diff --git a/TODO.md b/TODO.md index 9ac2d6d7..5c722715 100644 --- a/TODO.md +++ b/TODO.md @@ -61,7 +61,7 @@ - ~~Collection Manager collections map should store plain collection name~~ - ~~init_collection of Collection manager should probably take seq_id as param~~ - ~~node score should be int32, no longer uint16 like in document struct~~ -- Typo in prefix search +- ~~Typo in prefix search~~ - Proper logging - https support - Validate before string to int conversion in the http api layer diff --git a/src/collection.cpp b/src/collection.cpp index a7646464..369b84e6 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -453,11 +453,6 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt candidate_rank += 1; - int actual_candidate_rank = candidate_rank; - if(prefix) { - actual_candidate_rank = 0; - } - // intersect the document ids for each token to find docs that contain all the tokens (stored in `result_ids`) for(auto i=1; i < query_suggestion.size(); i++) { uint32_t* out = nullptr; @@ -481,7 +476,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt do_facets(facets, filtered_result_ids, filtered_results_size); // go through each matching document id and calculate match score - score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion, + score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion, filtered_result_ids, filtered_results_size); delete[] filtered_result_ids; @@ -495,7 +490,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt delete [] *all_result_ids; *all_result_ids = new_all_result_ids; - score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion, + score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion, result_ids, result_size); delete[] result_ids; } @@ -938,13 +933,21 @@ void Collection::search_field(std::string & query, const std::string & field, ui spp::sparse_hash_map token_to_count; std::vector> token_to_costs; - std::vector all_costs; - - for(int cost = 0; cost <= max_cost; cost++) { - all_costs.push_back(cost); - } for(size_t token_index = 0; token_index < tokens.size(); token_index++) { + std::vector all_costs; + const size_t token_len = tokens[token_index].length(); + + // This ensures that we don't end up doing a cost of 1 for a single char etc. + int bounded_cost = max_cost; + if(token_len > 0 && max_cost >= token_len && (token_len == 1 || token_len == 2)) { + bounded_cost = token_len - 1; + } + + for(int cost = 0; cost <= bounded_cost; cost++) { + all_costs.push_back(cost); + } + token_to_costs.push_back(all_costs); std::transform(tokens[token_index].begin(), tokens[token_index].end(), tokens[token_index].begin(), ::tolower); } @@ -986,9 +989,10 @@ void Collection::search_field(std::string & query, const std::string & field, ui // prefix should apply only for last token const bool prefix_search = prefix && ((token_index == tokens.size()-1) ? true : false); const int token_len = prefix_search ? (int) token.length() : (int) token.length() + 1; + const int max_candidates = prefix_search ? 5 : 3; art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len, - costs[token_index], costs[token_index], 3, token_order, prefix_search, leaves); + costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, leaves); if(!leaves.empty()) { token_cost_cache.emplace(token_cost_hash, leaves); @@ -1033,6 +1037,7 @@ void Collection::search_field(std::string & query, const std::string & field, ui break; } + // only allow upto 10 prefix candidate tokens if(prefix && candidate_rank > 10) { break; } @@ -1164,9 +1169,9 @@ void Collection::score_results(const std::vector & sort_fields, const i const number_t & secondary_rank_value = secondary_rank_score * secondary_rank_factor; topster.add(seq_id, query_index, match_score, primary_rank_value, secondary_rank_value); - /*std::cout << "candidate_rank_score: " << candidate_rank_score << ", words_present: " << mscore.words_present - << ", match_score: " << match_score << ", primary_rank_score: " << primary_rank_score - << ", seq_id: " << seq_id << std::endl;*/ + /*std::cout << "candidate_rank: " << candidate_rank << ", candidate_rank_score: " << candidate_rank_score + << ", words_present: " << mscore.words_present << ", match_score: " << match_score + << ", primary_rank_score: " << primary_rank_score.intval << ", seq_id: " << seq_id << std::endl;*/ } for (auto it = leaf_to_indices.begin(); it != leaf_to_indices.end(); it++) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index af7c0504..2f8f6374 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -351,7 +351,7 @@ TEST_F(CollectionTest, PrefixSearching) { std::vector facets; nlohmann::json results = collection->search("ex", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, true).get(); ASSERT_EQ(2, results["hits"].size()); - std::vector ids = {"6", "12"}; + std::vector ids = {"12", "6"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); @@ -408,6 +408,18 @@ TEST_F(CollectionTest, PrefixSearching) { // only the last token in the query should be used for prefix search - so, "math" should not match "mathematics" results = collection->search("math fx", query_fields, "", facets, sort_fields, 0, 1, 1, FREQUENCY, true).get(); ASSERT_EQ(0, results["hits"].size()); + + // single and double char prefixes should set a ceiling on the num_typos possible + results = collection->search("x", query_fields, "", facets, sort_fields, 2, 2, 1, FREQUENCY, true).get(); + ASSERT_EQ(0, results["hits"].size()); + + results = collection->search("xq", query_fields, "", facets, sort_fields, 2, 2, 1, FREQUENCY, true).get(); + ASSERT_EQ(0, results["hits"].size()); + + // prefix with a typo + results = collection->search("late propx", query_fields, "", facets, sort_fields, 2, 1, 1, FREQUENCY, true).get(); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("16", results["hits"].at(0)["id"]); } TEST_F(CollectionTest, MultipleFields) {