adding leftover changes

This commit is contained in:
krunal 2023-10-05 12:50:39 +05:30
parent a8234cf2b6
commit 77d934d381
2 changed files with 66 additions and 82 deletions

View File

@ -304,6 +304,11 @@ struct hnsw_index_t {
}
};
struct group_by_field_it_t {
std::string field_name;
posting_list_t::iterator_t it;
};
class Index {
private:
mutable std::shared_mutex mutex;
@ -538,7 +543,9 @@ private:
static void batch_embed_fields(std::vector<index_record*>& documents,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size = 200);
std::vector<group_by_field_it_t> get_group_by_field_iterators(const std::vector<std::string>&) const;
public:
// for limiting number of results on multiple candidates / query rewrites
enum {TYPO_TOKENS_THRESHOLD = 1};

View File

@ -1259,6 +1259,21 @@ int64_t Index::get_doc_val_from_sort_index(sort_index_iterator sort_index_it, ui
return INT64_MAX;
}
std::vector<group_by_field_it_t> Index::get_group_by_field_iterators(const std::vector<std::string>& group_by_fields) const {
std::vector<group_by_field_it_t> group_by_field_it_vec;
for (const auto &field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
group_by_field_it_t group_by_field_it_struct {field_name, std::move(facet_index_it)};
group_by_field_it_vec.emplace_back(std::move(group_by_field_it_struct));
}
return group_by_field_it_vec;
}
void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
bool estimate_facets, size_t facet_sample_percent,
const std::vector<facet_info_t>& facet_infos,
@ -1272,18 +1287,9 @@ void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
return ;
}
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
size_t total_docs = seq_ids->num_ids();
std::vector<std::pair<posting_list_t::iterator_t, std::string>> group_by_it_field_vec;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
group_by_it_field_vec.emplace_back(std::make_pair(std::move(facet_index_it), field_name));
}
// assumed that facet fields have already been validated upstream
for(size_t findex=0; findex < facets.size(); findex++) {
auto& a_facet = facets[findex];
@ -1384,8 +1390,8 @@ void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
uint64_t distinct_id = 0;
if(group_limit) {
distinct_id = 1;
for(auto& kv : group_by_it_field_vec) {
get_distinct_id(kv.second, kv.first, doc_seq_id, group_missing_values, distinct_id);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, doc_seq_id, group_missing_values, distinct_id);
}
}
//LOG(INFO) << "facet_hash_count " << facet_hash_count;
@ -2387,13 +2393,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
@ -2516,13 +2518,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
@ -2954,14 +2952,12 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
}
@ -3221,14 +3217,12 @@ void Index::process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>
// if one `id` of a group is present in curated hits, we have to exclude that entire group from results
for(auto seq_id: included_ids_vec) {
uint64_t distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
excluded_group_ids.emplace(distinct_id);
}
}
@ -3947,14 +3941,12 @@ Option<bool> Index::search_across_fields(const std::vector<token_t>& query_token
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
return;
}
@ -4764,14 +4756,12 @@ Option<bool> Index::do_phrase_search(const size_t num_search_fields, const std::
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
}
@ -4934,14 +4924,12 @@ Option<bool> Index::do_infix_search(const size_t num_search_fields, const std::v
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
}
@ -5301,16 +5289,6 @@ Option<bool> Index::search_wildcard(filter_node_t const* const& filter_tree_root
search_stop_us = parent_search_stop_ms;
search_cutoff = parent_search_cutoff;
std::vector<std::pair<posting_list_t::iterator_t, std::string>> group_by_it_field_vec;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
group_by_it_field_vec.emplace_back(std::make_pair(std::move(facet_index_it), field_name));
}
std::vector<uint32_t> filter_indexes;
for(size_t i = 0; i < batch_result->count; i++) {
@ -5339,8 +5317,10 @@ Option<bool> Index::search_wildcard(filter_node_t const* const& filter_tree_root
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
for(auto& kv: group_by_it_field_vec) {
get_distinct_id(kv.second, kv.first, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
if(excluded_group_ids.count(distinct_id) != 0) {
@ -5997,13 +5977,10 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
if(group_limit != 0) {
distinct_id = 1;
for(const auto& field_name: group_by_fields) {
if (!facet_index_v4->has_hash_index(field_name)) {
continue;
}
auto facet_index = facet_index_v4->get_facet_hash_index(field_name);
auto facet_index_it = facet_index->new_iterator();
get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id);
auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
for(auto& kv : group_by_field_it_vec) {
get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id);
}
}