diff --git a/include/collection.h b/include/collection.h index e1942d37..2c2bdd06 100644 --- a/include/collection.h +++ b/include/collection.h @@ -207,8 +207,7 @@ private: const std::map>& pinned_hits, const std::vector& hidden_hits, std::map>& include_ids, - std::vector& excluded_ids, std::vector& filter_overrides, - std::vector& filters) const; + std::vector& excluded_ids, std::vector& filter_overrides) const; Option check_and_update_schema(nlohmann::json& document, const DIRTY_VALUES& dirty_values); @@ -334,8 +333,6 @@ public: bool facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document, std::string &value) const; - static void aggregate_topster(size_t query_index, Topster &topster, Topster *index_topster); - static void populate_result_kvs(Topster *topster, std::vector> &result_kvs); void batch_index(std::vector& index_records, std::vector& json_out, size_t &num_indexed); @@ -387,7 +384,7 @@ public: bool enable_overrides=true, const std::string& highlight_fields="", const bool exhaustive_search = false, - size_t search_stop_millis = 200) const; + size_t search_stop_millis = 6000*1000) const; Option get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids); diff --git a/include/field.h b/include/field.h index 67c66168..e10cc8bb 100644 --- a/include/field.h +++ b/include/field.h @@ -478,7 +478,7 @@ struct filter { static Option parse_filter_query(const std::string& simple_filter_query, const std::unordered_map& search_schema, - Store* store, + const Store* store, const std::string& doc_id_prefix, std::vector& filters); }; diff --git a/include/index.h b/include/index.h index 29953e72..b2aad2fd 100644 --- a/include/index.h +++ b/include/index.h @@ -58,7 +58,7 @@ struct override_t { struct add_hit_t { std::string doc_id; - uint32_t position; + uint32_t position = 0; }; struct drop_hit_t { @@ -74,7 +74,7 @@ struct override_t { std::string filter_by; bool remove_matched_tokens = false; - override_t() {} + override_t() = default; static Option parse(const nlohmann::json& override_json, const std::string& id, override_t& override) { if(!override_json.is_object()) { @@ -198,6 +198,8 @@ struct override_t { while(i < override.rule.query.size()) { if(override.rule.query[i] == '}') { override.rule.dynamic_query = true; + // remove spaces around curlies + override.rule.query = StringUtils::trim_curly_spaces(override.rule.query); break; } i++; @@ -231,6 +233,11 @@ struct override_t { override["excludes"].push_back(exclude); } + if(!filter_by.empty()) { + override["filter_by"] = filter_by; + override["remove_matched_tokens"] = remove_matched_tokens; + } + return override; } }; @@ -368,15 +375,15 @@ class Index { private: mutable std::shared_mutex mutex; - ThreadPool* thread_pool; - static constexpr const uint64_t FACET_ARRAY_DELIMETER = std::numeric_limits::max(); std::string name; const uint32_t collection_id; - Store* store; + const Store* store; + + ThreadPool* thread_pool; size_t num_documents; @@ -425,31 +432,31 @@ private: uint32_t& token_bits, uint64& qhash); - void log_leaves(const int cost, const std::string &token, const std::vector &leaves) const; + void log_leaves(int cost, const std::string &token, const std::vector &leaves) const; void do_facets(std::vector & facets, facet_query_t & facet_query, size_t group_limit, const std::vector& group_by_fields, const uint32_t* result_ids, size_t results_size) const; - void static_filter_query_eval(const override_t* override, std::vector& tokens, + bool static_filter_query_eval(const override_t* override, std::vector& tokens, std::vector& filters) const; void process_filter_overrides(const std::vector& filter_overrides, std::vector& field_query_tokens, token_ordering token_order, - std::vector& filters, - uint32_t** filter_ids, - uint32_t& filter_ids_length) const; + std::vector& filters) const; - bool resolve_override(const std::vector& rule_parts, const bool exact_rule_match, + bool resolve_override(const std::vector& rule_tokens, bool exact_rule_match, const std::vector& query_tokens, token_ordering token_order, std::set& absorbed_tokens, - uint32_t*& override_ids, size_t& override_ids_len) const; + std::string& filter_by_clause) const; - bool check_for_overrides2(const token_ordering& token_order, const string& field_name, const bool slide_window, - uint32_t*& field_override_ids, size_t& field_override_ids_len, - bool exact_rule_match, std::vector& tokens, - std::set& absorbed_tokens) const; + bool check_for_overrides(const token_ordering& token_order, const string& field_name, bool slide_window, + bool exact_rule_match, std::vector& tokens, + std::set& absorbed_tokens, + std::vector& field_absorbed_tokens) const; + + static void aggregate_topster(Topster* agg_topster, Topster* index_topster); void search_field(const uint8_t & field_id, std::vector& query_tokens, @@ -460,19 +467,19 @@ private: const std::string & field, uint32_t *filter_ids, size_t filter_ids_length, const std::vector& curated_ids, std::vector & facets, const std::vector & sort_fields, - const int num_typos, std::vector> & searched_queries, + int num_typos, std::vector> & searched_queries, Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, size_t & all_result_ids_len, size_t& field_num_results, - const size_t group_limit, + size_t group_limit, const std::vector& group_by_fields, bool prioritize_exact_match, size_t concurrency, std::set& query_hashes, - const token_ordering token_order = FREQUENCY, const bool prefix = false, - const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD, - const size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD, - const bool exhaustive_search = false) const; + token_ordering token_order = FREQUENCY, const bool prefix = false, + size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD, + size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD, + bool exhaustive_search = false) const; void search_candidates(const uint8_t & field_id, bool field_is_array, @@ -719,5 +726,7 @@ public: void populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, const std::vector& sort_fields_std, std::array*, 3>& field_values) const; + + void remove_matched_tokens(std::vector& tokens, const std::set& rule_token_set) const; }; diff --git a/include/posting.h b/include/posting.h index 85ab60e9..67557937 100644 --- a/include/posting.h +++ b/include/posting.h @@ -5,10 +5,10 @@ #include "posting_list.h" #include "threadpool.h" -#define IS_COMPACT_POSTING(x) (((uintptr_t)x & 1)) -#define SET_COMPACT_POSTING(x) ((void*)((uintptr_t)x | 1)) -#define RAW_POSTING_PTR(x) ((void*)((uintptr_t)x & ~1)) -#define COMPACT_POSTING_PTR(x) ((compact_posting_list_t*)((uintptr_t)x & ~1)) +#define IS_COMPACT_POSTING(x) (((uintptr_t)(x) & 1)) +#define SET_COMPACT_POSTING(x) ((void*)((uintptr_t)(x) | 1)) +#define RAW_POSTING_PTR(x) ((void*)((uintptr_t)(x) & ~1)) +#define COMPACT_POSTING_PTR(x) ((compact_posting_list_t*)((uintptr_t)(x) & ~1)) struct compact_posting_list_t { // structured to get 4 byte alignment for `id_offsets` @@ -22,7 +22,7 @@ struct compact_posting_list_t { static compact_posting_list_t* create(uint32_t num_ids, const uint32_t* ids, const uint32_t* offset_index, uint32_t num_offsets, uint32_t* offsets); - posting_list_t* to_full_posting_list() const; + [[nodiscard]] posting_list_t* to_full_posting_list() const; bool contains(uint32_t id); @@ -34,7 +34,7 @@ struct compact_posting_list_t { uint32_t first_id(); uint32_t last_id(); - uint32_t num_ids() const; + [[nodiscard]] uint32_t num_ids() const; bool contains_atleast_one(const uint32_t* target_ids, size_t target_ids_size); }; diff --git a/include/posting_list.h b/include/posting_list.h index 316f395e..6faeefdc 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -67,10 +67,11 @@ public: }; struct result_iter_state_t { - uint32_t* excluded_result_ids = nullptr; - size_t excluded_result_ids_size = 0; - uint32_t* filter_ids = nullptr; - size_t filter_ids_length = 0; + const uint32_t* excluded_result_ids = nullptr; + const size_t excluded_result_ids_size = 0; + + const uint32_t* filter_ids = nullptr; + const size_t filter_ids_length = 0; size_t excluded_result_ids_index = 0; size_t filter_ids_index = 0; diff --git a/include/string_utils.h b/include/string_utils.h index 372da1ce..7538c8b8 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -312,4 +312,9 @@ struct StringUtils { static std::map parse_query_string(const std::string& query); static std::string float_to_str(float value); + + static void replace_all(std::string& subject, const std::string& search, + const std::string& replace); + + static std::string trim_curly_spaces(const std::string& str); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 020b02d8..ed5bec6b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -390,8 +390,8 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo const std::map>& pinned_hits, const std::vector& hidden_hits, std::map>& include_ids, - std::vector& excluded_ids, std::vector& filter_overrides, - std::vector& filters) const { + std::vector& excluded_ids, + std::vector& filter_overrides) const { std::set excluded_set; @@ -844,7 +844,7 @@ Option Collection::search(const std::string & raw_query, const s std::vector filter_overrides; std::string query = raw_query; curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, - include_ids, excluded_ids, filter_overrides, filters); + include_ids, excluded_ids, filter_overrides); /*for(auto& kv: include_ids) { LOG(INFO) << "key: " << kv.first; @@ -906,9 +906,6 @@ Option Collection::search(const std::string & raw_query, const s } // search all indices - size_t num_processed = 0; - std::mutex m_process; - std::condition_variable cv_process; size_t index_id = 0; search_args* search_params = new search_args(field_query_tokens, weighted_search_fields, @@ -1334,33 +1331,6 @@ void Collection::populate_result_kvs(Topster *topster, std::vectordistinct) { - for(auto &group_topster_entry: index_topster->group_kv_map) { - Topster* group_topster = group_topster_entry.second; - for(const auto& map_kv: group_topster->kv_map) { - map_kv.second->query_index += query_index; - if(map_kv.second->query_indices != nullptr) { - for(size_t i = 0; i < map_kv.second->query_indices[0]; i++) { - map_kv.second->query_indices[i+1] += query_index; - } - } - agg_topster.add(map_kv.second); - } - } - } else { - for(const auto& map_kv: index_topster->kv_map) { - map_kv.second->query_index += query_index; - if(map_kv.second->query_indices != nullptr) { - for(size_t i = 0; i < map_kv.second->query_indices[0]; i++) { - map_kv.second->query_indices[i+1] += query_index; - } - } - agg_topster.add(map_kv.second); - } - } -} - Option Collection::get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids) { std::shared_lock lock(mutex); diff --git a/src/field.cpp b/src/field.cpp index b0ba565e..9fa978a8 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -63,7 +63,7 @@ Option filter::parse_geopoint_filter_value(std::string& raw_value, Option filter::parse_filter_query(const string& simple_filter_query, const std::unordered_map& search_schema, - Store* store, + const Store* store, const std::string& doc_id_prefix, std::vector& filters) { @@ -332,7 +332,7 @@ Option filter::parse_filter_query(const string& simple_filter_query, "`: Unidentified field data type, see docs for supported data types."); } - if(f.comparators.size() > 0 && f.comparators.front() == NOT_EQUALS) { + if(!f.comparators.empty() && f.comparators.front() == NOT_EQUALS) { exclude_filters.push_back(f); } else { filters.push_back(f); diff --git a/src/index.cpp b/src/index.cpp index d0f34d3c..6b09752e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -40,7 +40,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* const std::unordered_map & search_schema, const std::map& facet_schema, const std::unordered_map& sort_schema, const std::vector& symbols_to_index, const std::vector& token_separators): - name(name), collection_id(collection_id), thread_pool(thread_pool), + name(name), collection_id(collection_id), store(store), thread_pool(thread_pool), search_schema(search_schema), facet_schema(facet_schema), sort_schema(sort_schema), symbols_to_index(symbols_to_index), token_separators(token_separators) { @@ -1079,17 +1079,17 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, } } -void aggregate_topster(Topster& agg_topster, Topster* index_topster) { +void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { if(index_topster->distinct) { for(auto &group_topster_entry: index_topster->group_kv_map) { Topster* group_topster = group_topster_entry.second; for(const auto& map_kv: group_topster->kv_map) { - agg_topster.add(map_kv.second); + agg_topster->add(map_kv.second); } } } else { for(const auto& map_kv: index_topster->kv_map) { - agg_topster.add(map_kv.second); + agg_topster->add(map_kv.second); } } } @@ -1186,7 +1186,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, if(topster == nullptr) { posting_t::block_intersector_t( - posting_lists, iter_state, thread_pool, 100 + posting_lists, iter_state, thread_pool, 100 ) .intersect([&](uint32_t seq_id, std::vector& its, size_t index) { result_id_vecs[index].push_back(seq_id); @@ -1197,7 +1197,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, } posting_t::block_intersector_t( - posting_lists, iter_state, thread_pool, 100 + posting_lists, iter_state, thread_pool, 100 ) .intersect([&](uint32_t seq_id, std::vector& its, size_t index) { score_results(sort_fields, searched_queries.size(), field_id, field_is_array, @@ -1215,6 +1215,11 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, size_t num_result_ids = 0; for(size_t i = 0; i < concurrency; i++) { + if(result_id_vecs[i].empty()) { + // can happen if not all threads produce results + continue; + } + uint32_t* new_all_result_ids = nullptr; all_result_ids_len = ArrayUtils::or_scalar(*all_result_ids, all_result_ids_len, &result_id_vecs[i][0], result_id_vecs[i].size(), &new_all_result_ids); @@ -1224,7 +1229,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, num_result_ids += result_id_vecs[i].size(); if(topster != nullptr) { - aggregate_topster(*topster, topsters[i]); + aggregate_topster(topster, topsters[i]); delete topsters[i]; groups_processed.insert(groups_processed_vec[i].begin(), groups_processed_vec[i].end()); } @@ -1710,7 +1715,7 @@ void Index::concat_topster_ids(Topster* topster, spp::sparse_hash_map& tokens, std::vector& filters) const { @@ -1721,42 +1726,34 @@ void Index::static_filter_query_eval(const override_t* override, Option filter_op = filter::parse_filter_query(override->filter_by, search_schema, store, "", filters); - if(!filter_op.ok()) { - return ; - } - - if(override->remove_matched_tokens) { - query.replace(query.find(override->rule.query), override->rule.query.size(), ""); - if(StringUtils::trim(query).empty()) { - tokens = {"*"}; - } else { - tokens.clear(); - StringUtils::split(query, tokens, " "); - } - } + return filter_op.ok(); } + + return false; } -bool Index::resolve_override(const std::vector& rule_parts, const bool exact_rule_match, +bool Index::resolve_override(const std::vector& rule_tokens, const bool exact_rule_match, const std::vector& query_tokens, token_ordering token_order, std::set& absorbed_tokens, - uint32_t*& override_ids, size_t& override_ids_len) const { + std::string& filter_by_clause) const { bool resolved_override = false; size_t i = 0, j = 0; - while(i < rule_parts.size()) { - if(rule_parts[i].front() == '{' && rule_parts[i].back() == '}') { + std::unordered_map> field_placeholder_tokens; + + while(i < rule_tokens.size()) { + if(rule_tokens[i].front() == '{' && rule_tokens[i].back() == '}') { // found a field placeholder std::vector field_names; - std::string rule_part = rule_parts[i]; + std::string rule_part = rule_tokens[i]; field_names.emplace_back(rule_part.erase(0, 1).erase(rule_part.size() - 1)); // skip until we find a non-placeholder token i++; - while(i < rule_parts.size() && (rule_parts[i].front() == '{' && rule_parts[i].back() == '}')) { - rule_part = rule_parts[i]; + while(i < rule_tokens.size() && (rule_tokens[i].front() == '{' && rule_tokens[i].back() == '}')) { + rule_part = rule_tokens[i]; field_names.emplace_back(rule_part.erase(0, 1).erase(rule_part.size() - 1)); i++; } @@ -1766,12 +1763,12 @@ bool Index::resolve_override(const std::vector& rule_parts, const b std::vector matched_tokens; - while(j < query_tokens.size() && (i == rule_parts.size() || rule_parts[i] != query_tokens[j])) { + while(j < query_tokens.size() && (i == rule_tokens.size() || rule_tokens[i] != query_tokens[j])) { matched_tokens.emplace_back(query_tokens[j]); j++; } - if(i < rule_parts.size() && j < query_tokens.size() && rule_parts[i] != query_tokens[j]) { + if(i < rule_tokens.size() && j < query_tokens.size() && rule_tokens[i] != query_tokens[j]) { // if last token does not match, it means that the query does not match the rule resolved_override = false; goto RETURN_EARLY; @@ -1788,14 +1785,34 @@ bool Index::resolve_override(const std::vector& rule_parts, const b for(size_t findex = 0; findex < field_names.size(); findex++) { const auto& field_name = field_names[findex]; bool slide_window = (findex == 0); // fields following another field should match exactly - resolved_override &= check_for_overrides2(token_order, field_name, slide_window, - override_ids, override_ids_len, - exact_rule_match, matched_tokens, absorbed_tokens); + std::vector field_absorbed_tokens; + resolved_override &= check_for_overrides(token_order, field_name, slide_window, + exact_rule_match, matched_tokens, absorbed_tokens, + field_absorbed_tokens); + if(!resolved_override) { goto RETURN_EARLY; } + + field_placeholder_tokens[field_name] = field_absorbed_tokens; } } else { + // rule token is not a placeholder, so we have to skip the query tokens until it matches rule token + while(j < query_tokens.size() && query_tokens[j] != rule_tokens[i]) { + if(exact_rule_match) { + // a single mismatch is enough to fail exact match + return false; + } + j++; + } + + // either we have exhausted all query tokens + if(j == query_tokens.size()) { + return false; + } + + // or query token matches rule token, so we can proceed + i++; j++; } @@ -1804,39 +1821,69 @@ bool Index::resolve_override(const std::vector& rule_parts, const b RETURN_EARLY: if(!resolved_override || (exact_rule_match && query_tokens.size() != absorbed_tokens.size())) { - delete [] override_ids; - override_ids_len = 0; return false; } - return resolved_override; + // replace placeholder with field_absorbed_tokens in rule_tokens + for(const auto& kv: field_placeholder_tokens) { + std::string pattern = "{" + kv.first + "}"; + std::string replacement = StringUtils::join(kv.second, " "); + StringUtils::replace_all(filter_by_clause, pattern, replacement); + } + + return true; } void Index::process_filter_overrides(const std::vector& filter_overrides, std::vector& field_query_tokens, token_ordering token_order, - std::vector& filters, - uint32_t** filter_ids, - uint32_t& filter_ids_length) const { + std::vector& filters) const { + + size_t orig_filters_size = filters.size(); for(auto& override: filter_overrides) { if(!override->rule.dynamic_query) { - // simple static filtering: add to filter_by and rewrite query if needed - // we will cover both the original query and the synonym variants - size_t orig_filters_size = filters.size(); + // Simple static filtering: add to filter_by and rewrite query if needed. + // Check the original query and then the synonym variants until a rule matches. + bool resolved_override = static_filter_query_eval(override, field_query_tokens[0].q_include_tokens, filters); - static_filter_query_eval(override, field_query_tokens[0].q_include_tokens, filters); - for(auto& syn_tokens: field_query_tokens[0].q_synonyms) { - static_filter_query_eval(override, syn_tokens, filters); + if(!resolved_override) { + // we have not been able to resolve an override, so look at synonyms + for(auto& syn_tokens: field_query_tokens[0].q_synonyms) { + static_filter_query_eval(override, syn_tokens, filters); + if(orig_filters_size != filters.size()) { + // we have been able to resolve an override via synonym, so can stop looking + resolved_override = true; + break; + } + } } - if(orig_filters_size != filters.size()) { - // means that we have been able to resolve an override + if(resolved_override && override->remove_matched_tokens) { + std::vector& tokens = field_query_tokens[0].q_include_tokens; + + std::vector rule_tokens; + Tokenizer(override->rule.query, true).tokenize(rule_tokens); + std::set rule_token_set(rule_tokens.begin(), rule_tokens.end()); + + remove_matched_tokens(tokens, rule_token_set); + + for(auto& syn_tokens: field_query_tokens[0].q_synonyms) { + remove_matched_tokens(syn_tokens, rule_token_set); + } + + // copy over for other fields + for(size_t i = 1; i < field_query_tokens.size(); i++) { + field_query_tokens[i] = field_query_tokens[0]; + } + } + + if(resolved_override) { return ; } } else { // need to extract placeholder field names from the search query, filter on them and rewrite query - // we will again cover both original query and synonyms + // we will cover both original query and synonyms std::vector rule_parts; StringUtils::split(override->rule.query, rule_parts, " "); @@ -1845,81 +1892,64 @@ void Index::process_filter_overrides(const std::vector& filte uint32_t* field_override_ids = nullptr; size_t field_override_ids_len = 0; + bool exact_rule_match = override->rule.match == override_t::MATCH_EXACT; + std::string filter_by_clause = override->filter_by; + std::set absorbed_tokens; bool resolved_override = resolve_override(rule_parts, exact_rule_match, query_tokens, token_order, - absorbed_tokens, field_override_ids, field_override_ids_len); + absorbed_tokens, filter_by_clause); - if(resolved_override && override->remove_matched_tokens) { - std::vector new_tokens; - for(auto& token: query_tokens) { - if(absorbed_tokens.count(token) == 0) { - new_tokens.emplace_back(token); + if(!resolved_override) { + // try resolving synonym + + for(size_t i = 0; i < field_query_tokens[0].q_synonyms.size(); i++) { + absorbed_tokens.clear(); + std::vector& syn_tokens = field_query_tokens[0].q_synonyms[i]; + resolved_override = resolve_override(rule_parts, exact_rule_match, syn_tokens, token_order, + absorbed_tokens, filter_by_clause); + + if(resolved_override) { + break; } - } - for(size_t j=0; j < field_query_tokens.size(); j++) { - if(new_tokens.empty()) { - field_query_tokens[j].q_include_tokens = {"*"}; - } else { - field_query_tokens[j].q_include_tokens = new_tokens; + if(resolved_override && override->remove_matched_tokens) { + std::vector new_tokens; + for (auto& token: syn_tokens) { + if (absorbed_tokens.count(token) == 0) { + new_tokens.emplace_back(token); + } + } + + for(size_t j=0; j < field_query_tokens.size(); j++) { + if(new_tokens.empty()) { + field_query_tokens[j].q_synonyms[i] = {"*"}; + } else { + field_query_tokens[j].q_synonyms[i] = new_tokens; + } + } } } } - // now resolve synonyms + if(resolved_override) { + Option filter_parse_op = filter::parse_filter_query(filter_by_clause, search_schema, store, "", + filters); + if(filter_parse_op.ok()) { + if(override->remove_matched_tokens) { + std::vector& tokens = field_query_tokens[0].q_include_tokens; + remove_matched_tokens(tokens, absorbed_tokens); - for(size_t i = 0; i < field_query_tokens[0].q_synonyms.size(); i++) { - uint32_t* synonym_override_ids = nullptr; - size_t synonym_override_ids_len = 0; - absorbed_tokens.clear(); + for(auto& syn_tokens: field_query_tokens[0].q_synonyms) { + remove_matched_tokens(syn_tokens, absorbed_tokens); + } - std::vector& syn_tokens = field_query_tokens[0].q_synonyms[i]; - resolved_override = resolve_override(rule_parts, exact_rule_match, syn_tokens, token_order, absorbed_tokens, - synonym_override_ids, synonym_override_ids_len); - - if(resolved_override && override->remove_matched_tokens) { - std::vector new_tokens; - for (auto& token: syn_tokens) { - if (absorbed_tokens.count(token) == 0) { - new_tokens.emplace_back(token); + // copy over for other fields + for(size_t i = 1; i < field_query_tokens.size(); i++) { + field_query_tokens[i] = field_query_tokens[0]; } } - - for(size_t j=0; j < field_query_tokens.size(); j++) { - if(new_tokens.empty()) { - field_query_tokens[j].q_synonyms[i] = {"*"}; - } else { - field_query_tokens[j].q_synonyms[i] = new_tokens; - } - } - } - - - // result ids of synonyms will be ORed - - uint32_t* syn_field_override_ids = nullptr; - field_override_ids_len = ArrayUtils::or_scalar(field_override_ids, field_override_ids_len, - synonym_override_ids, - synonym_override_ids_len, &syn_field_override_ids); - - delete [] synonym_override_ids; - delete [] field_override_ids; - field_override_ids = syn_field_override_ids; - } - - if(field_override_ids_len != 0) { - if(filter_ids_length != 0) { - uint32_t* filtered_results = nullptr; - filter_ids_length = ArrayUtils::and_scalar(field_override_ids, field_override_ids_len, *filter_ids, - filter_ids_length, &filtered_results); - delete [] *filter_ids; - delete [] field_override_ids; - *filter_ids = filtered_results; - } else { - *filter_ids = field_override_ids; - filter_ids_length = field_override_ids_len; } return ; @@ -1928,10 +1958,26 @@ void Index::process_filter_overrides(const std::vector& filte } } -bool Index::check_for_overrides2(const token_ordering& token_order, const string& field_name, const bool slide_window, - uint32_t*& field_override_ids, size_t& field_override_ids_len, - bool exact_rule_match, std::vector& tokens, - std::set& absorbed_tokens) const { +void Index::remove_matched_tokens(std::vector& tokens, const std::set& rule_token_set) const { + std::vector new_tokens; + + for(std::string& token: tokens) { + if(rule_token_set.count(token) == 0) { + new_tokens.push_back(token); + } + } + + if(new_tokens.empty()) { + tokens = {"*"}; + } else { + tokens = new_tokens; + } +} + +bool Index::check_for_overrides(const token_ordering& token_order, const string& field_name, const bool slide_window, + bool exact_rule_match, std::vector& tokens, + std::set& absorbed_tokens, + std::vector& field_absorbed_tokens) const { for(size_t window_len = tokens.size(); window_len > 0; window_len--) { for(size_t start_index = 0; start_index+window_len-1 < tokens.size(); start_index++) { @@ -1961,18 +2007,6 @@ bool Index::check_for_overrides2(const token_ordering& token_order, const string false, 4, query_hashes, token_order, false, 0, 1, false); if(result_ids_len != 0) { - if(field_override_ids != nullptr) { - uint32_t* filtered_results = nullptr; - field_override_ids_len = ArrayUtils::and_scalar(field_override_ids, field_override_ids_len, result_ids, - result_ids_len, &filtered_results); - delete [] result_ids; - delete [] field_override_ids; - field_override_ids = filtered_results; - } else { - field_override_ids = result_ids; - field_override_ids_len = result_ids_len; - } - // remove window_tokens from `tokens` std::vector new_tokens; for(size_t new_i = start_index; new_i < tokens.size(); new_i++) { @@ -1981,6 +2015,7 @@ bool Index::check_for_overrides2(const token_ordering& token_order, const string new_tokens.emplace_back(token); } else { absorbed_tokens.insert(token); + field_absorbed_tokens.emplace_back(token); } } @@ -2049,8 +2084,7 @@ void Index::search(std::vector& field_query_tokens, std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); - process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters, - &filter_ids, filter_ids_length); + process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters); lock.unlock(); diff --git a/src/string_utils.cpp b/src/string_utils.cpp index e0c3ec16..6a525304 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -212,6 +212,85 @@ std::string StringUtils::unicode_nfkd(const std::string& text) { } } +void StringUtils::replace_all(std::string& subject, const std::string& search, const std::string& replace) { + if(search.empty()) { + return ; + } + + size_t pos = 0; + while ((pos = subject.find(search, pos)) != std::string::npos) { + subject.replace(pos, search.length(), replace); + pos += replace.length(); + } +} + +std::string StringUtils::trim_curly_spaces(const std::string& str) { + std::string left_trimmed; + int i = 0; + bool inside_curly = false; + + while(i < str.size()) { + switch (str[i]) { + case '{': + left_trimmed += str[i]; + inside_curly = true; + break; + + case '}': + left_trimmed += str[i]; + inside_curly = false; + break; + + case ' ': + if(!inside_curly) { + left_trimmed += str[i]; + inside_curly = false; + } + break; + + default: + left_trimmed += str[i]; + inside_curly = false; + } + + i++; + } + + std::string right_trimmed; + i = left_trimmed.size()-1; + inside_curly = false; + + while(i >= 0) { + switch (left_trimmed[i]) { + case '}': + right_trimmed += left_trimmed[i]; + inside_curly = true; + break; + + case '{': + right_trimmed += left_trimmed[i]; + inside_curly = false; + break; + + case ' ': + if(!inside_curly) { + right_trimmed += left_trimmed[i]; + inside_curly = false; + } + break; + + default: + right_trimmed += left_trimmed[i]; + inside_curly = false; + } + + i--; + } + + std::reverse(right_trimmed.begin(), right_trimmed.end()); + return right_trimmed; +} + /*size_t StringUtils::unicode_length(const std::string& bytes) { std::wstring_convert, char32_t> utf8conv; return utf8conv.from_bytes(bytes).size(); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 6702b087..401934cc 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -315,6 +315,12 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) { coll_mul_fields->add_override(override_include); + std::map overrides = coll_mul_fields->get_overrides(); + ASSERT_EQ(1, overrides.size()); + auto override_json = overrides["include-rule"].to_json(); + ASSERT_FALSE(override_json.contains("filter_by")); + ASSERT_FALSE(override_json.contains("remove_matched_tokens")); + auto results = coll_mul_fields->search("not-found", {"title"}, "", {"starring"}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), @@ -1081,6 +1087,117 @@ TEST_F(CollectionOverrideTest, DynamicFilteringTokensBetweenPlaceholders) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionOverrideTest, DynamicFilteringWithNumericalFilter) { + Collection* coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("category", field_types::STRING, true), + field("brand", field_types::STRING, true), + field("color", field_types::STRING, true), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if (coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Retro Shoes"; + doc1["category"] = "shoes"; + doc1["color"] = "yellow"; + doc1["brand"] = "Nike"; + doc1["points"] = 15; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Baseball Shoes"; + doc2["category"] = "shoes"; + doc2["color"] = "white"; + doc2["brand"] = "Nike"; + doc2["points"] = 5; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Running Shoes"; + doc3["category"] = "sports"; + doc3["color"] = "grey"; + doc3["brand"] = "Nike"; + doc3["points"] = 5; + + nlohmann::json doc4; + doc4["id"] = "3"; + doc4["name"] = "Running Shoes"; + doc4["category"] = "sports"; + doc4["color"] = "grey"; + doc4["brand"] = "Adidas"; + doc4["points"] = 5; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + ASSERT_TRUE(coll1->add(doc4.dump()).ok()); + + std::vector sort_fields = {sort_by("_text_match", "DESC"), sort_by("points", "DESC")}; + + nlohmann::json override_json = { + {"id", "dynamic-cat-filter"}, + { + "rule", { + {"query", "popular {brand} shoes"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", false}, + {"filter_by", "brand: {brand} && points:> 10"} + }; + + override_t override; + auto op = override_t::parse(override_json, "dynamic-cat-filter", override); + ASSERT_TRUE(op.ok()); + + auto results = coll1->search("popular nike shoes", {"name", "category", "brand"}, "", + {}, sort_fields, {2, 2, 2}, 10).get(); + ASSERT_EQ(4, results["hits"].size()); + + coll1->add_override(override); + + results = coll1->search("popular nike shoes", {"name", "category", "brand"}, "", + {}, sort_fields, {2, 2, 2}, 10).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + + // when overrides are disabled + + bool enable_overrides = false; + results = coll1->search("popular nike shoes", {"name", "category", "brand"}, "", + {}, sort_fields, {2, 2, 2}, 10, 1, FREQUENCY, {false, false, false}, 1, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 1, {}, {}, {}, 0, + "", "", {1, 1, 1}, 10000, true, false, enable_overrides).get(); + ASSERT_EQ(4, results["hits"].size()); + + // should not match the defined override + + results = coll1->search("running adidas shoes", {"name", "category", "brand"}, "", + {}, sort_fields, {2, 2, 2}, 10).get(); + + ASSERT_EQ(4, results["hits"].size()); + ASSERT_EQ("3", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][3]["document"]["id"].get()); + + results = coll1->search("adidas", {"name", "category", "brand"}, "", + {}, sort_fields, {2, 2, 2}, 10).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("3", results["hits"][0]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) { Collection *coll1; @@ -1103,7 +1220,7 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) { nlohmann::json doc2; doc2["id"] = "1"; - doc2["name"] = "Track Gym"; + doc2["name"] = "Exciting Track Gym"; doc2["category"] = "shoes"; doc2["brand"] = "Adidas"; doc2["points"] = 5; @@ -1121,38 +1238,59 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) { synonym_t synonym1{"sneakers-shoes", {"sneakers"}, {{"shoes"}} }; synonym_t synonym2{"boots-shoes", {"boots"}, {{"shoes"}} }; + synonym_t synonym3{"exciting-amazing", {"exciting"}, {{"amazing"}} }; coll1->add_synonym(synonym1); coll1->add_synonym(synonym2); + coll1->add_synonym(synonym3); std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; - // with override, results will be different - - nlohmann::json override_json = { - {"id", "dynamic-filters"}, - { - "rule", { - {"query", "{category}"}, - {"match", override_t::MATCH_EXACT} - } - }, - {"remove_matched_tokens", true}, - {"filter_by", "category: {category}"} + // spaces around field name should still work e.g. "{ field }" + nlohmann::json override_json1 = { + {"id", "dynamic-filters"}, + { + "rule", { + {"query", "{ category }"}, + {"match", override_t::MATCH_EXACT} + } + }, + {"remove_matched_tokens", true}, + {"filter_by", "category: {category}"} }; - override_t override; - auto op = override_t::parse(override_json, "dynamic-filters", override); + override_t override1; + auto op = override_t::parse(override_json1, "dynamic-filters", override1); ASSERT_TRUE(op.ok()); + coll1->add_override(override1); - coll1->add_override(override); + std::map overrides = coll1->get_overrides(); + ASSERT_EQ(1, overrides.size()); + auto override_json = overrides["dynamic-filters"].to_json(); + ASSERT_EQ("category: {category}", override_json["filter_by"].get()); + ASSERT_EQ(true, override_json["remove_matched_tokens"].get()); + + nlohmann::json override_json2 = { + {"id", "static-filters"}, + { + "rule", { + {"query", "exciting"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", true}, + {"filter_by", "points: [5, 4]"} + }; + + override_t override2; + op = override_t::parse(override_json2, "static-filters", override2); + ASSERT_TRUE(op.ok()); + coll1->add_override(override2); auto results = coll1->search("sneakers", {"name", "category", "brand"}, "", {}, sort_fields, {2, 2, 2}, 10).get(); - ASSERT_EQ(3, results["hits"].size()); - ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); - ASSERT_EQ("2", results["hits"][1]["document"]["id"].get()); - ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); // keyword does not exist but has a synonym with results @@ -1164,6 +1302,15 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithSynonyms) { ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + // keyword has no override, but synonym's override is used + results = coll1->search("exciting", {"name", "category", "brand"}, "", + {}, sort_fields, {2, 2, 2}, 10).get(); + + ASSERT_EQ(2, results["hits"].size()); + + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][1]["document"]["id"].get()); + collectionManager.drop_collection("coll1"); } diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index e2d7b2a9..44422929 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -255,4 +255,12 @@ TEST(StringUtilsTest, ShouldParseStringifiedList) { StringUtils::split_to_values(str, strs); ASSERT_EQ(1, strs.size()); ASSERT_EQ("John Galt", strs[0]); -} \ No newline at end of file +} + +TEST(StringUtilsTest, ShouldTrimCurlySpaces) { + ASSERT_EQ("foo {bar}", StringUtils::trim_curly_spaces("foo { bar }")); + ASSERT_EQ("foo {bar}", StringUtils::trim_curly_spaces("foo { bar }")); + ASSERT_EQ("", StringUtils::trim_curly_spaces("")); + ASSERT_EQ("{}", StringUtils::trim_curly_spaces("{ }")); + ASSERT_EQ("foo {bar} {baz}", StringUtils::trim_curly_spaces("foo { bar } { baz}")); +}