mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 21:22:25 +08:00
Drop tokens direction.
This commit is contained in:
parent
00be933b23
commit
b530f80770
@ -497,7 +497,8 @@ public:
|
||||
const size_t remote_embedding_num_tries = 2,
|
||||
const std::string& stopwords_set="",
|
||||
const std::vector<std::string>& facet_return_parent = {},
|
||||
const std::vector<ref_include_fields>& ref_include_fields_vec = {}) const;
|
||||
const std::vector<ref_include_fields>& ref_include_fields_vec = {},
|
||||
const drop_tokens_mode_t drop_tokens_mode = right_to_left) const;
|
||||
|
||||
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;
|
||||
|
||||
|
@ -98,6 +98,11 @@ enum text_match_type_t {
|
||||
max_weight
|
||||
};
|
||||
|
||||
enum drop_tokens_mode_t {
|
||||
left_to_right,
|
||||
right_to_left,
|
||||
};
|
||||
|
||||
struct search_args {
|
||||
std::vector<query_tokens_t> field_query_tokens;
|
||||
std::vector<search_field_t> search_fields;
|
||||
@ -146,6 +151,7 @@ struct search_args {
|
||||
vector_query_t& vector_query;
|
||||
size_t facet_sample_percent;
|
||||
size_t facet_sample_threshold;
|
||||
drop_tokens_mode_t drop_tokens_mode;
|
||||
|
||||
search_args(std::vector<query_tokens_t> field_query_tokens, std::vector<search_field_t> search_fields,
|
||||
const text_match_type_t match_type,
|
||||
@ -161,7 +167,7 @@ struct search_args {
|
||||
size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector<enable_t>& infixes,
|
||||
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
|
||||
const bool filter_curated_hits, const enable_t split_join_tokens, vector_query_t& vector_query,
|
||||
size_t facet_sample_percent, size_t facet_sample_threshold) :
|
||||
size_t facet_sample_percent, size_t facet_sample_threshold, drop_tokens_mode_t drop_tokens_mode) :
|
||||
field_query_tokens(field_query_tokens),
|
||||
search_fields(search_fields), match_type(match_type), filter_tree_root(filter_tree_root), facets(facets),
|
||||
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
|
||||
@ -176,7 +182,8 @@ struct search_args {
|
||||
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix),
|
||||
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits),
|
||||
split_join_tokens(split_join_tokens), vector_query(vector_query),
|
||||
facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold) {
|
||||
facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold),
|
||||
drop_tokens_mode(drop_tokens_mode) {
|
||||
|
||||
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
|
||||
topster = new Topster(topster_size, group_limit);
|
||||
@ -641,7 +648,8 @@ public:
|
||||
const size_t max_extra_suffix, const size_t facet_query_num_typos,
|
||||
const bool filter_curated_hits, enable_t split_join_tokens,
|
||||
const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold,
|
||||
const std::string& collection_name, facet_index_type_t facet_index_type = DETECT) const;
|
||||
const std::string& collection_name, facet_index_type_t facet_index_type = DETECT,
|
||||
const drop_tokens_mode_t drop_tokens_mode = right_to_left) const;
|
||||
|
||||
void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name,
|
||||
const bool is_update);
|
||||
|
@ -1389,7 +1389,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const size_t remote_embedding_num_tries,
|
||||
const std::string& stopwords_set,
|
||||
const std::vector<std::string>& facet_return_parent,
|
||||
const std::vector<ref_include_fields>& ref_include_fields_vec) const {
|
||||
const std::vector<ref_include_fields>& ref_include_fields_vec,
|
||||
const drop_tokens_mode_t drop_tokens_mode) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
@ -1888,7 +1889,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
min_len_1typo, min_len_2typo, max_candidates, infixes,
|
||||
max_extra_prefix, max_extra_suffix, facet_query_num_typos,
|
||||
filter_curated_hits, split_join_tokens, vector_query,
|
||||
facet_sample_percent, facet_sample_threshold);
|
||||
facet_sample_percent, facet_sample_threshold, drop_tokens_mode);
|
||||
|
||||
std::unique_ptr<search_args> search_params_guard(search_params);
|
||||
|
||||
|
@ -976,6 +976,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
const char *FACET_SAMPLE_PERCENT = "facet_sample_percent";
|
||||
const char *FACET_SAMPLE_THRESHOLD = "facet_sample_threshold";
|
||||
|
||||
const char *DROP_TOKENS_MODE = "drop_tokens_mode";
|
||||
|
||||
// enrich params with values from embedded params
|
||||
for(auto& item: embedded_params.items()) {
|
||||
if(item.key() == "expires_at") {
|
||||
@ -1096,6 +1098,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
size_t facet_sample_percent = 100;
|
||||
size_t facet_sample_threshold = 0;
|
||||
|
||||
std::string drop_tokens_mode_str = "right_to_left";
|
||||
|
||||
std::unordered_map<std::string, size_t*> unsigned_int_values = {
|
||||
{MIN_LEN_1TYPO, &min_len_1typo},
|
||||
{MIN_LEN_2TYPO, &min_len_2typo},
|
||||
@ -1132,6 +1136,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{HIGHLIGHT_END_TAG, &highlight_end_tag},
|
||||
{PINNED_HITS, &pinned_hits_str},
|
||||
{HIDDEN_HITS, &hidden_hits_str},
|
||||
{DROP_TOKENS_MODE, &drop_tokens_mode_str},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, bool*> bool_values = {
|
||||
@ -1293,6 +1298,12 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
Index::NUM_CANDIDATES_DEFAULT_MIN);
|
||||
}
|
||||
|
||||
auto drop_tokens_mode_op = magic_enum::enum_cast<drop_tokens_mode_t>(drop_tokens_mode_str);
|
||||
drop_tokens_mode_t drop_tokens_mode;
|
||||
if(drop_tokens_mode_op.has_value()) {
|
||||
drop_tokens_mode = drop_tokens_mode_op.value();
|
||||
}
|
||||
|
||||
Option<nlohmann::json> result_op = collection->search(raw_query, search_fields, filter_query, facet_fields,
|
||||
sort_fields, num_typos,
|
||||
per_page,
|
||||
@ -1341,7 +1352,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
remote_embedding_num_tries,
|
||||
stopwords_set,
|
||||
facet_return_parent,
|
||||
ref_include_fields_vec);
|
||||
ref_include_fields_vec,
|
||||
drop_tokens_mode);
|
||||
|
||||
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::high_resolution_clock::now() - begin).count();
|
||||
|
@ -1762,7 +1762,8 @@ Option<bool> Index::run_search(search_args* search_params, const std::string& co
|
||||
search_params->facet_sample_percent,
|
||||
search_params->facet_sample_threshold,
|
||||
collection_name,
|
||||
facet_index_type);
|
||||
facet_index_type,
|
||||
search_params->drop_tokens_mode);
|
||||
}
|
||||
|
||||
void Index::collate_included_ids(const std::vector<token_t>& q_included_tokens,
|
||||
@ -2211,7 +2212,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
const bool filter_curated_hits, const enable_t split_join_tokens,
|
||||
const vector_query_t& vector_query,
|
||||
size_t facet_sample_percent, size_t facet_sample_threshold,
|
||||
const std::string& collection_name, facet_index_type_t facet_index_type) const {
|
||||
const std::string& collection_name,
|
||||
facet_index_type_t facet_index_type,
|
||||
const drop_tokens_mode_t drop_tokens_mode) const {
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root);
|
||||
@ -2610,16 +2613,24 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
for (size_t qi = 0; qi < all_queries.size(); qi++) {
|
||||
auto& orig_tokens = all_queries[qi];
|
||||
size_t num_tokens_dropped = 0;
|
||||
auto curr_direction = drop_tokens_mode;
|
||||
size_t total_dirs_done = 0;
|
||||
|
||||
while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) {
|
||||
// When atleast two tokens from the query are available we can drop one
|
||||
std::vector<token_t> truncated_tokens;
|
||||
std::vector<token_t> dropped_tokens;
|
||||
|
||||
if(orig_tokens.size() > 1 && num_tokens_dropped < 2*(orig_tokens.size()-1)) {
|
||||
bool prefix_search = false;
|
||||
if(num_tokens_dropped >= orig_tokens.size() - 1) {
|
||||
// swap direction and reset counter
|
||||
curr_direction = (curr_direction == right_to_left) ? left_to_right : right_to_left;
|
||||
num_tokens_dropped = 0;
|
||||
total_dirs_done++;
|
||||
}
|
||||
|
||||
if (num_tokens_dropped < orig_tokens.size() - 1) {
|
||||
if(orig_tokens.size() > 1 && total_dirs_done < 2) {
|
||||
bool prefix_search = false;
|
||||
if (curr_direction == right_to_left) {
|
||||
// drop from right
|
||||
size_t truncated_len = orig_tokens.size() - num_tokens_dropped - 1;
|
||||
for (size_t i = 0; i < orig_tokens.size(); i++) {
|
||||
@ -2632,7 +2643,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
} else {
|
||||
// drop from left
|
||||
prefix_search = true;
|
||||
size_t start_index = (num_tokens_dropped + 1) - orig_tokens.size() + 1;
|
||||
size_t start_index = (num_tokens_dropped + 1);
|
||||
for(size_t i = 0; i < orig_tokens.size(); i++) {
|
||||
if(i >= start_index) {
|
||||
truncated_tokens.emplace_back(orig_tokens[i]);
|
||||
|
@ -2332,6 +2332,47 @@ TEST_F(CollectionSpecificMoreTest, ExhaustiveSearchWithoutExplicitDropTokens) {
|
||||
ASSERT_EQ(2, res["hits"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSpecificMoreTest, DropTokensLeftToRightFirst) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string"}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "alpha beta";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
|
||||
doc["title"] = "beta gamma";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
|
||||
bool exhaustive_search = false;
|
||||
size_t drop_tokens_threshold = 1;
|
||||
|
||||
auto res = coll1->search("alpha beta gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, left_to_right).get();
|
||||
|
||||
ASSERT_EQ(1, res["hits"].size());
|
||||
ASSERT_EQ("1", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
|
||||
res = coll1->search("alpha beta gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, right_to_left).get();
|
||||
|
||||
ASSERT_EQ(1, res["hits"].size());
|
||||
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSpecificMoreTest, DoNotHighlightFieldsForSpecialCharacterQuery) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
|
Loading…
x
Reference in New Issue
Block a user