diff --git a/include/index.h b/include/index.h index ecb1e229..48aa9687 100644 --- a/include/index.h +++ b/include/index.h @@ -183,6 +183,7 @@ private: Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, size_t & all_result_ids_len, + size_t& field_num_results, const size_t typo_tokens_threshold); void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, @@ -242,6 +243,8 @@ public: std::vector> & override_result_kvs, const size_t typo_tokens_threshold); + static void concat_topster_ids(Topster* topster, spp::sparse_hash_map>& topster_ids); + Option do_filtering(uint32_t** filter_ids_out, const std::vector & filters); Option remove(const uint32_t seq_id, const nlohmann::json & document); diff --git a/include/match_score.h b/include/match_score.h index cc6667bd..8e172167 100644 --- a/include/match_score.h +++ b/include/match_score.h @@ -45,6 +45,17 @@ struct Match { } + // Explicit construction of match score + static inline uint64_t get_match_score(const uint32_t words_present, const uint32_t total_cost, const uint8_t distance, + const uint8_t field_id) { + + uint64_t match_score = ((int64_t) (words_present) << 24) | + ((int64_t) (255 - total_cost) << 16) | + ((int64_t) (distance) << 8) | + ((int64_t) (field_id)); + return match_score; + } + // Construct a single match score from individual components (for multi-field sort) inline uint64_t get_match_score(const uint32_t total_cost, const uint8_t field_id) const { uint64_t match_score = ((int64_t) (words_present) << 24) | diff --git a/include/topster.h b/include/topster.h index 33e2bf9b..5b6dcc83 100644 --- a/include/topster.h +++ b/include/topster.h @@ -88,9 +88,9 @@ struct Topster { } bool add(KV* kv) { - //LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->match_score; - /*for(auto kv: kv_map) { - LOG(INFO) << "kv key: " << kv.first << " => " << kv.second->match_score; + /*LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->scores[kvs[0]->match_score_index]; + for(auto& mkv: kv_map) { + LOG(INFO) << "kv key: " << mkv.first << " => " << mkv.second->scores[mkv.second->match_score_index]; }*/ bool less_than_min_heap = (size >= MAX_SIZE) && is_smaller(kv, kvs[0]); diff --git a/src/collection.cpp b/src/collection.cpp index de795e33..a524aee2 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1290,13 +1290,13 @@ void Collection::highlight_result(const field &search_field, const std::string& highlight_end_tag, highlight_t & highlight) { - std::vector leaf_to_indices; - std::vector query_suggestion; - if(searched_queries.size() <= field_order_kv->query_index) { return ; } + std::vector leaf_to_indices; + std::vector query_suggestion; + for (const art_leaf *token_leaf : searched_queries[field_order_kv->query_index]) { // Must search for the token string fresh on that field for the given document since `token_leaf` // is from the best matched field and need not be present in other fields of a document. diff --git a/src/index.cpp b/src/index.cpp index b184b0fd..65d2d6eb 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -828,6 +828,7 @@ void Index::search_candidates(const uint8_t & field_id, std::vector> & searched_queries, Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, size_t & all_result_ids_len, + size_t& field_num_results, const size_t typo_tokens_threshold) { const long long combination_limit = 10; @@ -901,7 +902,6 @@ void Index::search_candidates(const uint8_t & field_id, log_query << query_suggestion[i]->key << " "; } - if(filter_ids != nullptr) { // intersect once again with filter ids uint32_t* filtered_result_ids = nullptr; @@ -918,6 +918,8 @@ void Index::search_candidates(const uint8_t & field_id, score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion, groups_processed, filtered_result_ids, filtered_results_size); + field_num_results += filtered_results_size; + delete[] filtered_result_ids; delete[] result_ids; } else { @@ -934,13 +936,15 @@ void Index::search_candidates(const uint8_t & field_id, LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size; }*/ + field_num_results += result_size; + delete[] result_ids; } searched_queries.push_back(actual_query_suggestion); - //LOG(INFO) << "all_result_ids_len: " << all_result_ids_len << ", typo_tokens_threshold: " << typo_tokens_threshold; - if(all_result_ids_len >= typo_tokens_threshold) { + //LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold; + if(field_num_results >= typo_tokens_threshold) { break; } } @@ -1275,6 +1279,23 @@ void Index::collate_included_ids(const std::vector& q_included_toke searched_queries.push_back(override_query); } +void Index::concat_topster_ids(Topster* topster, spp::sparse_hash_map>& topster_ids) { + if(topster->distinct) { + for(auto &group_topster_entry: topster->group_kv_map) { + Topster* group_topster = group_topster_entry.second; + for(const auto& map_kv: group_topster->kv_map) { + topster_ids[map_kv.first].push_back(map_kv.second); + } + } + } else { + for(const auto& map_kv: topster->kv_map) { + //LOG(INFO) << "map_kv.second.key: " << map_kv.second->key; + //LOG(INFO) << "map_kv.first: " << map_kv.first; + topster_ids[map_kv.first].push_back(map_kv.second); + } + } +} + void Index::search(Option & outcome, const std::vector& q_include_tokens, const std::vector& q_exclude_tokens, @@ -1413,6 +1434,9 @@ void Index::search(Option & outcome, all_result_ids = filter_ids; filter_ids = nullptr; } else { + spp::sparse_hash_map> topster_ids; + std::vector ftopsters; + // non-wildcard for(size_t i = 0; i < num_search_fields; i++) { // proceed to query search only when no filters are provided or when filtering produces results @@ -1425,10 +1449,15 @@ void Index::search(Option & outcome, size_t num_tokens_dropped = 0; //LOG(INFO) << "searching field! " << field; + Topster* ftopster = new Topster(topster->MAX_SIZE, topster->distinct); + ftopsters.push_back(ftopster); + + // Don't waste additional cycles for single field searches + Topster* actual_topster = (num_search_fields == 1) ? topster : ftopster; search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, - num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len, + num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold); // do synonym based searches @@ -1439,16 +1468,88 @@ void Index::search(Option & outcome, // for synonym we use a smaller field id than for original tokens search_field(field_id-1, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, - num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len, + num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold); } + concat_topster_ids(ftopster, topster_ids); collate_included_ids(q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries); + //LOG(INFO) << "topster_ids.size: " << topster_ids.size(); } } + for(const auto& key_kvs: topster_ids) { + // first calculate existing aggregate scores across best matching fields + spp::sparse_hash_map existing_field_kvs; + + const auto& kvs = key_kvs.second; + const uint64_t seq_id = key_kvs.first; + + //LOG(INFO) << "DOC ID: " << seq_id; + + /*if(seq_id == 12 || seq_id == 15) { + LOG(INFO) << "here"; + }*/ + + for(const auto kv: kvs) { + existing_field_kvs.emplace(kv->field_id, kv); + } + + for(size_t i = 0; i < num_search_fields && num_search_fields > 1; i++) { + const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - (2*i)); // Order of `fields` used to sort results + + if(field_id == kvs[0]->field_id) { + continue; + } + + if(existing_field_kvs.count(field_id) != 0) { + // for existing field, we will simply sum field-wise match scores + kvs[0]->scores[kvs[0]->match_score_index] += + existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index]; + continue; + } + + const std::string & field = search_fields[i]; + + // compute approximate match score for this field from actual query + + size_t words_present = 0; + + for(size_t token_index=0; token_index < q_include_tokens.size(); token_index++) { + const auto& token = q_include_tokens[token_index]; + + std::vector leaves; + const bool prefix_search = prefix && (token_index == q_include_tokens.size()-1); + const size_t token_len = prefix_search ? (int) token.length() : (int) token.length() + 1; + art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len, + 0, 0, 1, token_order, prefix_search, leaves); + + if(!leaves.empty() && leaves[0]->values->ids.contains(seq_id)) { + words_present++; + } + + /*if(!leaves.empty()) { + LOG(INFO) << "tok: " << leaves[0]->key; + }*/ + } + + if(words_present != 0) { + uint64_t match_score = Match::get_match_score(words_present, 0, 100, field_id); + kvs[0]->scores[kvs[0]->match_score_index] += match_score; + } + } + + //LOG(INFO) << "kvs[0].key: " << kvs[0]->key; + topster->add(kvs[0]); + } + + for(Topster* ftopster: ftopsters) { + delete ftopster; + } } + //LOG(INFO) << "topster size: " << topster->size; + delete [] exclude_token_ids; do_facets(facets, facet_query, all_result_ids, all_result_ids_len); @@ -1496,6 +1597,9 @@ void Index::search_field(const uint8_t & field_id, const size_t max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos; + // tracks the number of results found for the current field + size_t field_num_results = 0; + // To prevent us from doing ART search repeatedly as we iterate through possible corrections spp::sparse_hash_map> token_cost_cache; @@ -1565,19 +1669,14 @@ void Index::search_field(const uint8_t & field_id, //log_leaves(costs[token_index], token, leaves); token_candidates_vec.push_back(token_candidates{token, costs[token_index], leaves}); } else { - // No result at `cost = costs[token_index]`. Remove costs until `cost` for token and re-do combinations + // No result at `cost = costs[token_index]`. Remove `cost` for token and re-do combinations auto it = std::find(token_to_costs[token_index].begin(), token_to_costs[token_index].end(), costs[token_index]); if(it != token_to_costs[token_index].end()) { token_to_costs[token_index].erase(it); - // when no more costs are left for this token and `drop_tokens_threshold` is breached - if(token_to_costs[token_index].empty() && all_result_ids_len >= drop_tokens_threshold) { - n = combination_limit; // to break outer loop - break; - } - - // otherwise, we try to drop the token and search with remaining tokens + // when no more costs are left for this token if(token_to_costs[token_index].empty()) { + // we can try to drop the token and search with remaining tokens token_to_costs.erase(token_to_costs.begin()+token_index); search_tokens.erase(search_tokens.begin()+token_index); query_tokens.erase(query_tokens.begin()+token_index); @@ -1585,32 +1684,35 @@ void Index::search_field(const uint8_t & field_id, } } - // To continue outerloop on new cost combination + // Continue outerloop on new cost combination n = -1; N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product); - break; + goto resume_typo_loop; } token_index++; } - if(!token_candidates_vec.empty() && token_candidates_vec.size() == search_tokens.size()) { - // If all tokens were found, go ahead and search for candidates with what we have so far + if(!token_candidates_vec.empty()) { + // If atleast one token is found, go ahead and search for candidates search_candidates(field_id, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size, curated_ids, sort_fields, token_candidates_vec, searched_queries, topster, - groups_processed, all_result_ids, all_result_ids_len, typo_tokens_threshold); + groups_processed, all_result_ids, all_result_ids_len, field_num_results, + typo_tokens_threshold); } - if (all_result_ids_len >= typo_tokens_threshold) { - // If we don't find enough results, we continue outerloop (looking at tokens with greater typo cost) - break; + resume_typo_loop: + + if(field_num_results >= drop_tokens_threshold || field_num_results >= typo_tokens_threshold) { + // if either threshold is breached, we are done + return ; } n++; } - // When there are not enough overall results and atleast one token has results - if(all_result_ids_len < drop_tokens_threshold && !query_tokens.empty() && num_tokens_dropped < query_tokens.size()) { + // When atleast one token from the query is available + if(!query_tokens.empty() && num_tokens_dropped < query_tokens.size()) { // Drop tokens from right until (len/2 + 1), and then from left until (len/2 + 1) std::vector truncated_tokens; diff --git a/test/collection_grouping_test.cpp b/test/collection_grouping_test.cpp index 7ca2e20d..74132378 100644 --- a/test/collection_grouping_test.cpp +++ b/test/collection_grouping_test.cpp @@ -307,8 +307,8 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) { ASSERT_STREQ("country", results["grouped_hits"][2]["group_key"][0].get().c_str()); ASSERT_EQ(2, results["grouped_hits"][2]["hits"].size()); - ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][1]["document"]["id"].get().c_str()); collectionManager.drop_collection("coll1"); } diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 6760033c..f35e6beb 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -2659,6 +2659,53 @@ TEST_F(CollectionTest, MultiFieldRelevance) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionTest, MultiFieldMatchRanking) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("artist", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1"); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Style", "Taylor Swift"}, + {"Blank Space", "Taylor Swift"}, + {"Balance Overkill", "Taylor Swift"}, + {"Cardigan", "Taylor Swift"}, + {"Invisible String", "Taylor Swift"}, + {"The Last Great American Dynasty", "Taylor Swift"}, + {"Mirrorball", "Taylor Swift"}, + {"Peace", "Taylor Swift"}, + {"Betty", "Taylor Swift"}, + {"Mad Woman", "Taylor Swift"}, + }; + + for(size_t i=0; iadd(doc.dump()).ok()); + } + + auto results = coll1->search("taylor swift style", + {"artist", "title"}, "", {}, {}, 0, 3, 1, FREQUENCY, true, 5).get(); + + LOG(INFO) << results; + + ASSERT_EQ(10, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionTest, HighlightWithAccentedCharacters) { Collection *coll1;