From e256f693f94cd51ac95b418fd3a8afba0b3c63d8 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 6 Jul 2021 17:30:46 +0530 Subject: [PATCH] Improve multi field typo ranking. --- src/index.cpp | 80 +++++++++++++++++------- test/collection_specific_test.cpp | 100 +++++++++++++++++++++++++++++- 2 files changed, 155 insertions(+), 25 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 2b91cb70..b2a9ab5a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1000,7 +1000,8 @@ void Index::search_candidates(const uint8_t & field_id, *all_result_ids = new_all_result_ids; /*if(result_size != 0) { - LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size; + LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size + << ", popcount: " << (__builtin_popcount(token_bits) - 1); }*/ score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, @@ -1702,17 +1703,17 @@ void Index::search(const std::vector& field_query_tokens, existing_field_kvs.emplace(kv->field_id, kv); } - uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set - uint64_t total_typos = 0, total_distances = 0; + uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set + uint64_t total_typos = 0, total_distances = 0, min_typos = 1000; - uint64_t verbatim_matches = 0; // query matching field verbatim - uint64_t query_matches = 0; // field containing query tokens - uint64_t cross_field_matches = 0; // total matches across fields (including fuzzy ones) + uint64_t verbatim_match_fields = 0; // query matching field verbatim + uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens + uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones) //LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits); for(size_t i = 0; i < num_search_fields; i++) { - const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - i); + const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i); size_t weight = search_fields[i].weight; //LOG(INFO) << "--- field index: " << i << ", weight: " << weight; @@ -1725,16 +1726,20 @@ void Index::search(const std::vector& field_query_tokens, int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index]; uint64_t tokens_found = ((match_score >> 24) & 0xFF); - int64_t field_typos = 255 - ((match_score >> 16) & 0xFF); + uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF); total_typos += (field_typos + 1) * weight; total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * weight; - verbatim_matches += (((match_score & 0xFF)) + 1) * weight; + verbatim_match_fields += (((match_score & 0xFF)) + 1); if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) { - query_matches++; + exact_match_fields++; } - cross_field_matches += tokens_found; + if(field_typos < min_typos) { + min_typos = field_typos; + } + + total_token_matches += tokens_found; /*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF)) << ", weighted typos: " << std::max((255 - ((match_score >> 8) & 0xFF)), 1) * weight @@ -1787,30 +1792,45 @@ void Index::search(const std::vector& field_query_tokens, total_typos += (field_typos + 1) * weight; if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) { - query_matches++; - verbatim_matches++; + exact_match_fields++; + verbatim_match_fields++; // this is only an approximate } - cross_field_matches += tokens_found; + if(field_typos < min_typos) { + min_typos = field_typos; + } + + total_token_matches += tokens_found; //LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF); } } - int64_t tokens_present = int64_t(__builtin_popcount(token_bits)) - 1; + // num tokens present across fields including those containing typos + int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1; + total_typos = std::min(255, total_typos); total_distances = std::min(100, total_distances); uint64_t aggregated_score = ( - (query_matches << 40) | - (tokens_present << 32) | - (cross_field_matches << 24) | - ((255 - total_typos) << 16) | - ((100 - total_distances) << 8) | - (verbatim_matches) + (exact_match_fields << 48) | // number of fields that contain *all tokens* in the query + (uniq_tokens_found << 40) | // number of unique tokens found across fields including typos + ((255 - min_typos) << 32) | // minimum typo cost across all fields + (total_token_matches << 24) | // total matches across fields including typos + ((255 - total_typos) << 16) | // total typos across fields (weighted) + ((100 - total_distances) << 8) | // total distances across fields (weighted) + (verbatim_match_fields) // field value *exactly* same as query tokens ); - /*LOG(INFO) << "seq id: " << seq_id << ", tokens_present: " << tokens_present - << ", total_distances: " << total_distances << ", total_typos: " << total_typos + //LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score; + + /*LOG(INFO) << "seq id: " << seq_id + << ", exact_match_fields: " << exact_match_fields + << ", uniq_tokens_found: " << uniq_tokens_found + << ", min typo score: " << (255 - min_typos) + << ", total_token_matches: " << total_token_matches + << ", typo score: " << (255 - total_typos) + << ", distance score: " << (100 - total_distances) + << ", verbatim_match_fields: " << verbatim_match_fields << ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/ kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score; @@ -2125,6 +2145,8 @@ void Index::score_results(const std::vector & sort_fields, const uint16 std::unordered_map> array_token_positions; populate_token_positions(query_suggestion, leaf_to_indices, i, array_token_positions); + uint64_t total_tokens_found = 0, total_num_typos = 0, total_distance = 0, total_verbatim = 0; + for (const auto& kv: array_token_positions) { const std::vector& token_positions = kv.second; if (token_positions.empty()) { @@ -2133,7 +2155,10 @@ void Index::score_results(const std::vector & sort_fields, const uint16 const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match); uint64_t this_match_score = match.get_match_score(total_cost); - match_score += this_match_score; + total_tokens_found += ((this_match_score >> 24) & 0xFF); + total_num_typos += 255 - ((this_match_score >> 16) & 0xFF); + total_distance += 100 - ((this_match_score >> 8) & 0xFF); + total_verbatim += (this_match_score & 0xFF); /*std::ostringstream os; os << name << ", total_cost: " << (255 - total_cost) @@ -2143,6 +2168,13 @@ void Index::score_results(const std::vector & sort_fields, const uint16 << ", seq_id: " << seq_id << std::endl; LOG(INFO) << os.str();*/ } + + match_score = ( + (uint64_t(total_tokens_found) << 24) | + (uint64_t(255 - total_num_typos) << 16) | + (uint64_t(100 - total_distance) << 8) | + (uint64_t(total_verbatim) << 1) + ); } const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional) diff --git a/test/collection_specific_test.cpp b/test/collection_specific_test.cpp index 83760c56..6ce707a6 100644 --- a/test/collection_specific_test.cpp +++ b/test/collection_specific_test.cpp @@ -144,7 +144,105 @@ TEST_F(CollectionSpecificTest, ExplicitHighlightFieldsConfig) { collectionManager.drop_collection("coll1"); } -TEST_F(CollectionSpecificTest, PrefixWithTypos1) { +TEST_F(CollectionSpecificTest, ExactSingleFieldMatch) { + std::vector fields = {field("title", field_types::STRING, false), + field("description", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["title"] = "Fast Electric Charger"; + doc1["description"] = "A product you should buy."; + doc1["points"] = 100; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["title"] = "Omega Chargex"; + doc2["description"] = "Chargex is a great product."; + doc2["points"] = 200; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + + auto results = coll1->search("charger", {"title", "description"}, "", {}, {}, {2}, 10, + 1, FREQUENCY, {true, true}).get(); + + LOG(INFO) << results; + + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionSpecificTest, OrderMultiFieldFuzzyMatch) { + std::vector fields = {field("title", field_types::STRING, false), + field("description", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["title"] = "Moto Insta Share"; + doc1["description"] = "Share information with this device."; + doc1["points"] = 100; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["title"] = "Portable USB Store"; + doc2["description"] = "Use it to charge your phone."; + doc2["points"] = 50; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + + auto results = coll1->search("charger", {"title", "description"}, "", {}, {}, {2}, 10, + 1, FREQUENCY, {true, true}).get(); + + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionSpecificTest, FieldWeighting) { + std::vector fields = {field("title", field_types::STRING, false), + field("description", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["title"] = "The Quick Brown Fox"; + doc1["description"] = "Share information with this device."; + doc1["points"] = 100; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["title"] = "Random Title"; + doc2["description"] = "The Quick Brown Fox"; + doc2["points"] = 50; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + + auto results = coll1->search("brown fox", {"title", "description"}, "", {}, {}, {2}, 10, + 1, FREQUENCY, {true, true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 40, {}, {}, {}, 0, + "", "", {1, 4}).get(); + + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionSpecificTest, PrefixWithTypos) { std::vector fields = {field("title", field_types::STRING, false), field("points", field_types::INT32, false),};