From f5848be750ee8fc8fbdc9d3bf65f065abc136c65 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 18 Aug 2017 15:26:17 +0530 Subject: [PATCH] Address prefix search issues. Score based comparison was broken - test has been enhanced. --- include/collection.h | 2 +- src/api.cpp | 22 +++++++++++++++------- src/art.cpp | 2 +- src/collection.cpp | 32 ++++++++++++++++++++++++-------- src/main/benchmark.cpp | 2 +- test/collection_test.cpp | 25 ++++++++++++++++++++++++- test/documents.jsonl | 2 +- 7 files changed, 67 insertions(+), 20 deletions(-) diff --git a/include/collection.h b/include/collection.h index 488013cf..0c56d6a5 100644 --- a/include/collection.h +++ b/include/collection.h @@ -95,7 +95,7 @@ private: std::vector> & token_to_candidates, std::vector> & searched_queries, Topster<100> & topster, size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len, - const size_t & max_results); + const size_t & max_results, const bool prefix); void index_string_field(const std::string & text, const uint32_t score, art_tree *t, uint32_t seq_id, const bool verbatim) const; diff --git a/src/api.cpp b/src/api.cpp index 9cef8c4b..4d62c6e2 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -116,14 +116,15 @@ void post_create_collection(http_req & req, http_res & res) { sort_fields.push_back(field(sort_field_json["name"], sort_field_json["type"])); } + const char* PREFIX_RANKING_FIELD = "prefix_ranking_field"; std::string token_ranking_field = ""; - if(req_json.count("token_ranking_field") != 0) { - if(!req_json["token_ranking_field"].is_string()) { - return res.send_400("Wrong format for `token_ranking_field`. It should be a string (name of a field)."); + if(req_json.count(PREFIX_RANKING_FIELD) != 0) { + if(!req_json[PREFIX_RANKING_FIELD].is_string()) { + return res.send_400(std::string("Wrong format for `") + PREFIX_RANKING_FIELD + "`. It should be the name of an unsigned INT32 field."); } - token_ranking_field = req_json["token_ranking_field"].get(); + token_ranking_field = req_json[PREFIX_RANKING_FIELD].get(); } collectionManager.create_collection(req_json["name"], search_fields, facet_fields, sort_fields, token_ranking_field); @@ -157,6 +158,7 @@ void get_search(http_req & req, http_res & res) { const char *PER_PAGE = "per_page"; const char *PAGE = "page"; const char *CALLBACK = "callback"; + const char *SORT_PREFIXES_BY = "sort_prefixes_by"; if(req.params.count(NUM_TYPOS) == 0) { req.params[NUM_TYPOS] = "2"; @@ -217,11 +219,17 @@ void get_search(http_req & req, http_res & res) { bool prefix = (req.params[PREFIX] == "true"); - token_ordering token_order = FREQUENCY; - if(prefix && !collection->get_token_ranking_field().empty()) { - token_order = MAX_SCORE; + if(req.params.count(SORT_PREFIXES_BY) == 0) { + if(prefix && !collection->get_token_ranking_field().empty()) { + req.params[SORT_PREFIXES_BY] = "PREFIX_SORT_FIELD"; + } else { + req.params[SORT_PREFIXES_BY] = "TERM_FREQUENCY"; + } } + StringUtils::toupper(req.params[SORT_PREFIXES_BY]); + token_ordering token_order = (req.params[SORT_PREFIXES_BY] == "PREFIX_SORT_FIELD") ? MAX_SCORE : FREQUENCY; + Option result_op = collection->search(req.params["q"], search_fields, filter_str, facet_fields, sort_fields, std::stoi(req.params[NUM_TYPOS]), std::stoi(req.params[PER_PAGE]), std::stoi(req.params[PAGE]), diff --git a/src/art.cpp b/src/art.cpp index 3ac2d47f..97155f2c 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -91,7 +91,7 @@ bool compare_art_node_frequency_pq(const art_node *a, const art_node *b) { } bool compare_art_node_score_pq(const art_node* a, const art_node* b) { - return !compare_art_node_frequency(a, b); + return !compare_art_node_score(a, b); } /** diff --git a/src/collection.cpp b/src/collection.cpp index 496b6559..2f6161b3 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -405,7 +405,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt std::vector> & token_to_candidates, std::vector> & searched_queries, Topster<100> & topster, size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len, - const size_t & max_results) { + const size_t & max_results, const bool prefix) { const size_t combination_limit = 10; auto product = []( long long a, std::vector& b ) { return a*b.size(); }; long long int N = std::accumulate(token_to_candidates.begin(), token_to_candidates.end(), 1LL, product); @@ -422,10 +422,17 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt uint32_t* result_ids = query_suggestion[0]->values->ids.uncompress(); size_t result_size = query_suggestion[0]->values->ids.getLength(); - if(result_size == 0) continue; + if(result_size == 0) { + continue; + } candidate_rank += 1; + int actual_candidate_rank = candidate_rank; + if(prefix) { + actual_candidate_rank = 0; + } + // intersect the document ids for each token to find docs that contain all the tokens (stored in `result_ids`) for(auto i=1; i < query_suggestion.size(); i++) { uint32_t* out = nullptr; @@ -449,7 +456,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt do_facets(facets, filtered_result_ids, filtered_results_size); // go through each matching document id and calculate match score - score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion, + score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion, filtered_result_ids, filtered_results_size); delete[] filtered_result_ids; @@ -463,14 +470,19 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt delete [] *all_result_ids; *all_result_ids = new_all_result_ids; - score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion, result_ids, result_size); + score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion, + result_ids, result_size); delete[] result_ids; } total_results += topster.size; searched_queries.push_back(query_suggestion); - if(total_results >= max_results) { + if(!prefix && total_results >= max_results) { + break; + } + + if(prefix && candidate_rank >= max_results) { break; } } @@ -944,7 +956,6 @@ void Collection::search_field(std::string & query, const std::string & field, ui leaves = token_cost_cache[token_cost_hash]; } else { int token_len = prefix ? (int) token.length() : (int) token.length() + 1; - int count = search_index.count(field); art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len, @@ -985,12 +996,17 @@ void Collection::search_field(std::string & query, const std::string & field, ui if(token_to_candidates.size() != 0 && token_to_candidates.size() == tokens.size()) { // If all tokens were found, go ahead and search for candidates with what we have so far search_candidates(filter_ids, filter_ids_length, facets, sort_fields, candidate_rank, token_to_candidates, - searched_queries, topster, total_results, all_result_ids, all_result_ids_len, max_results); + searched_queries, topster, total_results, all_result_ids, all_result_ids_len, + max_results, prefix); - if (total_results >= max_results) { + if (!prefix && total_results >= max_results) { // If we don't find enough results, we continue outerloop (looking at tokens with greater cost) break; } + + if(prefix && candidate_rank > 10) { + break; + } } n++; diff --git a/src/main/benchmark.cpp b/src/main/benchmark.cpp index 0be93f3d..c97b8b74 100644 --- a/src/main/benchmark.cpp +++ b/src/main/benchmark.cpp @@ -24,7 +24,7 @@ int main(int argc, char* argv[]) { Collection *collection = collectionManager.get_collection("hnstories_direct"); if(collection == nullptr) { - collection = collectionManager.create_collection("hnstories_direct", fields_to_index, {}, sort_fields); + collection = collectionManager.create_collection("hnstories_direct", fields_to_index, {}, sort_fields, "points"); } std::ifstream infile("/Users/kishore/Downloads/hnstories.jsonl"); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 15a04b81..9aa3feba 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -349,7 +349,7 @@ TEST_F(CollectionTest, PrefixSearching) { std::vector facets; nlohmann::json results = collection->search("ex", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, true).get(); ASSERT_EQ(2, results["hits"].size()); - std::vector ids = {"12", "6"}; + std::vector ids = {"6", "12"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); @@ -379,6 +379,29 @@ TEST_F(CollectionTest, PrefixSearching) { std::string id = ids.at(i); ASSERT_STREQ(id.c_str(), result_id.c_str()); } + + // restrict to only 2 results and differentiate between MAX_SCORE and FREQUENCY + results = collection->search("t", query_fields, "", facets, sort_fields, 0, 2, 1, MAX_SCORE, true).get(); + ASSERT_EQ(2, results["hits"].size()); + ids = {"19", "22"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results = collection->search("t", query_fields, "", facets, sort_fields, 0, 2, 1, FREQUENCY, true).get(); + ASSERT_EQ(2, results["hits"].size()); + ids = {"1", "6"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } } TEST_F(CollectionTest, MultipleFields) { diff --git a/test/documents.jsonl b/test/documents.jsonl index 130fc0a0..7646a533 100644 --- a/test/documents.jsonl +++ b/test/documents.jsonl @@ -16,7 +16,7 @@ {"points":10,"title":"How late do the launch propellants ionize in a chemical rocket mission?"} {"points":8,"title":"How much does it cost to launch (right from start) a rocket today?"} {"points":16,"title":"Difference between Space Dynamics & Astrodynamics in engineering perspective?"} -{"points":18,"title":"What kind of biological research does ISS do?"} +{"points":18,"title":"What kind of biological research does ISS do then?"} {"points":10,"title":"Which kinds of radiation hit ISX ?"} {"points":7,"title":"What kinds of things have been tossed out of ISS in space?"} {"points":17,"title":"What does triple redundant closed loop digital avionics system mean?"}