Handle override with synonyms.

This commit is contained in:
Kishore Nallan 2021-10-01 15:43:30 +05:30
parent 36a26f3f40
commit 947a5019d9
12 changed files with 470 additions and 220 deletions

View File

@ -207,8 +207,7 @@ private:
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t>& excluded_ids, std::vector<const override_t*>& filter_overrides,
std::vector<filter>& filters) const;
std::vector<uint32_t>& excluded_ids, std::vector<const override_t*>& filter_overrides) const;
Option<bool> check_and_update_schema(nlohmann::json& document, const DIRTY_VALUES& dirty_values);
@ -334,8 +333,6 @@ public:
bool facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document,
std::string &value) const;
static void aggregate_topster(size_t query_index, Topster &topster, Topster *index_topster);
static void populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs);
void batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out, size_t &num_indexed);
@ -387,7 +384,7 @@ public:
bool enable_overrides=true,
const std::string& highlight_fields="",
const bool exhaustive_search = false,
size_t search_stop_millis = 200) const;
size_t search_stop_millis = 6000*1000) const;
Option<bool> get_filter_ids(const std::string & simple_filter_query,
std::vector<std::pair<size_t, uint32_t*>>& index_ids);

View File

@ -478,7 +478,7 @@ struct filter {
static Option<bool> parse_filter_query(const std::string& simple_filter_query,
const std::unordered_map<std::string, field>& search_schema,
Store* store,
const Store* store,
const std::string& doc_id_prefix,
std::vector<filter>& filters);
};

View File

@ -58,7 +58,7 @@ struct override_t {
struct add_hit_t {
std::string doc_id;
uint32_t position;
uint32_t position = 0;
};
struct drop_hit_t {
@ -74,7 +74,7 @@ struct override_t {
std::string filter_by;
bool remove_matched_tokens = false;
override_t() {}
override_t() = default;
static Option<bool> parse(const nlohmann::json& override_json, const std::string& id, override_t& override) {
if(!override_json.is_object()) {
@ -198,6 +198,8 @@ struct override_t {
while(i < override.rule.query.size()) {
if(override.rule.query[i] == '}') {
override.rule.dynamic_query = true;
// remove spaces around curlies
override.rule.query = StringUtils::trim_curly_spaces(override.rule.query);
break;
}
i++;
@ -231,6 +233,11 @@ struct override_t {
override["excludes"].push_back(exclude);
}
if(!filter_by.empty()) {
override["filter_by"] = filter_by;
override["remove_matched_tokens"] = remove_matched_tokens;
}
return override;
}
};
@ -368,15 +375,15 @@ class Index {
private:
mutable std::shared_mutex mutex;
ThreadPool* thread_pool;
static constexpr const uint64_t FACET_ARRAY_DELIMETER = std::numeric_limits<uint64_t>::max();
std::string name;
const uint32_t collection_id;
Store* store;
const Store* store;
ThreadPool* thread_pool;
size_t num_documents;
@ -425,31 +432,31 @@ private:
uint32_t& token_bits,
uint64& qhash);
void log_leaves(const int cost, const std::string &token, const std::vector<art_leaf *> &leaves) const;
void log_leaves(int cost, const std::string &token, const std::vector<art_leaf *> &leaves) const;
void do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
size_t group_limit, const std::vector<std::string>& group_by_fields,
const uint32_t* result_ids, size_t results_size) const;
void static_filter_query_eval(const override_t* override, std::vector<std::string>& tokens,
bool static_filter_query_eval(const override_t* override, std::vector<std::string>& tokens,
std::vector<filter>& filters) const;
void process_filter_overrides(const std::vector<const override_t*>& filter_overrides,
std::vector<query_tokens_t>& field_query_tokens,
token_ordering token_order,
std::vector<filter>& filters,
uint32_t** filter_ids,
uint32_t& filter_ids_length) const;
std::vector<filter>& filters) const;
bool resolve_override(const std::vector<std::string>& rule_parts, const bool exact_rule_match,
bool resolve_override(const std::vector<std::string>& rule_tokens, bool exact_rule_match,
const std::vector<std::string>& query_tokens,
token_ordering token_order, std::set<std::string>& absorbed_tokens,
uint32_t*& override_ids, size_t& override_ids_len) const;
std::string& filter_by_clause) const;
bool check_for_overrides2(const token_ordering& token_order, const string& field_name, const bool slide_window,
uint32_t*& field_override_ids, size_t& field_override_ids_len,
bool exact_rule_match, std::vector<std::string>& tokens,
std::set<std::string>& absorbed_tokens) const;
bool check_for_overrides(const token_ordering& token_order, const string& field_name, bool slide_window,
bool exact_rule_match, std::vector<std::string>& tokens,
std::set<std::string>& absorbed_tokens,
std::vector<std::string>& field_absorbed_tokens) const;
static void aggregate_topster(Topster* agg_topster, Topster* index_topster);
void search_field(const uint8_t & field_id,
std::vector<token_t>& query_tokens,
@ -460,19 +467,19 @@ private:
const std::string & field, uint32_t *filter_ids, size_t filter_ids_length,
const std::vector<uint32_t>& curated_ids,
std::vector<facet> & facets, const std::vector<sort_by> & sort_fields,
const int num_typos, std::vector<std::vector<art_leaf*>> & searched_queries,
int num_typos, std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
uint32_t** all_result_ids, size_t & all_result_ids_len,
size_t& field_num_results,
const size_t group_limit,
size_t group_limit,
const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,
size_t concurrency,
std::set<uint64>& query_hashes,
const token_ordering token_order = FREQUENCY, const bool prefix = false,
const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD,
const size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
const bool exhaustive_search = false) const;
token_ordering token_order = FREQUENCY, const bool prefix = false,
size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD,
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
bool exhaustive_search = false) const;
void search_candidates(const uint8_t & field_id,
bool field_is_array,
@ -719,5 +726,7 @@ public:
void populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
const std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const;
void remove_matched_tokens(std::vector<std::string>& tokens, const std::set<std::string>& rule_token_set) const;
};

View File

@ -5,10 +5,10 @@
#include "posting_list.h"
#include "threadpool.h"
#define IS_COMPACT_POSTING(x) (((uintptr_t)x & 1))
#define SET_COMPACT_POSTING(x) ((void*)((uintptr_t)x | 1))
#define RAW_POSTING_PTR(x) ((void*)((uintptr_t)x & ~1))
#define COMPACT_POSTING_PTR(x) ((compact_posting_list_t*)((uintptr_t)x & ~1))
#define IS_COMPACT_POSTING(x) (((uintptr_t)(x) & 1))
#define SET_COMPACT_POSTING(x) ((void*)((uintptr_t)(x) | 1))
#define RAW_POSTING_PTR(x) ((void*)((uintptr_t)(x) & ~1))
#define COMPACT_POSTING_PTR(x) ((compact_posting_list_t*)((uintptr_t)(x) & ~1))
struct compact_posting_list_t {
// structured to get 4 byte alignment for `id_offsets`
@ -22,7 +22,7 @@ struct compact_posting_list_t {
static compact_posting_list_t* create(uint32_t num_ids, const uint32_t* ids, const uint32_t* offset_index,
uint32_t num_offsets, uint32_t* offsets);
posting_list_t* to_full_posting_list() const;
[[nodiscard]] posting_list_t* to_full_posting_list() const;
bool contains(uint32_t id);
@ -34,7 +34,7 @@ struct compact_posting_list_t {
uint32_t first_id();
uint32_t last_id();
uint32_t num_ids() const;
[[nodiscard]] uint32_t num_ids() const;
bool contains_atleast_one(const uint32_t* target_ids, size_t target_ids_size);
};

View File

@ -67,10 +67,11 @@ public:
};
struct result_iter_state_t {
uint32_t* excluded_result_ids = nullptr;
size_t excluded_result_ids_size = 0;
uint32_t* filter_ids = nullptr;
size_t filter_ids_length = 0;
const uint32_t* excluded_result_ids = nullptr;
const size_t excluded_result_ids_size = 0;
const uint32_t* filter_ids = nullptr;
const size_t filter_ids_length = 0;
size_t excluded_result_ids_index = 0;
size_t filter_ids_index = 0;

View File

@ -312,4 +312,9 @@ struct StringUtils {
static std::map<std::string, std::string> parse_query_string(const std::string& query);
static std::string float_to_str(float value);
static void replace_all(std::string& subject, const std::string& search,
const std::string& replace);
static std::string trim_curly_spaces(const std::string& str);
};

View File

@ -390,8 +390,8 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t>& excluded_ids, std::vector<const override_t*>& filter_overrides,
std::vector<filter>& filters) const {
std::vector<uint32_t>& excluded_ids,
std::vector<const override_t*>& filter_overrides) const {
std::set<uint32_t> excluded_set;
@ -844,7 +844,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
std::vector<const override_t*> filter_overrides;
std::string query = raw_query;
curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits,
include_ids, excluded_ids, filter_overrides, filters);
include_ids, excluded_ids, filter_overrides);
/*for(auto& kv: include_ids) {
LOG(INFO) << "key: " << kv.first;
@ -906,9 +906,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
}
// search all indices
size_t num_processed = 0;
std::mutex m_process;
std::condition_variable cv_process;
size_t index_id = 0;
search_args* search_params = new search_args(field_query_tokens, weighted_search_fields,
@ -1334,33 +1331,6 @@ void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<K
}
}
void Collection::aggregate_topster(size_t query_index, Topster& agg_topster, Topster* index_topster) {
if(index_topster->distinct) {
for(auto &group_topster_entry: index_topster->group_kv_map) {
Topster* group_topster = group_topster_entry.second;
for(const auto& map_kv: group_topster->kv_map) {
map_kv.second->query_index += query_index;
if(map_kv.second->query_indices != nullptr) {
for(size_t i = 0; i < map_kv.second->query_indices[0]; i++) {
map_kv.second->query_indices[i+1] += query_index;
}
}
agg_topster.add(map_kv.second);
}
}
} else {
for(const auto& map_kv: index_topster->kv_map) {
map_kv.second->query_index += query_index;
if(map_kv.second->query_indices != nullptr) {
for(size_t i = 0; i < map_kv.second->query_indices[0]; i++) {
map_kv.second->query_indices[i+1] += query_index;
}
}
agg_topster.add(map_kv.second);
}
}
}
Option<bool> Collection::get_filter_ids(const std::string & simple_filter_query,
std::vector<std::pair<size_t, uint32_t*>>& index_ids) {
std::shared_lock lock(mutex);

View File

@ -63,7 +63,7 @@ Option<bool> filter::parse_geopoint_filter_value(std::string& raw_value,
Option<bool> filter::parse_filter_query(const string& simple_filter_query,
const std::unordered_map<std::string, field>& search_schema,
Store* store,
const Store* store,
const std::string& doc_id_prefix,
std::vector<filter>& filters) {
@ -332,7 +332,7 @@ Option<bool> filter::parse_filter_query(const string& simple_filter_query,
"`: Unidentified field data type, see docs for supported data types.");
}
if(f.comparators.size() > 0 && f.comparators.front() == NOT_EQUALS) {
if(!f.comparators.empty() && f.comparators.front() == NOT_EQUALS) {
exclude_filters.push_back(f);
} else {
filters.push_back(f);

View File

@ -40,7 +40,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
const std::unordered_map<std::string, field> & search_schema,
const std::map<std::string, field>& facet_schema, const std::unordered_map<std::string, field>& sort_schema,
const std::vector<char>& symbols_to_index, const std::vector<char>& token_separators):
name(name), collection_id(collection_id), thread_pool(thread_pool),
name(name), collection_id(collection_id), store(store), thread_pool(thread_pool),
search_schema(search_schema), facet_schema(facet_schema), sort_schema(sort_schema),
symbols_to_index(symbols_to_index), token_separators(token_separators) {
@ -1079,17 +1079,17 @@ void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
}
}
void aggregate_topster(Topster& agg_topster, Topster* index_topster) {
void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) {
if(index_topster->distinct) {
for(auto &group_topster_entry: index_topster->group_kv_map) {
Topster* group_topster = group_topster_entry.second;
for(const auto& map_kv: group_topster->kv_map) {
agg_topster.add(map_kv.second);
agg_topster->add(map_kv.second);
}
}
} else {
for(const auto& map_kv: index_topster->kv_map) {
agg_topster.add(map_kv.second);
agg_topster->add(map_kv.second);
}
}
}
@ -1186,7 +1186,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
if(topster == nullptr) {
posting_t::block_intersector_t(
posting_lists, iter_state, thread_pool, 100
posting_lists, iter_state, thread_pool, 100
)
.intersect([&](uint32_t seq_id, std::vector<posting_list_t::iterator_t>& its, size_t index) {
result_id_vecs[index].push_back(seq_id);
@ -1197,7 +1197,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
}
posting_t::block_intersector_t(
posting_lists, iter_state, thread_pool, 100
posting_lists, iter_state, thread_pool, 100
)
.intersect([&](uint32_t seq_id, std::vector<posting_list_t::iterator_t>& its, size_t index) {
score_results(sort_fields, searched_queries.size(), field_id, field_is_array,
@ -1215,6 +1215,11 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
size_t num_result_ids = 0;
for(size_t i = 0; i < concurrency; i++) {
if(result_id_vecs[i].empty()) {
// can happen if not all threads produce results
continue;
}
uint32_t* new_all_result_ids = nullptr;
all_result_ids_len = ArrayUtils::or_scalar(*all_result_ids, all_result_ids_len, &result_id_vecs[i][0],
result_id_vecs[i].size(), &new_all_result_ids);
@ -1224,7 +1229,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
num_result_ids += result_id_vecs[i].size();
if(topster != nullptr) {
aggregate_topster(*topster, topsters[i]);
aggregate_topster(topster, topsters[i]);
delete topsters[i];
groups_processed.insert(groups_processed_vec[i].begin(), groups_processed_vec[i].end());
}
@ -1710,7 +1715,7 @@ void Index::concat_topster_ids(Topster* topster, spp::sparse_hash_map<uint64_t,
}
}
void Index::static_filter_query_eval(const override_t* override,
bool Index::static_filter_query_eval(const override_t* override,
std::vector<std::string>& tokens,
std::vector<filter>& filters) const {
@ -1721,42 +1726,34 @@ void Index::static_filter_query_eval(const override_t* override,
Option<bool> filter_op = filter::parse_filter_query(override->filter_by, search_schema,
store, "", filters);
if(!filter_op.ok()) {
return ;
}
if(override->remove_matched_tokens) {
query.replace(query.find(override->rule.query), override->rule.query.size(), "");
if(StringUtils::trim(query).empty()) {
tokens = {"*"};
} else {
tokens.clear();
StringUtils::split(query, tokens, " ");
}
}
return filter_op.ok();
}
return false;
}
bool Index::resolve_override(const std::vector<std::string>& rule_parts, const bool exact_rule_match,
bool Index::resolve_override(const std::vector<std::string>& rule_tokens, const bool exact_rule_match,
const std::vector<std::string>& query_tokens,
token_ordering token_order, std::set<std::string>& absorbed_tokens,
uint32_t*& override_ids, size_t& override_ids_len) const {
std::string& filter_by_clause) const {
bool resolved_override = false;
size_t i = 0, j = 0;
while(i < rule_parts.size()) {
if(rule_parts[i].front() == '{' && rule_parts[i].back() == '}') {
std::unordered_map<std::string, std::vector<std::string>> field_placeholder_tokens;
while(i < rule_tokens.size()) {
if(rule_tokens[i].front() == '{' && rule_tokens[i].back() == '}') {
// found a field placeholder
std::vector<std::string> field_names;
std::string rule_part = rule_parts[i];
std::string rule_part = rule_tokens[i];
field_names.emplace_back(rule_part.erase(0, 1).erase(rule_part.size() - 1));
// skip until we find a non-placeholder token
i++;
while(i < rule_parts.size() && (rule_parts[i].front() == '{' && rule_parts[i].back() == '}')) {
rule_part = rule_parts[i];
while(i < rule_tokens.size() && (rule_tokens[i].front() == '{' && rule_tokens[i].back() == '}')) {
rule_part = rule_tokens[i];
field_names.emplace_back(rule_part.erase(0, 1).erase(rule_part.size() - 1));
i++;
}
@ -1766,12 +1763,12 @@ bool Index::resolve_override(const std::vector<std::string>& rule_parts, const b
std::vector<std::string> matched_tokens;
while(j < query_tokens.size() && (i == rule_parts.size() || rule_parts[i] != query_tokens[j])) {
while(j < query_tokens.size() && (i == rule_tokens.size() || rule_tokens[i] != query_tokens[j])) {
matched_tokens.emplace_back(query_tokens[j]);
j++;
}
if(i < rule_parts.size() && j < query_tokens.size() && rule_parts[i] != query_tokens[j]) {
if(i < rule_tokens.size() && j < query_tokens.size() && rule_tokens[i] != query_tokens[j]) {
// if last token does not match, it means that the query does not match the rule
resolved_override = false;
goto RETURN_EARLY;
@ -1788,14 +1785,34 @@ bool Index::resolve_override(const std::vector<std::string>& rule_parts, const b
for(size_t findex = 0; findex < field_names.size(); findex++) {
const auto& field_name = field_names[findex];
bool slide_window = (findex == 0); // fields following another field should match exactly
resolved_override &= check_for_overrides2(token_order, field_name, slide_window,
override_ids, override_ids_len,
exact_rule_match, matched_tokens, absorbed_tokens);
std::vector<std::string> field_absorbed_tokens;
resolved_override &= check_for_overrides(token_order, field_name, slide_window,
exact_rule_match, matched_tokens, absorbed_tokens,
field_absorbed_tokens);
if(!resolved_override) {
goto RETURN_EARLY;
}
field_placeholder_tokens[field_name] = field_absorbed_tokens;
}
} else {
// rule token is not a placeholder, so we have to skip the query tokens until it matches rule token
while(j < query_tokens.size() && query_tokens[j] != rule_tokens[i]) {
if(exact_rule_match) {
// a single mismatch is enough to fail exact match
return false;
}
j++;
}
// either we have exhausted all query tokens
if(j == query_tokens.size()) {
return false;
}
// or query token matches rule token, so we can proceed
i++;
j++;
}
@ -1804,39 +1821,69 @@ bool Index::resolve_override(const std::vector<std::string>& rule_parts, const b
RETURN_EARLY:
if(!resolved_override || (exact_rule_match && query_tokens.size() != absorbed_tokens.size())) {
delete [] override_ids;
override_ids_len = 0;
return false;
}
return resolved_override;
// replace placeholder with field_absorbed_tokens in rule_tokens
for(const auto& kv: field_placeholder_tokens) {
std::string pattern = "{" + kv.first + "}";
std::string replacement = StringUtils::join(kv.second, " ");
StringUtils::replace_all(filter_by_clause, pattern, replacement);
}
return true;
}
void Index::process_filter_overrides(const std::vector<const override_t*>& filter_overrides,
std::vector<query_tokens_t>& field_query_tokens,
token_ordering token_order,
std::vector<filter>& filters,
uint32_t** filter_ids,
uint32_t& filter_ids_length) const {
std::vector<filter>& filters) const {
size_t orig_filters_size = filters.size();
for(auto& override: filter_overrides) {
if(!override->rule.dynamic_query) {
// simple static filtering: add to filter_by and rewrite query if needed
// we will cover both the original query and the synonym variants
size_t orig_filters_size = filters.size();
// Simple static filtering: add to filter_by and rewrite query if needed.
// Check the original query and then the synonym variants until a rule matches.
bool resolved_override = static_filter_query_eval(override, field_query_tokens[0].q_include_tokens, filters);
static_filter_query_eval(override, field_query_tokens[0].q_include_tokens, filters);
for(auto& syn_tokens: field_query_tokens[0].q_synonyms) {
static_filter_query_eval(override, syn_tokens, filters);
if(!resolved_override) {
// we have not been able to resolve an override, so look at synonyms
for(auto& syn_tokens: field_query_tokens[0].q_synonyms) {
static_filter_query_eval(override, syn_tokens, filters);
if(orig_filters_size != filters.size()) {
// we have been able to resolve an override via synonym, so can stop looking
resolved_override = true;
break;
}
}
}
if(orig_filters_size != filters.size()) {
// means that we have been able to resolve an override
if(resolved_override && override->remove_matched_tokens) {
std::vector<std::string>& tokens = field_query_tokens[0].q_include_tokens;
std::vector<std::string> rule_tokens;
Tokenizer(override->rule.query, true).tokenize(rule_tokens);
std::set<std::string> rule_token_set(rule_tokens.begin(), rule_tokens.end());
remove_matched_tokens(tokens, rule_token_set);
for(auto& syn_tokens: field_query_tokens[0].q_synonyms) {
remove_matched_tokens(syn_tokens, rule_token_set);
}
// copy over for other fields
for(size_t i = 1; i < field_query_tokens.size(); i++) {
field_query_tokens[i] = field_query_tokens[0];
}
}
if(resolved_override) {
return ;
}
} else {
// need to extract placeholder field names from the search query, filter on them and rewrite query
// we will again cover both original query and synonyms
// we will cover both original query and synonyms
std::vector<std::string> rule_parts;
StringUtils::split(override->rule.query, rule_parts, " ");
@ -1845,81 +1892,64 @@ void Index::process_filter_overrides(const std::vector<const override_t*>& filte
uint32_t* field_override_ids = nullptr;
size_t field_override_ids_len = 0;
bool exact_rule_match = override->rule.match == override_t::MATCH_EXACT;
std::string filter_by_clause = override->filter_by;
std::set<std::string> absorbed_tokens;
bool resolved_override = resolve_override(rule_parts, exact_rule_match, query_tokens, token_order,
absorbed_tokens, field_override_ids, field_override_ids_len);
absorbed_tokens, filter_by_clause);
if(resolved_override && override->remove_matched_tokens) {
std::vector<std::string> new_tokens;
for(auto& token: query_tokens) {
if(absorbed_tokens.count(token) == 0) {
new_tokens.emplace_back(token);
if(!resolved_override) {
// try resolving synonym
for(size_t i = 0; i < field_query_tokens[0].q_synonyms.size(); i++) {
absorbed_tokens.clear();
std::vector<std::string>& syn_tokens = field_query_tokens[0].q_synonyms[i];
resolved_override = resolve_override(rule_parts, exact_rule_match, syn_tokens, token_order,
absorbed_tokens, filter_by_clause);
if(resolved_override) {
break;
}
}
for(size_t j=0; j < field_query_tokens.size(); j++) {
if(new_tokens.empty()) {
field_query_tokens[j].q_include_tokens = {"*"};
} else {
field_query_tokens[j].q_include_tokens = new_tokens;
if(resolved_override && override->remove_matched_tokens) {
std::vector<std::string> new_tokens;
for (auto& token: syn_tokens) {
if (absorbed_tokens.count(token) == 0) {
new_tokens.emplace_back(token);
}
}
for(size_t j=0; j < field_query_tokens.size(); j++) {
if(new_tokens.empty()) {
field_query_tokens[j].q_synonyms[i] = {"*"};
} else {
field_query_tokens[j].q_synonyms[i] = new_tokens;
}
}
}
}
}
// now resolve synonyms
if(resolved_override) {
Option<bool> filter_parse_op = filter::parse_filter_query(filter_by_clause, search_schema, store, "",
filters);
if(filter_parse_op.ok()) {
if(override->remove_matched_tokens) {
std::vector<std::string>& tokens = field_query_tokens[0].q_include_tokens;
remove_matched_tokens(tokens, absorbed_tokens);
for(size_t i = 0; i < field_query_tokens[0].q_synonyms.size(); i++) {
uint32_t* synonym_override_ids = nullptr;
size_t synonym_override_ids_len = 0;
absorbed_tokens.clear();
for(auto& syn_tokens: field_query_tokens[0].q_synonyms) {
remove_matched_tokens(syn_tokens, absorbed_tokens);
}
std::vector<std::string>& syn_tokens = field_query_tokens[0].q_synonyms[i];
resolved_override = resolve_override(rule_parts, exact_rule_match, syn_tokens, token_order, absorbed_tokens,
synonym_override_ids, synonym_override_ids_len);
if(resolved_override && override->remove_matched_tokens) {
std::vector<std::string> new_tokens;
for (auto& token: syn_tokens) {
if (absorbed_tokens.count(token) == 0) {
new_tokens.emplace_back(token);
// copy over for other fields
for(size_t i = 1; i < field_query_tokens.size(); i++) {
field_query_tokens[i] = field_query_tokens[0];
}
}
for(size_t j=0; j < field_query_tokens.size(); j++) {
if(new_tokens.empty()) {
field_query_tokens[j].q_synonyms[i] = {"*"};
} else {
field_query_tokens[j].q_synonyms[i] = new_tokens;
}
}
}
// result ids of synonyms will be ORed
uint32_t* syn_field_override_ids = nullptr;
field_override_ids_len = ArrayUtils::or_scalar(field_override_ids, field_override_ids_len,
synonym_override_ids,
synonym_override_ids_len, &syn_field_override_ids);
delete [] synonym_override_ids;
delete [] field_override_ids;
field_override_ids = syn_field_override_ids;
}
if(field_override_ids_len != 0) {
if(filter_ids_length != 0) {
uint32_t* filtered_results = nullptr;
filter_ids_length = ArrayUtils::and_scalar(field_override_ids, field_override_ids_len, *filter_ids,
filter_ids_length, &filtered_results);
delete [] *filter_ids;
delete [] field_override_ids;
*filter_ids = filtered_results;
} else {
*filter_ids = field_override_ids;
filter_ids_length = field_override_ids_len;
}
return ;
@ -1928,10 +1958,26 @@ void Index::process_filter_overrides(const std::vector<const override_t*>& filte
}
}
bool Index::check_for_overrides2(const token_ordering& token_order, const string& field_name, const bool slide_window,
uint32_t*& field_override_ids, size_t& field_override_ids_len,
bool exact_rule_match, std::vector<std::string>& tokens,
std::set<std::string>& absorbed_tokens) const {
void Index::remove_matched_tokens(std::vector<std::string>& tokens, const std::set<std::string>& rule_token_set) const {
std::vector<std::string> new_tokens;
for(std::string& token: tokens) {
if(rule_token_set.count(token) == 0) {
new_tokens.push_back(token);
}
}
if(new_tokens.empty()) {
tokens = {"*"};
} else {
tokens = new_tokens;
}
}
bool Index::check_for_overrides(const token_ordering& token_order, const string& field_name, const bool slide_window,
bool exact_rule_match, std::vector<std::string>& tokens,
std::set<std::string>& absorbed_tokens,
std::vector<std::string>& field_absorbed_tokens) const {
for(size_t window_len = tokens.size(); window_len > 0; window_len--) {
for(size_t start_index = 0; start_index+window_len-1 < tokens.size(); start_index++) {
@ -1961,18 +2007,6 @@ bool Index::check_for_overrides2(const token_ordering& token_order, const string
false, 4, query_hashes, token_order, false, 0, 1, false);
if(result_ids_len != 0) {
if(field_override_ids != nullptr) {
uint32_t* filtered_results = nullptr;
field_override_ids_len = ArrayUtils::and_scalar(field_override_ids, field_override_ids_len, result_ids,
result_ids_len, &filtered_results);
delete [] result_ids;
delete [] field_override_ids;
field_override_ids = filtered_results;
} else {
field_override_ids = result_ids;
field_override_ids_len = result_ids_len;
}
// remove window_tokens from `tokens`
std::vector<std::string> new_tokens;
for(size_t new_i = start_index; new_i < tokens.size(); new_i++) {
@ -1981,6 +2015,7 @@ bool Index::check_for_overrides2(const token_ordering& token_order, const string
new_tokens.emplace_back(token);
} else {
absorbed_tokens.insert(token);
field_absorbed_tokens.emplace_back(token);
}
}
@ -2049,8 +2084,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
std::vector<uint32_t> curated_ids_sorted(curated_ids.begin(), curated_ids.end());
std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end());
process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters,
&filter_ids, filter_ids_length);
process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters);
lock.unlock();

View File

@ -212,6 +212,85 @@ std::string StringUtils::unicode_nfkd(const std::string& text) {
}
}
void StringUtils::replace_all(std::string& subject, const std::string& search, const std::string& replace) {
if(search.empty()) {
return ;
}
size_t pos = 0;
while ((pos = subject.find(search, pos)) != std::string::npos) {
subject.replace(pos, search.length(), replace);
pos += replace.length();
}
}
std::string StringUtils::trim_curly_spaces(const std::string& str) {
std::string left_trimmed;
int i = 0;
bool inside_curly = false;
while(i < str.size()) {
switch (str[i]) {
case '{':
left_trimmed += str[i];
inside_curly = true;
break;
case '}':
left_trimmed += str[i];
inside_curly = false;
break;
case ' ':
if(!inside_curly) {
left_trimmed += str[i];
inside_curly = false;
}
break;
default:
left_trimmed += str[i];
inside_curly = false;
}
i++;
}
std::string right_trimmed;
i = left_trimmed.size()-1;
inside_curly = false;
while(i >= 0) {
switch (left_trimmed[i]) {
case '}':
right_trimmed += left_trimmed[i];
inside_curly = true;
break;
case '{':
right_trimmed += left_trimmed[i];
inside_curly = false;
break;
case ' ':
if(!inside_curly) {
right_trimmed += left_trimmed[i];
inside_curly = false;
}
break;
default:
right_trimmed += left_trimmed[i];
inside_curly = false;
}
i--;
}
std::reverse(right_trimmed.begin(), right_trimmed.end());
return right_trimmed;
}
/*size_t StringUtils::unicode_length(const std::string& bytes) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> utf8conv;
return utf8conv.from_bytes(bytes).size();

View File

@ -315,6 +315,12 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
coll_mul_fields->add_override(override_include);
std::map<std::string, override_t> overrides = coll_mul_fields->get_overrides();
ASSERT_EQ(1, overrides.size());
auto override_json = overrides["include-rule"].to_json();
ASSERT_FALSE(override_json.contains("filter_by"));
ASSERT_FALSE(override_json.contains("remove_matched_tokens"));
auto results = coll_mul_fields->search("not-found", {"title"}, "", {"starring"}, {}, {0}, 10, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
@ -1081,6 +1087,117 @@ TEST_F(CollectionOverrideTest, DynamicFilteringTokensBetweenPlaceholders) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionOverrideTest, DynamicFilteringWithNumericalFilter) {
Collection* coll1;
std::vector<field> fields = {field("name", field_types::STRING, false),
field("category", field_types::STRING, true),
field("brand", field_types::STRING, true),
field("color", field_types::STRING, true),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if (coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
nlohmann::json doc1;
doc1["id"] = "0";
doc1["name"] = "Retro Shoes";
doc1["category"] = "shoes";
doc1["color"] = "yellow";
doc1["brand"] = "Nike";
doc1["points"] = 15;
nlohmann::json doc2;
doc2["id"] = "1";
doc2["name"] = "Baseball Shoes";
doc2["category"] = "shoes";
doc2["color"] = "white";
doc2["brand"] = "Nike";
doc2["points"] = 5;
nlohmann::json doc3;
doc3["id"] = "2";
doc3["name"] = "Running Shoes";
doc3["category"] = "sports";
doc3["color"] = "grey";
doc3["brand"] = "Nike";
doc3["points"] = 5;
nlohmann::json doc4;
doc4["id"] = "3";
doc4["name"] = "Running Shoes";
doc4["category"] = "sports";
doc4["color"] = "grey";
doc4["brand"] = "Adidas";
doc4["points"] = 5;
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
ASSERT_TRUE(coll1->add(doc3.dump()).ok());
ASSERT_TRUE(coll1->add(doc4.dump()).ok());
std::vector<sort_by> sort_fields = {sort_by("_text_match", "DESC"), sort_by("points", "DESC")};
nlohmann::json override_json = {
{"id", "dynamic-cat-filter"},
{
"rule", {
{"query", "popular {brand} shoes"},
{"match", override_t::MATCH_CONTAINS}
}
},
{"remove_matched_tokens", false},
{"filter_by", "brand: {brand} && points:> 10"}
};
override_t override;
auto op = override_t::parse(override_json, "dynamic-cat-filter", override);
ASSERT_TRUE(op.ok());
auto results = coll1->search("popular nike shoes", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10).get();
ASSERT_EQ(4, results["hits"].size());
coll1->add_override(override);
results = coll1->search("popular nike shoes", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10).get();
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
// when overrides are disabled
bool enable_overrides = false;
results = coll1->search("popular nike shoes", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10, 1, FREQUENCY, {false, false, false}, 1,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 1, {}, {}, {}, 0,
"<mark>", "</mark>", {1, 1, 1}, 10000, true, false, enable_overrides).get();
ASSERT_EQ(4, results["hits"].size());
// should not match the defined override
results = coll1->search("running adidas shoes", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10).get();
ASSERT_EQ(4, results["hits"].size());
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][3]["document"]["id"].get<std::string>());
results = coll1->search("adidas", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10).get();
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) {
Collection *coll1;
@ -1103,7 +1220,7 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) {
nlohmann::json doc2;
doc2["id"] = "1";
doc2["name"] = "Track Gym";
doc2["name"] = "Exciting Track Gym";
doc2["category"] = "shoes";
doc2["brand"] = "Adidas";
doc2["points"] = 5;
@ -1121,38 +1238,59 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) {
synonym_t synonym1{"sneakers-shoes", {"sneakers"}, {{"shoes"}} };
synonym_t synonym2{"boots-shoes", {"boots"}, {{"shoes"}} };
synonym_t synonym3{"exciting-amazing", {"exciting"}, {{"amazing"}} };
coll1->add_synonym(synonym1);
coll1->add_synonym(synonym2);
coll1->add_synonym(synonym3);
std::vector<sort_by> sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") };
// with override, results will be different
nlohmann::json override_json = {
{"id", "dynamic-filters"},
{
"rule", {
{"query", "{category}"},
{"match", override_t::MATCH_EXACT}
}
},
{"remove_matched_tokens", true},
{"filter_by", "category: {category}"}
// spaces around field name should still work e.g. "{ field }"
nlohmann::json override_json1 = {
{"id", "dynamic-filters"},
{
"rule", {
{"query", "{ category }"},
{"match", override_t::MATCH_EXACT}
}
},
{"remove_matched_tokens", true},
{"filter_by", "category: {category}"}
};
override_t override;
auto op = override_t::parse(override_json, "dynamic-filters", override);
override_t override1;
auto op = override_t::parse(override_json1, "dynamic-filters", override1);
ASSERT_TRUE(op.ok());
coll1->add_override(override1);
coll1->add_override(override);
std::map<std::string, override_t> overrides = coll1->get_overrides();
ASSERT_EQ(1, overrides.size());
auto override_json = overrides["dynamic-filters"].to_json();
ASSERT_EQ("category: {category}", override_json["filter_by"].get<std::string>());
ASSERT_EQ(true, override_json["remove_matched_tokens"].get<bool>());
nlohmann::json override_json2 = {
{"id", "static-filters"},
{
"rule", {
{"query", "exciting"},
{"match", override_t::MATCH_CONTAINS}
}
},
{"remove_matched_tokens", true},
{"filter_by", "points: [5, 4]"}
};
override_t override2;
op = override_t::parse(override_json2, "static-filters", override2);
ASSERT_TRUE(op.ok());
coll1->add_override(override2);
auto results = coll1->search("sneakers", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10).get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ("2", results["hits"][0]["document"]["id"].get<std::string>());
// keyword does not exist but has a synonym with results
@ -1164,6 +1302,15 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) {
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
// keyword has no override, but synonym's override is used
results = coll1->search("exciting", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10).get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
collectionManager.drop_collection("coll1");
}

View File

@ -255,4 +255,12 @@ TEST(StringUtilsTest, ShouldParseStringifiedList) {
StringUtils::split_to_values(str, strs);
ASSERT_EQ(1, strs.size());
ASSERT_EQ("John Galt", strs[0]);
}
}
TEST(StringUtilsTest, ShouldTrimCurlySpaces) {
ASSERT_EQ("foo {bar}", StringUtils::trim_curly_spaces("foo { bar }"));
ASSERT_EQ("foo {bar}", StringUtils::trim_curly_spaces("foo { bar }"));
ASSERT_EQ("", StringUtils::trim_curly_spaces(""));
ASSERT_EQ("{}", StringUtils::trim_curly_spaces("{ }"));
ASSERT_EQ("foo {bar} {baz}", StringUtils::trim_curly_spaces("foo { bar } { baz}"));
}