diff --git a/TODO.md b/TODO.md index 4d8e36b1..b4479eda 100644 --- a/TODO.md +++ b/TODO.md @@ -96,9 +96,10 @@ - ~~gzip compress responses~~ - ~~Have a LOG(ERROR) level~~ - ~~Handle SIGTERM which is sent when process is killed~~ +- Exact search +- NOT operator support - Log operations - Parameterize replica's MAX_UPDATES_TO_SEND -- NOT operator support - > INT32_MAX validation for float field - highlight of string arrays? - test for token ranking on float field diff --git a/include/collection.h b/include/collection.h index 13f89a41..8f5ac786 100644 --- a/include/collection.h +++ b/include/collection.h @@ -96,7 +96,8 @@ public: const std::string & simple_filter_query, const std::vector & facet_fields, const std::vector & sort_fields, const int num_typos, const size_t per_page = 10, const size_t page = 1, - const token_ordering token_order = FREQUENCY, const bool prefix = false); + const token_ordering token_order = FREQUENCY, const bool prefix = false, + const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD); Option get(const std::string & id); diff --git a/include/index.h b/include/index.h index ec51aede..79ab0ad2 100644 --- a/include/index.h +++ b/include/index.h @@ -32,6 +32,7 @@ struct search_args { size_t page; token_ordering token_order; bool prefix; + size_t drop_tokens_threshold; std::vector::KV>> field_order_kvs; size_t all_result_ids_len; std::vector> searched_queries; @@ -43,10 +44,11 @@ struct search_args { search_args(std::string query, std::vector search_fields, std::vector filters, std::vector facets, std::vector sort_fields_std, int num_typos, - size_t per_page, size_t page, token_ordering token_order, bool prefix): + size_t per_page, size_t page, token_ordering token_order, bool prefix, size_t drop_tokens_threshold): query(query), search_fields(search_fields), filters(filters), facets(facets), sort_fields_std(sort_fields_std), num_typos(num_typos), per_page(per_page), page(page), - token_order(token_order), prefix(prefix), all_result_ids_len(0), outcome(0) { + token_order(token_order), prefix(prefix), drop_tokens_threshold(drop_tokens_threshold), + all_result_ids_len(0), outcome(0) { } }; @@ -91,7 +93,8 @@ private: const int num_typos, const size_t num_results, std::vector> & searched_queries, Topster<512> & topster, uint32_t** all_result_ids, - size_t & all_result_ids_len, const token_ordering token_order = FREQUENCY, const bool prefix = false); + size_t & all_result_ids_len, const token_ordering token_order = FREQUENCY, + const bool prefix = false, const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD); void search_candidates(uint32_t* filter_ids, size_t filter_ids_length, std::vector & facets, const std::vector & sort_fields, std::vector & token_to_candidates, @@ -138,7 +141,7 @@ public: const std::vector & filters, std::vector & facets, std::vector sort_fields_std, const int num_typos, const size_t per_page, const size_t page, - const token_ordering token_order, const bool prefix, + const token_ordering token_order, const bool prefix, const size_t drop_tokens_threshold, std::vector::KV>> & field_order_kv, size_t & all_result_ids_len, std::vector> & searched_queries); @@ -150,7 +153,11 @@ public: Option index_in_memory(const nlohmann::json & document, uint32_t seq_id, int32_t points); - static const int SEARCH_LIMIT_NUM = 100; // for limiting number of results on multiple candidates / query rewrites + static const int SEARCH_LIMIT_NUM = 100; // for limiting number of results on multiple candidates / query rewrites + + // If the number of results found is less than this threshold, Typesense will attempt to drop the tokens + // in the query that have the least individual hits one by one until enough results are found. + static const int DROP_TOKENS_THRESHOLD = 10; // strings under this length will be fully highlighted, instead of showing a snippet of relevant portion enum {SNIPPET_STR_ABOVE_LEN = 30}; diff --git a/src/api.cpp b/src/api.cpp index b56218a4..2e5d5093 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -161,6 +161,7 @@ void get_search(http_req & req, http_res & res) { const char *NUM_TYPOS = "num_typos"; const char *PREFIX = "prefix"; + const char *DROP_TOKENS_THRESHOLD = "drop_tokens_threshold"; const char *FILTER = "filter_by"; const char *QUERY = "q"; const char *QUERY_BY = "query_by"; @@ -179,6 +180,10 @@ void get_search(http_req & req, http_res & res) { req.params[PREFIX] = "true"; } + if(req.params.count(DROP_TOKENS_THRESHOLD) == 0) { + req.params[DROP_TOKENS_THRESHOLD] = std::to_string(Index::DROP_TOKENS_THRESHOLD); + } + if(req.params.count(QUERY) == 0) { return res.send_400(std::string("Parameter `") + QUERY + "` is required."); } @@ -195,6 +200,10 @@ void get_search(http_req & req, http_res & res) { req.params[PAGE] = "1"; } + if(!StringUtils::is_uint64_t(req.params[DROP_TOKENS_THRESHOLD])) { + return res.send_400("Parameter `" + std::string(DROP_TOKENS_THRESHOLD) + "` must be an unsigned integer."); + } + if(!StringUtils::is_uint64_t(req.params[NUM_TYPOS])) { return res.send_400("Parameter `" + std::string(NUM_TYPOS) + "` must be an unsigned integer."); } @@ -245,6 +254,7 @@ void get_search(http_req & req, http_res & res) { } bool prefix = (req.params[PREFIX] == "true"); + const size_t drop_tokens_threshold = (size_t) std::stoi(req.params[DROP_TOKENS_THRESHOLD]); if(req.params.count(RANK_TOKENS_BY) == 0) { req.params[RANK_TOKENS_BY] = "DEFAULT_SORTING_FIELD"; @@ -256,7 +266,7 @@ void get_search(http_req & req, http_res & res) { Option result_op = collection->search(req.params[QUERY], search_fields, filter_str, facet_fields, sort_fields, std::stoi(req.params[NUM_TYPOS]), std::stoi(req.params[PER_PAGE]), std::stoi(req.params[PAGE]), - token_order, prefix); + token_order, prefix, drop_tokens_threshold); uint64_t timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); diff --git a/src/art.cpp b/src/art.cpp index 2f099704..daa53a0a 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -905,15 +905,12 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r std::vector &results) { printf("INSIDE art_topk_iter: root->type: %d\n", root->type); - std::priority_queue, - std::function> q; + std::priority_queue, + decltype(&compare_art_node_score_pq)> q(compare_art_node_score_pq); if(token_order == FREQUENCY) { - q = std::priority_queue, - std::function>(compare_art_node_frequency_pq); - } else { - q = std::priority_queue, - std::function>(compare_art_node_score_pq); + q = std::priority_queue, + decltype(&compare_art_node_frequency_pq)>(compare_art_node_frequency_pq); } q.push(root); diff --git a/src/collection.cpp b/src/collection.cpp index 27436382..5d87140d 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -264,7 +264,8 @@ Option Collection::search(std::string query, const std::vector & facet_fields, const std::vector & sort_fields, const int num_typos, const size_t per_page, const size_t page, - const token_ordering token_order, const bool prefix) { + const token_ordering token_order, const bool prefix, + const size_t drop_tokens_threshold) { std::vector facets; // validate search fields @@ -430,7 +431,7 @@ Option Collection::search(std::string query, const std::vectorsearch_params = search_args(query, search_fields, filters, facets, sort_fields_std, - num_typos, per_page, page, token_order, prefix); + num_typos, per_page, page, token_order, prefix, drop_tokens_threshold); { std::lock_guard lk(index->m); index->ready = true; diff --git a/src/index.cpp b/src/index.cpp index 6c09ee6c..2a73a6fe 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -331,7 +331,7 @@ void Index::search_candidates(uint32_t* filter_ids, size_t filter_ids_length, st // every element in `query_suggestion` contains a token and its associated hits std::vector query_suggestion = next_suggestion(token_candidates_vec, n); - /*for(auto i=0; i < query_suggestion.size(); i++) { + /*for(size_t i=0; i < query_suggestion.size(); i++) { LOG(INFO) << "i: " << i << " - " << query_suggestion[i]->key; }*/ @@ -548,8 +548,8 @@ void Index::run_search() { search(search_params.outcome, search_params.query, search_params.search_fields, search_params.filters, search_params.facets, search_params.sort_fields_std, search_params.num_typos, search_params.per_page, search_params.page, - search_params.token_order, search_params.prefix, search_params.field_order_kvs, - search_params.all_result_ids_len, search_params.searched_queries); + search_params.token_order, search_params.prefix, search_params.drop_tokens_threshold, + search_params.field_order_kvs, search_params.all_result_ids_len, search_params.searched_queries); // hand control back to main thread processed = true; @@ -565,7 +565,8 @@ void Index::search(Option & outcome, std::string query, const std::vec const std::vector & filters, std::vector & facets, std::vector sort_fields_std, const int num_typos, const size_t per_page, const size_t page, const token_ordering token_order, - const bool prefix, std::vector::KV>> & field_order_kvs, + const bool prefix, const size_t drop_tokens_threshold, + std::vector::KV>> & field_order_kvs, size_t & all_result_ids_len, std::vector> & searched_queries) { const size_t num_results = (page * per_page); @@ -591,7 +592,8 @@ void Index::search(Option & outcome, std::string query, const std::vec // proceed to query search only when no filters are provided or when filtering produces results if(filters.size() == 0 || filter_ids_length > 0) { search_field(query, field, filter_ids, filter_ids_length, facets, sort_fields_std, num_typos, num_results, - searched_queries, topster, &all_result_ids, all_result_ids_len, token_order, prefix); + searched_queries, topster, &all_result_ids, all_result_ids_len, token_order, prefix, + drop_tokens_threshold); topster.sort(); } @@ -623,7 +625,7 @@ void Index::search_field(std::string & query, const std::string & field, uint32_ std::vector & facets, const std::vector & sort_fields, const int num_typos, const size_t num_results, std::vector> & searched_queries, Topster<512> &topster, uint32_t** all_result_ids, size_t & all_result_ids_len, - const token_ordering token_order, const bool prefix) { + const token_ordering token_order, const bool prefix, const size_t drop_tokens_threshold) { std::vector tokens; StringUtils::split(query, tokens, " "); @@ -713,7 +715,12 @@ void Index::search_field(std::string & query, const std::string & field, uint32_ if(it != token_to_costs[token_index].end()) { token_to_costs[token_index].erase(it); - // no more costs left for this token, clean up + // when no more costs are left for this token and `drop_tokens_threshold` is breached + if(token_to_costs[token_index].empty() && topster.size >= drop_tokens_threshold) { + break; + } + + // otherwise, we try to drop the token and search with remaining tokens if(token_to_costs[token_index].empty()) { token_to_costs.erase(token_to_costs.begin()+token_index); tokens.erase(tokens.begin()+token_index); @@ -747,7 +754,7 @@ void Index::search_field(std::string & query, const std::string & field, uint32_ } // When there are not enough overall results and atleast one token has results - if(topster.size < Index::SEARCH_LIMIT_NUM && token_to_count.size() > 1) { + if(topster.size < drop_tokens_threshold && token_to_count.size() > 1) { // Drop token with least hits and try searching again std::string truncated_query; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 91008b72..c503916d 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -206,6 +206,28 @@ TEST_F(CollectionTest, SkipUnindexedTokensDuringPhraseSearch) { ASSERT_STREQ(id.c_str(), result_id.c_str()); } + // should not try to drop tokens to expand query + results.clear(); + results = collection->search("the a", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, false, 10).get(); + ASSERT_EQ(8, results["hits"].size()); + + results.clear(); + results = collection->search("the a", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, false, 0).get(); + ASSERT_EQ(3, results["hits"].size()); + ids = {"8", "16", "10"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string id = ids.at(i); + std::string result_id = result["document"]["id"]; + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results.clear(); + results = collection->search("the a DoesNotExist", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, false, 0).get(); + ASSERT_EQ(0, results["hits"].size()); + + // with no indexed word results.clear(); results = collection->search("DoesNotExist1 DoesNotExist2", query_fields, "", facets, sort_fields, 0, 10).get(); ASSERT_EQ(0, results["hits"].size());