From 77d934d381357d32bc373262d69fd72fee07ecee Mon Sep 17 00:00:00 2001 From: krunal Date: Thu, 5 Oct 2023 12:50:39 +0530 Subject: [PATCH] adding leftover changes --- include/index.h | 9 +++- src/index.cpp | 139 ++++++++++++++++++++---------------------------- 2 files changed, 66 insertions(+), 82 deletions(-) diff --git a/include/index.h b/include/index.h index e60a9e21..95cdc1e7 100644 --- a/include/index.h +++ b/include/index.h @@ -304,6 +304,11 @@ struct hnsw_index_t { } }; +struct group_by_field_it_t { + std::string field_name; + posting_list_t::iterator_t it; +}; + class Index { private: mutable std::shared_mutex mutex; @@ -538,7 +543,9 @@ private: static void batch_embed_fields(std::vector& documents, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema, const size_t remote_embedding_batch_size = 200); - + + std::vector get_group_by_field_iterators(const std::vector&) const; + public: // for limiting number of results on multiple candidates / query rewrites enum {TYPO_TOKENS_THRESHOLD = 1}; diff --git a/src/index.cpp b/src/index.cpp index 72b4e0f4..c86cb929 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1259,6 +1259,21 @@ int64_t Index::get_doc_val_from_sort_index(sort_index_iterator sort_index_it, ui return INT64_MAX; } +std::vector Index::get_group_by_field_iterators(const std::vector& group_by_fields) const { + std::vector group_by_field_it_vec; + for (const auto &field_name: group_by_fields) { + if (!facet_index_v4->has_hash_index(field_name)) { + continue; + } + auto facet_index = facet_index_v4->get_facet_hash_index(field_name); + auto facet_index_it = facet_index->new_iterator(); + + group_by_field_it_t group_by_field_it_struct {field_name, std::move(facet_index_it)}; + group_by_field_it_vec.emplace_back(std::move(group_by_field_it_struct)); + } + return group_by_field_it_vec; +} + void Index::do_facets(std::vector & facets, facet_query_t & facet_query, bool estimate_facets, size_t facet_sample_percent, const std::vector& facet_infos, @@ -1272,18 +1287,9 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, return ; } + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + size_t total_docs = seq_ids->num_ids(); - - std::vector> group_by_it_field_vec; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - group_by_it_field_vec.emplace_back(std::make_pair(std::move(facet_index_it), field_name)); - } - // assumed that facet fields have already been validated upstream for(size_t findex=0; findex < facets.size(); findex++) { auto& a_facet = facets[findex]; @@ -1384,8 +1390,8 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, uint64_t distinct_id = 0; if(group_limit) { distinct_id = 1; - for(auto& kv : group_by_it_field_vec) { - get_distinct_id(kv.second, kv.first, doc_seq_id, group_missing_values, distinct_id); + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, doc_seq_id, group_missing_values, distinct_id); } } //LOG(INFO) << "facet_hash_count " << facet_hash_count; @@ -2387,13 +2393,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint64_t distinct_id = seq_id; if (group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } if(excluded_group_ids.count(distinct_id) != 0) { continue; @@ -2516,13 +2518,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint64_t distinct_id = seq_id; if (group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } if(excluded_group_ids.count(distinct_id) != 0) { continue; @@ -2954,14 +2952,12 @@ Option Index::search(std::vector& field_query_tokens, cons uint64_t distinct_id = seq_id; if (group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } + if(excluded_group_ids.count(distinct_id) != 0) { continue; } @@ -3221,14 +3217,12 @@ void Index::process_curated_ids(const std::vector> // if one `id` of a group is present in curated hits, we have to exclude that entire group from results for(auto seq_id: included_ids_vec) { uint64_t distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } + excluded_group_ids.emplace(distinct_id); } } @@ -3947,14 +3941,12 @@ Option Index::search_across_fields(const std::vector& query_token uint64_t distinct_id = seq_id; if(group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } + if(excluded_group_ids.count(distinct_id) != 0) { return; } @@ -4764,14 +4756,12 @@ Option Index::do_phrase_search(const size_t num_search_fields, const std:: uint64_t distinct_id = seq_id; if(group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } + if(excluded_group_ids.count(distinct_id) != 0) { continue; } @@ -4934,14 +4924,12 @@ Option Index::do_infix_search(const size_t num_search_fields, const std::v uint64_t distinct_id = seq_id; if(group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } + if(excluded_group_ids.count(distinct_id) != 0) { continue; } @@ -5301,16 +5289,6 @@ Option Index::search_wildcard(filter_node_t const* const& filter_tree_root search_stop_us = parent_search_stop_ms; search_cutoff = parent_search_cutoff; - std::vector> group_by_it_field_vec; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - group_by_it_field_vec.emplace_back(std::make_pair(std::move(facet_index_it), field_name)); - } - std::vector filter_indexes; for(size_t i = 0; i < batch_result->count; i++) { @@ -5339,8 +5317,10 @@ Option Index::search_wildcard(filter_node_t const* const& filter_tree_root uint64_t distinct_id = seq_id; if(group_limit != 0) { distinct_id = 1; - for(auto& kv: group_by_it_field_vec) { - get_distinct_id(kv.second, kv.first, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } if(excluded_group_ids.count(distinct_id) != 0) { @@ -5997,13 +5977,10 @@ void Index::score_results(const std::vector & sort_fields, const uint16 if(group_limit != 0) { distinct_id = 1; - for(const auto& field_name: group_by_fields) { - if (!facet_index_v4->has_hash_index(field_name)) { - continue; - } - auto facet_index = facet_index_v4->get_facet_hash_index(field_name); - auto facet_index_it = facet_index->new_iterator(); - get_distinct_id(field_name, facet_index_it, seq_id, group_missing_values, distinct_id); + auto group_by_field_it_vec = get_group_by_field_iterators(group_by_fields); + + for(auto& kv : group_by_field_it_vec) { + get_distinct_id(kv.field_name, kv.it, seq_id, group_missing_values, distinct_id); } }