diff --git a/TODO.md b/TODO.md index adeea108..8c3bed2d 100644 --- a/TODO.md +++ b/TODO.md @@ -99,8 +99,8 @@ - ~~Use snappy compression for storage~~ - ~~Fix exclude_scalar early returns~~ - ~~Fix result ids length during grouped overrides~~ -- Fix override grouping (collate_included_ids) -- Test for overriding result on second page +- ~~Fix override grouping (collate_included_ids)~~ +- ~~Test for overriding result on second page~~ - atleast 1 token match for proceeding with drop tokens - support wildcard query with filters - API for optimizing on disk storage diff --git a/include/collection.h b/include/collection.h index 4e06131d..648ca384 100644 --- a/include/collection.h +++ b/include/collection.h @@ -155,10 +155,10 @@ private: void remove_document(nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); void populate_overrides(std::string query, - const std::map& pinned_hits, + const std::map>& pinned_hits, const std::vector& hidden_hits, - std::map & id_pos_map, - std::vector & included_ids, std::vector & excluded_ids); + std::map>& include_ids, + std::vector & excluded_ids); static bool facet_count_compare(const std::pair& a, const std::pair& b) { @@ -236,7 +236,7 @@ public: const size_t snippet_threshold = 30, const std::string & highlight_full_fields = "", size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD, - const std::map& pinned_hits={}, + const std::map>& pinned_hits={}, const std::vector& hidden_hits={}, const std::vector& group_by_fields={}, const size_t group_limit = 0); diff --git a/include/index.h b/include/index.h index 65cf951e..f2eb5fcc 100644 --- a/include/index.h +++ b/include/index.h @@ -27,7 +27,7 @@ struct search_args { std::vector search_fields; std::vector filters; std::vector facets; - std::vector included_ids; + std::map> included_ids; std::vector excluded_ids; std::vector sort_fields_std; facet_query_t facet_query; @@ -47,7 +47,7 @@ struct search_args { Topster* topster; Topster* curated_topster; std::vector> raw_result_kvs; - std::vector override_result_kvs; + std::vector> override_result_kvs; Option outcome; search_args(): outcome(0) { @@ -55,7 +55,7 @@ struct search_args { } search_args(std::string query, std::vector search_fields, std::vector filters, - std::vector facets, std::vector included_ids, std::vector excluded_ids, + std::vector facets, std::map> included_ids, std::vector excluded_ids, std::vector sort_fields_std, facet_query_t facet_query, int num_typos, size_t max_facet_values, size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix, size_t drop_tokens_threshold, size_t typo_tokens_threshold, @@ -70,7 +70,7 @@ struct search_args { const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory topster = new Topster(topster_size, group_limit); - curated_topster = new Topster(topster_size); + curated_topster = new Topster(topster_size, group_limit); } ~search_args() { @@ -213,7 +213,7 @@ private: const uint32_t indices_length); void collate_included_ids(const std::string & query, const std::string & field, const uint8_t field_id, - const std::vector & included_ids, + const std::map> & included_ids_map, Topster* curated_topster, std::vector> & searched_queries); uint64_t facet_token_hash(const field & a_field, const std::string &token); @@ -239,7 +239,8 @@ public: void search(Option & outcome, const std::string & query, const std::vector & search_fields, const std::vector & filters, std::vector & facets, facet_query_t & facet_query, - const std::vector & included_ids, const std::vector & excluded_ids, + const std::map> & included_ids_map, + const std::vector & excluded_ids, const std::vector & sort_fields_std, const int num_typos, Topster* topster, Topster* curated_topster, const size_t per_page, const size_t page, const token_ordering token_order, @@ -247,7 +248,8 @@ public: size_t & all_result_ids_len, spp::sparse_hash_set& groups_processed, std::vector> & searched_queries, - std::vector> & raw_result_kvs, std::vector & override_result_kvs, + std::vector> & raw_result_kvs, + std::vector> & override_result_kvs, const size_t typo_tokens_threshold); Option remove(const uint32_t seq_id, nlohmann::json & document); diff --git a/src/collection.cpp b/src/collection.cpp index 24547c62..a51a7e29 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -305,54 +305,70 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash } void Collection::populate_overrides(std::string query, - const std::map& pinned_hits, + const std::map>& pinned_hits, const std::vector& hidden_hits, - std::map & id_pos_map, - std::vector & included_ids, + std::map>& include_ids, std::vector & excluded_ids) { StringUtils::tolowercase(query); + std::set excluded_set; + + // If pinned or hidden hits are provided, they take precedence over overrides + + // have to ensure that hidden hits take precedence over included hits + if(!hidden_hits.empty()) { + for(const auto & hit: hidden_hits) { + Option seq_id_op = doc_id_to_seq_id(hit); + if(seq_id_op.ok()) { + excluded_ids.push_back(seq_id_op.get()); + excluded_set.insert(seq_id_op.get()); + } + } + } for(const auto & override_kv: overrides) { const auto & override = override_kv.second; if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) || (override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) { - for(const auto & hit: override.add_hits) { - Option seq_id_op = doc_id_to_seq_id(hit.doc_id); - if(seq_id_op.ok()) { - included_ids.push_back(seq_id_op.get()); - id_pos_map[seq_id_op.get()] = hit.position; - } - } + // have to ensure that dropped hits take precedence over added hits for(const auto & hit: override.drop_hits) { Option seq_id_op = doc_id_to_seq_id(hit.doc_id); if(seq_id_op.ok()) { excluded_ids.push_back(seq_id_op.get()); + excluded_set.insert(seq_id_op.get()); } } + + for(const auto & hit: override.add_hits) { + Option seq_id_op = doc_id_to_seq_id(hit.doc_id); + if(!seq_id_op.ok()) { + continue; + } + uint32_t seq_id = seq_id_op.get(); + bool excluded = (excluded_set.count(seq_id) != 0); + if(!excluded) { + include_ids[hit.position].push_back(seq_id); + } + } + + break; } } - // If pinned or hidden hits are provided, they take precedence over overrides - if(!pinned_hits.empty()) { - for(const auto & hit: pinned_hits) { - Option seq_id_op = doc_id_to_seq_id(hit.first); - if(seq_id_op.ok()) { - included_ids.push_back(seq_id_op.get()); - id_pos_map[seq_id_op.get()] = hit.second; - } - } - } - - if(!hidden_hits.empty()) { - for(const auto & hit: hidden_hits) { - Option seq_id_op = doc_id_to_seq_id(hit); - if(seq_id_op.ok()) { - included_ids.erase(std::remove(included_ids.begin(), included_ids.end(), seq_id_op.get()), included_ids.end()); - id_pos_map.erase(seq_id_op.get()); - excluded_ids.push_back(seq_id_op.get()); + for(const auto& pos_ids: pinned_hits) { + size_t pos = pos_ids.first; + for(const std::string& id: pos_ids.second) { + Option seq_id_op = doc_id_to_seq_id(id); + if(!seq_id_op.ok()) { + continue; + } + uint32_t seq_id = seq_id_op.get(); + bool excluded = (excluded_set.count(seq_id) != 0); + if(!excluded) { + include_ids[pos].push_back(seq_id); + } } } } @@ -371,22 +387,40 @@ Option Collection::search(const std::string & query, const std:: const size_t snippet_threshold, const std::string & highlight_full_fields, size_t typo_tokens_threshold, - const std::map& pinned_hits, + const std::map>& pinned_hits, const std::vector& hidden_hits, const std::vector& group_by_fields, const size_t group_limit) { - std::vector included_ids; std::vector excluded_ids; - std::map id_pos_map; - populate_overrides(query, pinned_hits, hidden_hits, id_pos_map, included_ids, excluded_ids); + std::map> include_ids; // position => list of IDs + populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids); - std::map> index_to_included_ids; + /*for(auto kv: include_ids) { + LOG(INFO) << "key: " << kv.first; + for(auto val: kv.second) { + LOG(INFO) << val; + } + } + + LOG(INFO) << "Excludes:"; + + for(auto id: excluded_ids) { + LOG(INFO) << id; + }*/ + + //LOG(INFO) << "include_ids size: " << include_ids.size(); + //LOG(INFO) << "Pos 1: " << include_ids[1][0]; + + std::map>> index_to_included_ids; std::map> index_to_excluded_ids; - for(auto seq_id: included_ids) { - auto index_id = (seq_id % num_indices); - index_to_included_ids[index_id].push_back(seq_id); + for(const auto& pos_ids: include_ids) { + size_t position = pos_ids.first; + for(auto seq_id: pos_ids.second) { + auto index_id = (seq_id % num_indices); + index_to_included_ids[index_id][position].push_back(seq_id); + } } for(auto seq_id: excluded_ids) { @@ -650,7 +684,7 @@ Option Collection::search(const std::string & query, const std:: std::vector> searched_queries; // search queries used for generating the results std::vector> raw_result_kvs; - std::vector override_result_kvs; + std::vector> override_result_kvs; size_t total_found = 0; spp::sparse_hash_set groups_processed; // used to calculate total_found for grouped query @@ -697,9 +731,9 @@ Option Collection::search(const std::string & query, const std:: raw_result_kvs.push_back(kv_group); } - for(auto & field_order_kv: index->search_params->override_result_kvs) { - field_order_kv->query_index += searched_queries.size(); - override_result_kvs.push_back(field_order_kv); + for(const std::vector & kv_group: index->search_params->override_result_kvs) { + kv_group[0]->query_index += searched_queries.size(); + override_result_kvs.push_back(kv_group); } searched_queries.insert(searched_queries.end(), index->search_params->searched_queries.begin(), @@ -797,8 +831,8 @@ Option Collection::search(const std::string & query, const std:: // Sort based on position in overridden list std::sort( override_result_kvs.begin(), override_result_kvs.end(), - [&id_pos_map](const KV* a, const KV* b) -> bool { - return id_pos_map[a->key] < id_pos_map[b->key]; + [](const std::vector& a, std::vector& b) -> bool { + return a[0]->distinct_key < b[0]->distinct_key; } ); @@ -808,11 +842,12 @@ Option Collection::search(const std::string & query, const std:: // merge raw results and override results while(override_kv_index < override_result_kvs.size() && raw_results_index < raw_result_kvs.size()) { - if(override_kv_index < override_result_kvs.size() && - id_pos_map.count(override_result_kvs[override_kv_index]->key) != 0 && - result_group_kvs.size() + 1 == id_pos_map[override_result_kvs[override_kv_index]->key]) { - result_group_kvs.push_back({override_result_kvs[override_kv_index]}); - override_kv_index++; + size_t result_position = result_group_kvs.size() + 1; + uint64_t override_position = override_result_kvs[override_kv_index][0]->distinct_key; + + if(result_position == override_position) { + result_group_kvs.push_back(override_result_kvs[override_kv_index]); + override_kv_index++; } else { result_group_kvs.push_back(raw_result_kvs[raw_results_index]); raw_results_index++; diff --git a/src/core_api.cpp b/src/core_api.cpp index 18c96629..03f89a2f 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -390,7 +390,8 @@ bool get_search(http_req & req, http_res & res) { } } - std::map pinned_hits; + std::map> pinned_hits; + if(req.params.count(PINNED_HITS) != 0) { std::vector pinned_hits_strs; StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ","); @@ -415,7 +416,7 @@ bool get_search(http_req & req, http_res & res) { return false; } - pinned_hits.emplace(expression_parts[0], position); + pinned_hits[position].emplace_back(expression_parts[0]); } } diff --git a/src/index.cpp b/src/index.cpp index fcbbc478..6fdee499 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1038,11 +1038,11 @@ void Index::run_search() { } void Index::collate_included_ids(const std::string & query, const std::string & field, const uint8_t field_id, - const std::vector & included_ids, + const std::map> & included_ids_map, Topster* curated_topster, std::vector> & searched_queries) { - if(included_ids.empty()) { + if(included_ids_map.empty()) { return; } @@ -1061,50 +1061,28 @@ void Index::collate_included_ids(const std::string & query, const std::string & art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len, 0, 0, 1, token_ordering::MAX_SCORE, false, leaves); - if(leaves.size() > 0) { + if(!leaves.empty()) { override_query.push_back(leaves[0]); } } - spp::sparse_hash_map leaf_to_indices; + for(const auto& pos_ids: included_ids_map) { + const size_t pos = pos_ids.first; - for (art_leaf *token_leaf : override_query) { - uint32_t *indices = new uint32_t[included_ids.size()]; - token_leaf->values->ids.indexOf(&included_ids[0], included_ids.size(), indices); - leaf_to_indices.emplace(token_leaf, indices); - } + for(size_t i = 0; i < pos_ids.second.size(); i++) { + uint32_t seq_id = pos_ids.second[i]; - // curated_topster->MAX_SIZE is initialized based on max_hits. - // Even if override has more IDs, we should restrict to max hits. - size_t iter_size = std::min((size_t)curated_topster->MAX_SIZE, included_ids.size()); + uint64_t distinct_id = pos; // position is the group distinct key + uint64_t match_score = (64000 - i); // index within a group is the match score - for(size_t j=0; j>> array_token_positions; - populate_token_positions(override_query, leaf_to_indices, j, array_token_positions); - - uint64_t match_score = 0; - - for(const std::vector> & token_positions: array_token_positions) { - if(token_positions.empty()) { - continue; - } - const Match & match = Match::match(seq_id, token_positions); - uint64_t this_match_score = match.get_match_score(0, field_id); - - if(this_match_score > match_score) { - match_score = this_match_score; - } + KV kv(field_id, searched_queries.size(), seq_id, distinct_id, match_score, scores); + curated_topster->add(&kv); } - - int64_t scores[3]; - scores[0] = int64_t(match_score); - scores[1] = int64_t(1); - scores[2] = int64_t(1); - - KV kv(field_id, searched_queries.size(), seq_id, seq_id, match_score, scores); - curated_topster->add(&kv); } searched_queries.push_back(override_query); @@ -1115,7 +1093,7 @@ void Index::search(Option & outcome, const std::vector & search_fields, const std::vector & filters, std::vector & facets, facet_query_t & facet_query, - const std::vector & included_ids, + const std::map> & included_ids_map, const std::vector & excluded_ids, const std::vector & sort_fields_std, const int num_typos, Topster* topster, @@ -1126,7 +1104,7 @@ void Index::search(Option & outcome, spp::sparse_hash_set& groups_processed, std::vector>& searched_queries, std::vector> & raw_result_kvs, - std::vector & override_result_kvs, + std::vector> & override_result_kvs, const size_t typo_tokens_threshold) { // process the filters @@ -1141,7 +1119,16 @@ void Index::search(Option & outcome, uint32_t filter_ids_length = op_filter_ids_length.get(); // we will be removing all curated IDs from organic result ids before running topster - std::set curated_ids(included_ids.begin(), included_ids.end()); + std::set curated_ids; + std::vector included_ids; + + for(const auto& pos_ids: included_ids_map) { + for(const uint32_t id: pos_ids.second) { + curated_ids.insert(id); + included_ids.push_back(id); + } + } + curated_ids.insert(excluded_ids.begin(), excluded_ids.end()); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); @@ -1166,7 +1153,7 @@ void Index::search(Option & outcome, score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {}, groups_processed, filter_ids, filter_ids_length); - collate_included_ids(query, field, field_id, included_ids, curated_topster, searched_queries); + collate_included_ids(query, field, field_id, included_ids_map, curated_topster, searched_queries); all_result_ids_len = filter_ids_length; all_result_ids = filter_ids; @@ -1182,7 +1169,7 @@ void Index::search(Option & outcome, search_field(field_id, query, field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold); - collate_included_ids(query, field, field_id, included_ids, curated_topster, searched_queries); + collate_included_ids(query, field, field_id, included_ids_map, curated_topster, searched_queries); } } } @@ -1202,16 +1189,23 @@ void Index::search(Option & outcome, const std::vector group_kvs(group_topster->kvs, group_topster->kvs+group_topster->size); raw_result_kvs.emplace_back(group_kvs); } + + for(auto &curated_topster_entry: curated_topster->group_kv_map) { + Topster* group_topster = curated_topster_entry.second; + const std::vector group_kvs(group_topster->kvs, group_topster->kvs+group_topster->size); + override_result_kvs.emplace_back(group_kvs); + } + } else { for(uint32_t t = 0; t < topster->size; t++) { KV* kv = topster->getKV(t); raw_result_kvs.push_back({kv}); } - } - for(uint32_t t = 0; t < curated_topster->size; t++) { - KV* kv = curated_topster->getKV(t); - override_result_kvs.push_back(kv); + for(uint32_t t = 0; t < curated_topster->size; t++) { + KV* kv = curated_topster->getKV(t); + override_result_kvs.push_back({kv}); + } } // add curated IDs to result count diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 49824142..eaf153f7 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -256,10 +256,9 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) { } TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { - std::map pinned_hits; - std::vector hidden_hits; - pinned_hits["13"] = 1; - pinned_hits["4"] = 2; + std::map> pinned_hits; + pinned_hits[1] = {"13"}; + pinned_hits[2] = {"4"}; // basic pinning @@ -279,6 +278,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { // both pinning and hiding + std::vector hidden_hits; hidden_hits = {"11", "16"}; results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 50, 1, FREQUENCY, false, Index::DROP_TOKENS_THRESHOLD, @@ -291,6 +291,21 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get().c_str()); + // paginating such that pinned hits appear on second page + pinned_hits.clear(); + pinned_hits[4] = {"13"}; + pinned_hits[5] = {"4"}; + + results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 2, 2, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "starring: will", 30, + "", 10, + pinned_hits, hidden_hits).get(); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("13", results["hits"][1]["document"]["id"].get().c_str()); + // take precedence over override rules nlohmann::json override_json_include = {