Expose typo_tokens_threshold parameter.

If the number of results found for a specific query is less than this number, Typesense will attempt to look for tokens with more typos until enough results are found.
This commit is contained in:
kishorenc 2020-03-07 12:31:05 +05:30
parent eef3a5a3de
commit 206fe5b833
6 changed files with 67 additions and 28 deletions

View File

@ -227,7 +227,8 @@ public:
size_t max_facet_values=10, size_t max_hits=500,
const std::string & simple_facet_query = "",
const size_t snippet_threshold = 30,
const std::string & highlight_full_fields = "");
const std::string & highlight_full_fields = "",
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD);
Option<nlohmann::json> get(const std::string & id);

View File

@ -38,6 +38,7 @@ struct search_args {
token_ordering token_order;
bool prefix;
size_t drop_tokens_threshold;
size_t typo_tokens_threshold;
std::vector<KV> raw_result_kvs;
size_t all_result_ids_len;
std::vector<std::vector<art_leaf*>> searched_queries;
@ -52,11 +53,12 @@ struct search_args {
std::vector<facet> facets, std::vector<uint32_t> included_ids, std::vector<uint32_t> excluded_ids,
std::vector<sort_by> sort_fields_std, facet_query_t facet_query, int num_typos, size_t max_facet_values,
size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix,
size_t drop_tokens_threshold):
size_t drop_tokens_threshold, size_t typo_tokens_threshold):
query(query), search_fields(search_fields), filters(filters), facets(facets), included_ids(included_ids),
excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), facet_query(facet_query), num_typos(num_typos),
max_facet_values(max_facet_values), max_hits(max_hits), per_page(per_page),
page(page), token_order(token_order), prefix(prefix), drop_tokens_threshold(drop_tokens_threshold),
page(page), token_order(token_order), prefix(prefix),
drop_tokens_threshold(drop_tokens_threshold), typo_tokens_threshold(typo_tokens_threshold),
all_result_ids_len(0), outcome(0) {
}
@ -147,19 +149,22 @@ private:
void drop_facets(std::vector<facet> & facets, const std::vector<uint32_t> & ids);
void search_field(const uint8_t & field_id, std::string & query,
void search_field(const uint8_t & field_id, const std::string & query,
const std::string & field, uint32_t *filter_ids, size_t filter_ids_length,
std::vector<facet> & facets, const std::vector<sort_by> & sort_fields,
const int num_typos, std::vector<std::vector<art_leaf*>> & searched_queries,
Topster & topster, uint32_t** all_result_ids,
size_t & all_result_ids_len, const token_ordering token_order = FREQUENCY,
const bool prefix = false, const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD);
const bool prefix = false,
const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD,
const size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD);
void search_candidates(const uint8_t & field_id, uint32_t* filter_ids, size_t filter_ids_length,
const std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
const token_ordering token_order, std::vector<std::vector<art_leaf*>> & searched_queries,
Topster & topster, uint32_t** all_result_ids,
size_t & all_result_ids_len, const size_t & max_results);
size_t & all_result_ids_len,
const size_t typo_tokens_threshold);
void insert_doc(const uint32_t score, art_tree *t, uint32_t seq_id,
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;
@ -208,7 +213,7 @@ public:
void run_search();
void search(Option<uint32_t> & outcome, std::string query, const std::vector<std::string> & search_fields,
void search(Option<uint32_t> & outcome, const std::string & query, const std::vector<std::string> & search_fields,
const std::vector<filter> & filters, std::vector<facet> & facets,
facet_query_t & facet_query,
const std::vector<uint32_t> & included_ids, const std::vector<uint32_t> & excluded_ids,
@ -216,7 +221,7 @@ public:
const size_t max_hits, const size_t per_page, const size_t page, const token_ordering token_order,
const bool prefix, const size_t drop_tokens_threshold, std::vector<KV> & raw_result_kvs,
size_t & all_result_ids_len, std::vector<std::vector<art_leaf*>> & searched_queries,
std::vector<KV> & override_result_kvs);
std::vector<KV> & override_result_kvs, const size_t typo_tokens_threshold);
Option<uint32_t> remove(const uint32_t seq_id, nlohmann::json & document);
@ -250,7 +255,7 @@ public:
const spp::sparse_hash_map<std::string, art_tree *> &_get_search_index() const;
// for limiting number of results on multiple candidates / query rewrites
enum {SEARCH_LIMIT_NUM = 100};
enum {TYPO_TOKENS_THRESHOLD = 100};
// for limiting number of fields that can be searched on
enum {FIELD_LIMIT_NUM = 100};

View File

@ -313,7 +313,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
const size_t max_facet_values, const size_t max_hits,
const std::string & simple_facet_query,
const size_t snippet_threshold,
const std::string & highlight_full_fields) {
const std::string & highlight_full_fields,
size_t typo_tokens_threshold ) {
std::vector<uint32_t> included_ids;
std::vector<uint32_t> excluded_ids;
@ -583,7 +584,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
index->search_params = search_args(query, search_fields, filters, facets,
index_to_included_ids[index_id], index_to_excluded_ids[index_id],
sort_fields_std, facet_query, num_typos, max_facet_values, max_hits,
results_per_page, page, token_order, prefix, drop_tokens_threshold);
results_per_page, page, token_order, prefix,
drop_tokens_threshold, typo_tokens_threshold);
{
std::lock_guard<std::mutex> lk(index->m);
index->ready = true;

View File

@ -172,6 +172,7 @@ void get_search(http_req & req, http_res & res) {
const char *NUM_TYPOS = "num_typos";
const char *PREFIX = "prefix";
const char *DROP_TOKENS_THRESHOLD = "drop_tokens_threshold";
const char *TYPO_TOKENS_THRESHOLD = "typo_tokens_threshold";
const char *FILTER = "filter_by";
const char *QUERY = "q";
const char *QUERY_BY = "query_by";
@ -207,6 +208,10 @@ void get_search(http_req & req, http_res & res) {
req.params[DROP_TOKENS_THRESHOLD] = std::to_string(Index::DROP_TOKENS_THRESHOLD);
}
if(req.params.count(TYPO_TOKENS_THRESHOLD) == 0) {
req.params[TYPO_TOKENS_THRESHOLD] = std::to_string(Index::TYPO_TOKENS_THRESHOLD);
}
if(req.params.count(QUERY) == 0) {
return res.send_400(std::string("Parameter `") + QUERY + "` is required.");
}
@ -260,6 +265,10 @@ void get_search(http_req & req, http_res & res) {
return res.send_400("Parameter `" + std::string(DROP_TOKENS_THRESHOLD) + "` must be an unsigned integer.");
}
if(!StringUtils::is_uint64_t(req.params[TYPO_TOKENS_THRESHOLD])) {
return res.send_400("Parameter `" + std::string(TYPO_TOKENS_THRESHOLD) + "` must be an unsigned integer.");
}
if(!StringUtils::is_uint64_t(req.params[NUM_TYPOS])) {
return res.send_400("Parameter `" + std::string(NUM_TYPOS) + "` must be an unsigned integer.");
}
@ -320,6 +329,7 @@ void get_search(http_req & req, http_res & res) {
bool prefix = (req.params[PREFIX] == "true");
const size_t drop_tokens_threshold = (size_t) std::stoi(req.params[DROP_TOKENS_THRESHOLD]);
const size_t typo_tokens_threshold = (size_t) std::stoi(req.params[TYPO_TOKENS_THRESHOLD]);
if(req.params.count(RANK_TOKENS_BY) == 0) {
req.params[RANK_TOKENS_BY] = "DEFAULT_SORTING_FIELD";
@ -338,7 +348,8 @@ void get_search(http_req & req, http_res & res) {
static_cast<size_t>(std::stoi(req.params[MAX_HITS])),
req.params[FACET_QUERY],
static_cast<size_t>(std::stoi(req.params[SNIPPET_THRESHOLD])),
req.params[HIGHLIGHT_FULL_FIELDS]
req.params[HIGHLIGHT_FULL_FIELDS],
typo_tokens_threshold
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -831,7 +831,7 @@ void Index::search_candidates(const uint8_t & field_id, uint32_t* filter_ids, si
std::vector<token_candidates> & token_candidates_vec, const token_ordering token_order,
std::vector<std::vector<art_leaf*>> & searched_queries, Topster & topster,
uint32_t** all_result_ids, size_t & all_result_ids_len,
const size_t & max_results) {
const size_t typo_tokens_threshold) {
const long long combination_limit = 10;
auto product = []( long long a, token_candidates & b ) { return a*b.candidates.size(); };
@ -905,7 +905,7 @@ void Index::search_candidates(const uint8_t & field_id, uint32_t* filter_ids, si
searched_queries.push_back(query_suggestion);
if(all_result_ids_len >= max_results) {
if(all_result_ids_len >= typo_tokens_threshold) {
break;
}
}
@ -1060,7 +1060,8 @@ void Index::run_search() {
search_params.excluded_ids, search_params.sort_fields_std, search_params.num_typos,
search_params.max_hits, search_params.per_page, search_params.page, search_params.token_order,
search_params.prefix, search_params.drop_tokens_threshold, search_params.raw_result_kvs,
search_params.all_result_ids_len, search_params.searched_queries, search_params.override_result_kvs);
search_params.all_result_ids_len, search_params.searched_queries, search_params.override_result_kvs,
search_params.typo_tokens_threshold);
// hand control back to main thread
processed = true;
@ -1141,7 +1142,7 @@ void Index::collate_curated_ids(const std::string & query, const std::string & f
}
void Index::search(Option<uint32_t> & outcome,
std::string query,
const std::string & query,
const std::vector<std::string> & search_fields,
const std::vector<filter> & filters,
std::vector<facet> & facets, facet_query_t & facet_query,
@ -1153,7 +1154,8 @@ void Index::search(Option<uint32_t> & outcome,
std::vector<KV> & raw_result_kvs,
size_t & all_result_ids_len,
std::vector<std::vector<art_leaf*>> & searched_queries,
std::vector<KV> & override_result_kvs) {
std::vector<KV> & override_result_kvs,
const size_t typo_tokens_threshold) {
const size_t num_results = (page * per_page);
@ -1195,7 +1197,7 @@ void Index::search(Option<uint32_t> & outcome,
search_field(field_id, query, field, filter_ids, filter_ids_length, facets, sort_fields_std,
num_typos, searched_queries, topster, &all_result_ids, all_result_ids_len,
token_order, prefix, drop_tokens_threshold);
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold);
collate_curated_ids(query, field, field_id, included_ids, curated_topster, searched_queries);
}
}
@ -1253,12 +1255,13 @@ void Index::search(Option<uint32_t> & outcome,
4. Intersect the lists to find docs that match each phrase
5. Sort the docs based on some ranking criteria
*/
void Index::search_field(const uint8_t & field_id, std::string & query, const std::string & field,
void Index::search_field(const uint8_t & field_id, const std::string & query, const std::string & field,
uint32_t *filter_ids, size_t filter_ids_length,
std::vector<facet> & facets, const std::vector<sort_by> & sort_fields, const int num_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster & topster, uint32_t** all_result_ids, size_t & all_result_ids_len,
const token_ordering token_order, const bool prefix, const size_t drop_tokens_threshold) {
const token_ordering token_order, const bool prefix,
const size_t drop_tokens_threshold, const size_t typo_tokens_threshold) {
std::vector<std::string> tokens;
StringUtils::split(query, tokens, " ");
@ -1320,7 +1323,7 @@ void Index::search_field(const uint8_t & field_id, std::string & query, const st
leaves = token_cost_cache[token_cost_hash];
} else {
// prefix should apply only for last token
const bool prefix_search = prefix && ((token_index == tokens.size()-1) ? true : false);
const bool prefix_search = prefix && (token_index == tokens.size()-1);
const size_t token_len = prefix_search ? (int) token.length() : (int) token.length() + 1;
// If this is a prefix search, look for more candidates and do a union of those document IDs
@ -1367,16 +1370,16 @@ void Index::search_field(const uint8_t & field_id, std::string & query, const st
token_index++;
}
if(token_candidates_vec.size() != 0 && token_candidates_vec.size() == tokens.size()) {
if(!token_candidates_vec.empty() && token_candidates_vec.size() == tokens.size()) {
// If all tokens were found, go ahead and search for candidates with what we have so far
search_candidates(field_id, filter_ids, filter_ids_length, sort_fields, token_candidates_vec,
token_order, searched_queries, topster, all_result_ids, all_result_ids_len,
Index::SEARCH_LIMIT_NUM);
typo_tokens_threshold);
}
if (all_result_ids_len >= Index::SEARCH_LIMIT_NUM) {
// If we don't find enough results, we continue outerloop (looking at tokens with greater cost)
break;
}
if (all_result_ids_len >= typo_tokens_threshold) {
// If we don't find enough results, we continue outerloop (looking at tokens with greater typo cost)
break;
}
n++;

View File

@ -85,7 +85,7 @@ TEST_F(CollectionTest, ExactSearchShouldBeStable) {
nlohmann::json results = collection->search("the", query_fields, "", facets, sort_fields, 0, 10).get();
ASSERT_EQ(7, results["hits"].size());
ASSERT_EQ(7, results["found"].get<int>());
ASSERT_STREQ("the", results["request_params"]["q"].get<std::string>().c_str());
ASSERT_EQ(10, results["request_params"]["per_page"].get<size_t>());
@ -539,6 +539,23 @@ TEST_F(CollectionTest, PrefixSearching) {
ASSERT_EQ("16", results["hits"].at(0)["document"]["id"]);
}
TEST_F(CollectionTest, TypoTokensThreshold) {
// Query expansion should happen only based on the `typo_tokens_threshold` value
auto results = collection->search("launch", {"title"}, "", {}, sort_fields, 2, 10, 1,
token_ordering::FREQUENCY, true, 10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, 500, "", 5, "", 0).get();
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ(5, results["found"].get<size_t>());
results = collection->search("launch", {"title"}, "", {}, sort_fields, 2, 10, 1,
token_ordering::FREQUENCY, true, 10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, 500, "", 5, "", 10).get();
ASSERT_EQ(7, results["hits"].size());
ASSERT_EQ(7, results["found"].get<size_t>());
}
TEST_F(CollectionTest, MultiOccurrenceString) {
Collection *coll_multi_string;