Drop tokens direction.

This commit is contained in:
Kishore Nallan 2023-09-25 07:29:55 +05:30
parent 00be933b23
commit b530f80770
6 changed files with 87 additions and 13 deletions

View File

@ -497,7 +497,8 @@ public:
const size_t remote_embedding_num_tries = 2,
const std::string& stopwords_set="",
const std::vector<std::string>& facet_return_parent = {},
const std::vector<ref_include_fields>& ref_include_fields_vec = {}) const;
const std::vector<ref_include_fields>& ref_include_fields_vec = {},
const drop_tokens_mode_t drop_tokens_mode = right_to_left) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -98,6 +98,11 @@ enum text_match_type_t {
max_weight
};
enum drop_tokens_mode_t {
left_to_right,
right_to_left,
};
struct search_args {
std::vector<query_tokens_t> field_query_tokens;
std::vector<search_field_t> search_fields;
@ -146,6 +151,7 @@ struct search_args {
vector_query_t& vector_query;
size_t facet_sample_percent;
size_t facet_sample_threshold;
drop_tokens_mode_t drop_tokens_mode;
search_args(std::vector<query_tokens_t> field_query_tokens, std::vector<search_field_t> search_fields,
const text_match_type_t match_type,
@ -161,7 +167,7 @@ struct search_args {
size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector<enable_t>& infixes,
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits, const enable_t split_join_tokens, vector_query_t& vector_query,
size_t facet_sample_percent, size_t facet_sample_threshold) :
size_t facet_sample_percent, size_t facet_sample_threshold, drop_tokens_mode_t drop_tokens_mode) :
field_query_tokens(field_query_tokens),
search_fields(search_fields), match_type(match_type), filter_tree_root(filter_tree_root), facets(facets),
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
@ -176,7 +182,8 @@ struct search_args {
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix),
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits),
split_join_tokens(split_join_tokens), vector_query(vector_query),
facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold) {
facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold),
drop_tokens_mode(drop_tokens_mode) {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
topster = new Topster(topster_size, group_limit);
@ -641,7 +648,8 @@ public:
const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits, enable_t split_join_tokens,
const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold,
const std::string& collection_name, facet_index_type_t facet_index_type = DETECT) const;
const std::string& collection_name, facet_index_type_t facet_index_type = DETECT,
const drop_tokens_mode_t drop_tokens_mode = right_to_left) const;
void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name,
const bool is_update);

View File

@ -1389,7 +1389,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t remote_embedding_num_tries,
const std::string& stopwords_set,
const std::vector<std::string>& facet_return_parent,
const std::vector<ref_include_fields>& ref_include_fields_vec) const {
const std::vector<ref_include_fields>& ref_include_fields_vec,
const drop_tokens_mode_t drop_tokens_mode) const {
std::shared_lock lock(mutex);
@ -1888,7 +1889,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
min_len_1typo, min_len_2typo, max_candidates, infixes,
max_extra_prefix, max_extra_suffix, facet_query_num_typos,
filter_curated_hits, split_join_tokens, vector_query,
facet_sample_percent, facet_sample_threshold);
facet_sample_percent, facet_sample_threshold, drop_tokens_mode);
std::unique_ptr<search_args> search_params_guard(search_params);

View File

@ -976,6 +976,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *FACET_SAMPLE_PERCENT = "facet_sample_percent";
const char *FACET_SAMPLE_THRESHOLD = "facet_sample_threshold";
const char *DROP_TOKENS_MODE = "drop_tokens_mode";
// enrich params with values from embedded params
for(auto& item: embedded_params.items()) {
if(item.key() == "expires_at") {
@ -1096,6 +1098,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
size_t facet_sample_percent = 100;
size_t facet_sample_threshold = 0;
std::string drop_tokens_mode_str = "right_to_left";
std::unordered_map<std::string, size_t*> unsigned_int_values = {
{MIN_LEN_1TYPO, &min_len_1typo},
{MIN_LEN_2TYPO, &min_len_2typo},
@ -1132,6 +1136,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{HIGHLIGHT_END_TAG, &highlight_end_tag},
{PINNED_HITS, &pinned_hits_str},
{HIDDEN_HITS, &hidden_hits_str},
{DROP_TOKENS_MODE, &drop_tokens_mode_str},
};
std::unordered_map<std::string, bool*> bool_values = {
@ -1293,6 +1298,12 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
Index::NUM_CANDIDATES_DEFAULT_MIN);
}
auto drop_tokens_mode_op = magic_enum::enum_cast<drop_tokens_mode_t>(drop_tokens_mode_str);
drop_tokens_mode_t drop_tokens_mode;
if(drop_tokens_mode_op.has_value()) {
drop_tokens_mode = drop_tokens_mode_op.value();
}
Option<nlohmann::json> result_op = collection->search(raw_query, search_fields, filter_query, facet_fields,
sort_fields, num_typos,
per_page,
@ -1341,7 +1352,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
remote_embedding_num_tries,
stopwords_set,
facet_return_parent,
ref_include_fields_vec);
ref_include_fields_vec,
drop_tokens_mode);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin).count();

View File

@ -1762,7 +1762,8 @@ Option<bool> Index::run_search(search_args* search_params, const std::string& co
search_params->facet_sample_percent,
search_params->facet_sample_threshold,
collection_name,
facet_index_type);
facet_index_type,
search_params->drop_tokens_mode);
}
void Index::collate_included_ids(const std::vector<token_t>& q_included_tokens,
@ -2211,7 +2212,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
const bool filter_curated_hits, const enable_t split_join_tokens,
const vector_query_t& vector_query,
size_t facet_sample_percent, size_t facet_sample_threshold,
const std::string& collection_name, facet_index_type_t facet_index_type) const {
const std::string& collection_name,
facet_index_type_t facet_index_type,
const drop_tokens_mode_t drop_tokens_mode) const {
std::shared_lock lock(mutex);
auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root);
@ -2610,16 +2613,24 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
for (size_t qi = 0; qi < all_queries.size(); qi++) {
auto& orig_tokens = all_queries[qi];
size_t num_tokens_dropped = 0;
auto curr_direction = drop_tokens_mode;
size_t total_dirs_done = 0;
while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) {
// When atleast two tokens from the query are available we can drop one
std::vector<token_t> truncated_tokens;
std::vector<token_t> dropped_tokens;
if(orig_tokens.size() > 1 && num_tokens_dropped < 2*(orig_tokens.size()-1)) {
bool prefix_search = false;
if(num_tokens_dropped >= orig_tokens.size() - 1) {
// swap direction and reset counter
curr_direction = (curr_direction == right_to_left) ? left_to_right : right_to_left;
num_tokens_dropped = 0;
total_dirs_done++;
}
if (num_tokens_dropped < orig_tokens.size() - 1) {
if(orig_tokens.size() > 1 && total_dirs_done < 2) {
bool prefix_search = false;
if (curr_direction == right_to_left) {
// drop from right
size_t truncated_len = orig_tokens.size() - num_tokens_dropped - 1;
for (size_t i = 0; i < orig_tokens.size(); i++) {
@ -2632,7 +2643,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
} else {
// drop from left
prefix_search = true;
size_t start_index = (num_tokens_dropped + 1) - orig_tokens.size() + 1;
size_t start_index = (num_tokens_dropped + 1);
for(size_t i = 0; i < orig_tokens.size(); i++) {
if(i >= start_index) {
truncated_tokens.emplace_back(orig_tokens[i]);

View File

@ -2332,6 +2332,47 @@ TEST_F(CollectionSpecificMoreTest, ExhaustiveSearchWithoutExplicitDropTokens) {
ASSERT_EQ(2, res["hits"].size());
}
TEST_F(CollectionSpecificMoreTest, DropTokensLeftToRightFirst) {
nlohmann::json schema = R"({
"name": "coll1",
"fields": [
{"name": "title", "type": "string"}
]
})"_json;
Collection* coll1 = collectionManager.create_collection(schema).get();
nlohmann::json doc;
doc["title"] = "alpha beta";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
doc["title"] = "beta gamma";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
bool exhaustive_search = false;
size_t drop_tokens_threshold = 1;
auto res = coll1->search("alpha beta gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
0, HASH, 30000, 2, "", {}, {}, left_to_right).get();
ASSERT_EQ(1, res["hits"].size());
ASSERT_EQ("1", res["hits"][0]["document"]["id"].get<std::string>());
res = coll1->search("alpha beta gamma", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, drop_tokens_threshold,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
0, HASH, 30000, 2, "", {}, {}, right_to_left).get();
ASSERT_EQ(1, res["hits"].size());
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
}
TEST_F(CollectionSpecificMoreTest, DoNotHighlightFieldsForSpecialCharacterQuery) {
nlohmann::json schema = R"({
"name": "coll1",