From 38c5c0b035057a3dd265ae80d1f4d264283547ca Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 25 May 2023 12:10:45 +0530 Subject: [PATCH] Allow text match bucket of 1. --- src/collection.cpp | 9 ++-- test/collection_sorting_test.cpp | 70 +++++++++++++++++--------------- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index 1c3e51b1..ad186d3e 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1549,12 +1549,12 @@ Option Collection::search(std::string raw_query, return Option(408, "Request Timeout"); } - if(match_score_index >= 0 && sort_fields_std[match_score_index].text_match_buckets > 1) { + if(match_score_index >= 0 && sort_fields_std[match_score_index].text_match_buckets > 0) { size_t num_buckets = sort_fields_std[match_score_index].text_match_buckets; const size_t max_kvs_bucketed = std::min(DEFAULT_TOPSTER_SIZE, raw_result_kvs.size()); if(max_kvs_bucketed >= num_buckets) { - std::vector result_scores(max_kvs_bucketed); + spp::sparse_hash_map result_scores; // only first `max_kvs_bucketed` elements are bucketed to prevent pagination issues past 250 records size_t block_len = (max_kvs_bucketed / num_buckets); @@ -1563,7 +1563,7 @@ Option Collection::search(std::string raw_query, int64_t anchor_score = raw_result_kvs[i][0]->scores[raw_result_kvs[i][0]->match_score_index]; size_t j = 0; while(j < block_len && i+j < max_kvs_bucketed) { - result_scores[i+j] = raw_result_kvs[i+j][0]->scores[raw_result_kvs[i+j][0]->match_score_index]; + result_scores[raw_result_kvs[i+j][0]->key] = raw_result_kvs[i+j][0]->scores[raw_result_kvs[i+j][0]->match_score_index]; raw_result_kvs[i+j][0]->scores[raw_result_kvs[i+j][0]->match_score_index] = anchor_score; j++; } @@ -1577,7 +1577,8 @@ Option Collection::search(std::string raw_query, // restore original scores for(i = 0; i < max_kvs_bucketed; i++) { - raw_result_kvs[i][0]->scores[raw_result_kvs[i][0]->match_score_index] = result_scores[i]; + raw_result_kvs[i][0]->scores[raw_result_kvs[i][0]->match_score_index] = + result_scores[raw_result_kvs[i][0]->key]; } } } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 99cb95c7..7598a589 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -1636,7 +1636,7 @@ TEST_F(CollectionSortingTest, TextMatchBucketRanking) { nlohmann::json doc1; doc1["id"] = "0"; doc1["title"] = "Mark Antony"; - doc1["description"] = "Marriage Counsellor"; + doc1["description"] = "Counsellor"; doc1["points"] = 100; nlohmann::json doc2; @@ -1653,47 +1653,51 @@ TEST_F(CollectionSortingTest, TextMatchBucketRanking) { sort_by("points", "DESC"), }; - auto results = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + auto results = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true).get(); + "", "", {3}, 1000, true).get(); // when there are more buckets than results, no bucketing will happen ASSERT_EQ(2, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); - // bucketing by 1 produces original text match + // bucketing by 1 makes the text match score the same sort_fields = { sort_by("_text_match(buckets: 1)", "DESC"), sort_by("points", "DESC"), }; - results = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + results = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true).get(); + "", "", {3}, 1000, true).get(); ASSERT_EQ(2, results["hits"].size()); - ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); - ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); - // likewise with bucket 0 + size_t score1 = std::stoul(results["hits"][0]["text_match_info"]["score"].get()); + size_t score2 = std::stoul(results["hits"][1]["text_match_info"]["score"].get()); + ASSERT_TRUE(score1 < score2); + + // bucketing by 0 produces original text match sort_fields = { sort_by("_text_match(buckets: 0)", "DESC"), sort_by("points", "DESC"), }; - results = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + results = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true).get(); + "", "", {3}, 1000, true).get(); ASSERT_EQ(2, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -1702,46 +1706,46 @@ TEST_F(CollectionSortingTest, TextMatchBucketRanking) { // don't allow bad parameter name sort_fields[0] = sort_by("_text_match(foobar: 0)", "DESC"); - auto res_op = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + auto res_op = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true); + "", "", {3}, 1000, true); ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Invalid sorting parameter passed for _text_match.", res_op.error()); // handle bad syntax sort_fields[0] = sort_by("_text_match(foobar:", "DESC"); - res_op = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + res_op = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true); + "", "", {3}, 1000, true); ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Could not find a field named `_text_match(foobar:` in the schema for sorting.", res_op.error()); // handle bad value sort_fields[0] = sort_by("_text_match(buckets: x)", "DESC"); - res_op = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + res_op = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true); + "", "", {3}, 1000, true); ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Invalid value passed for _text_match `buckets` configuration.", res_op.error()); // handle negative value sort_fields[0] = sort_by("_text_match(buckets: -1)", "DESC"); - res_op = coll1->search("mark", {"title", "description"}, - "", {}, sort_fields, {2, 2}, 10, - 1, FREQUENCY, {true, true}, + res_op = coll1->search("mark", {"title"}, + "", {}, sort_fields, {2}, 10, + 1, FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, - "", "", {3, 1}, 1000, true); + "", "", {3}, 1000, true); ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Invalid value passed for _text_match `buckets` configuration.", res_op.error());