diff --git a/include/field.h b/include/field.h index d49d8e38..830b73d8 100644 --- a/include/field.h +++ b/include/field.h @@ -546,8 +546,7 @@ struct facet_count_t { // used to fetch the actual document and value for representation uint32_t doc_id; uint32_t array_pos; - - std::unordered_map query_token_pos; + std::vector tokens; }; struct facet_stats_t { @@ -567,6 +566,14 @@ struct facet { } }; +struct facet_info_t { + // facet hash => resolved tokens + std::unordered_map> hashes; + bool use_facet_query = false; + bool should_compute_stats = false; + field facet_field{"", "", false}; +}; + struct facet_query_t { std::string field_name; std::string query; diff --git a/include/index.h b/include/index.h index 24a68154..c4a2727d 100644 --- a/include/index.h +++ b/include/index.h @@ -380,8 +380,6 @@ class Index { private: mutable std::shared_mutex mutex; - static constexpr const uint64_t FACET_ARRAY_DELIMETER = std::numeric_limits::max(); - std::string name; const uint32_t collection_id; @@ -440,6 +438,7 @@ private: void log_leaves(int cost, const std::string &token, const std::vector &leaves) const; void do_facets(std::vector & facets, facet_query_t & facet_query, + const std::vector& facet_infos, size_t group_limit, const std::vector& group_by_fields, const uint32_t* result_ids, size_t results_size) const; @@ -469,9 +468,10 @@ private: const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, size_t& num_tokens_dropped, - const std::string & field, uint32_t *filter_ids, size_t filter_ids_length, + const field& the_field, const std::string& field_name, + const uint32_t *filter_ids, size_t filter_ids_length, const std::vector& curated_ids, - std::vector & facets, const std::vector & sort_fields, + const std::vector & sort_fields, int num_typos, std::vector> & searched_queries, Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, size_t & all_result_ids_len, @@ -490,7 +490,7 @@ private: void search_candidates(const uint8_t & field_id, bool field_is_array, - uint32_t* filter_ids, size_t filter_ids_length, + const uint32_t* filter_ids, size_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::vector& curated_ids, const std::vector & sort_fields, std::vector & token_to_candidates, @@ -742,5 +742,10 @@ public: std::array*, 3>& field_values) const; static void remove_matched_tokens(std::vector& tokens, const std::set& rule_token_set) ; + + void compute_facet_infos(const std::vector& facets, facet_query_t& facet_query, + const uint32_t* all_result_ids, const size_t& all_result_ids_len, + const std::vector& group_by_fields, + std::vector& facet_infos) const; }; diff --git a/include/posting.h b/include/posting.h index fe49c544..77a2fac7 100644 --- a/include/posting.h +++ b/include/posting.h @@ -105,6 +105,13 @@ public: const std::vector& posting_lists, std::unordered_map>& array_token_positions ); + + static void get_exact_matches(const std::vector& raw_posting_lists, bool field_is_array, + const uint32_t* ids, uint32_t num_ids, + uint32_t*& exact_ids, size_t& num_exact_ids); + + static void get_matching_array_indices(const std::vector& raw_posting_lists, + uint32_t id, std::vector& indices); }; template diff --git a/include/posting_list.h b/include/posting_list.h index 6faeefdc..8f95a0f9 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -79,8 +79,8 @@ public: result_iter_state_t() = default; - result_iter_state_t(uint32_t* excluded_result_ids, size_t excluded_result_ids_size, uint32_t* filter_ids, - size_t filter_ids_length) : excluded_result_ids(excluded_result_ids), + result_iter_state_t(uint32_t* excluded_result_ids, size_t excluded_result_ids_size, + const uint32_t* filter_ids, const size_t filter_ids_length) : excluded_result_ids(excluded_result_ids), excluded_result_ids_size(excluded_result_ids_size), filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} }; @@ -164,6 +164,13 @@ public: ); static bool is_single_token_verbatim_match(const posting_list_t::iterator_t& it, bool field_is_array); + + static void get_exact_matches(std::vector& its, bool field_is_array, + const uint32_t* ids, const uint32_t num_ids, + uint32_t*& exact_ids, size_t& num_exact_ids); + + static void get_matching_array_indices(uint32_t id, std::vector& its, + std::vector& indices); }; template diff --git a/src/collection.cpp b/src/collection.cpp index a5389753..a28775d3 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1174,6 +1174,8 @@ Option Collection::search(const std::string & raw_query, const s facet_hash_counts.emplace_back(kv); } + auto the_field = search_schema.at(a_facet.field_name); + // keep only top K facets auto max_facets = std::min(max_facet_values, facet_hash_counts.size()); std::nth_element(facet_hash_counts.begin(), facet_hash_counts.begin() + max_facets, @@ -1181,7 +1183,11 @@ Option Collection::search(const std::string & raw_query, const s std::vector facet_query_tokens; - StringUtils::split(facet_query.query, facet_query_tokens, " "); + if(the_field.locale.empty() || the_field.locale == "en") { + StringUtils::split(facet_query.query, facet_query_tokens, " "); + } else { + Tokenizer(facet_query.query, true, !the_field.is_string()).tokenize(facet_query_tokens); + } std::vector facet_values; @@ -1207,32 +1213,71 @@ Option Collection::search(const std::string & raw_query, const s continue; } - std::vector tokens; - StringUtils::split(value, tokens, " "); - std::stringstream highlightedss; + std::unordered_map ftoken_pos; - // invert query_pos -> token_pos - spp::sparse_hash_map token_query_pos; - for(auto qtoken_pos: facet_count.query_token_pos) { - token_query_pos.emplace(qtoken_pos.second.pos, qtoken_pos.first); + for(size_t ti = 0; ti < facet_count.tokens.size(); ti++) { + if(the_field.is_bool()) { + if(facet_count.tokens[ti] == "1") { + facet_count.tokens[ti] = "true"; + } else { + facet_count.tokens[ti] = "false"; + } + } + + const std::string& resolved_token = facet_count.tokens[ti]; + ftoken_pos[resolved_token] = ti; } - for(size_t i = 0; i < tokens.size(); i++) { - if(i != 0) { - highlightedss << " "; + const std::string& last_full_q_token = facet_count.tokens.empty() ? "" : facet_count.tokens.back(); + const std::string& last_q_token = facet_query_tokens.empty() ? "" : facet_query_tokens.back(); + + // 2 passes: first identify tokens that need to be highlighted and then construct highlighted text + + Tokenizer tokenizer(value, true, !the_field.is_string()); + std::string raw_token; + size_t raw_token_index = 0, tok_start = 0, tok_end = 0; + + // need an ordered map here to ensure that it is ordered by the key (start offset) + std::map token_offsets; + size_t prefix_token_start_index = 0; + + while(tokenizer.next(raw_token, raw_token_index, tok_start, tok_end)) { + auto token_pos_it = ftoken_pos.find(raw_token); + if(token_pos_it != ftoken_pos.end()) { + token_offsets[tok_start] = tok_end; + if(raw_token == last_full_q_token) { + prefix_token_start_index = tok_start; + } + } + } + + auto offset_it = token_offsets.begin(); + size_t i = 0; + std::stringstream highlightedss; + + while(i < value.size()) { + if(offset_it != token_offsets.end()) { + if (i == offset_it->first) { + highlightedss << highlight_start_tag; + + // loop until end index, accumulate token and complete highlighting + size_t token_len = (i == prefix_token_start_index) ? + std::min(last_full_q_token.size(), last_q_token.size()) : + (offset_it->second - i + 1); + + for(size_t j = 0; j < token_len; j++) { + highlightedss << value[i + j]; + } + + highlightedss << highlight_end_tag; + offset_it++; + i += token_len; + continue; + } } - if(token_query_pos.count(i) != 0) { - size_t query_token_len = facet_query_tokens[token_query_pos[i]].size(); - // handle query token being larger than actual token (typo correction) - query_token_len = std::min(query_token_len, tokens[i].size()); - const std::string & unmarked = tokens[i].substr(query_token_len, std::string::npos); - highlightedss << highlight_start_tag << - tokens[i].substr(0, query_token_len) << - highlight_end_tag << unmarked; - } else { - highlightedss << tokens[i]; - } + highlightedss << value[i]; + i++; } facet_value_t facet_value = {value, highlightedss.str(), facet_count.count}; @@ -1414,7 +1459,9 @@ bool Collection::facet_value_to_string(const facet &a_facet, const facet_count_t } else if(facet_schema.at(a_facet.field_name).type == field_types::FLOAT) { float raw_val = document[a_facet.field_name].get(); value = StringUtils::float_to_str(raw_val); - value.erase ( value.find_last_not_of('0') + 1, std::string::npos ); // remove trailing zeros + if(value != "0") { + value.erase ( value.find_last_not_of('0') + 1, std::string::npos ); // remove trailing zeros + } } else if(facet_schema.at(a_facet.field_name).type == field_types::FLOAT_ARRAY) { float raw_val = document[a_facet.field_name][facet_count.array_pos].get(); value = StringUtils::float_to_str(raw_val); diff --git a/src/field.cpp b/src/field.cpp index 662d236e..4d4e0930 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -292,12 +292,6 @@ Option filter::parse_filter_query(const string& simple_filter_query, NUM_COMPARATOR str_comparator = CONTAINS; if(raw_value[0] == '=') { - if(!_field.facet) { - // EQUALS filtering on string is possible only on facet fields - return Option(400, "To perform exact filtering, filter field `" + - _field.name + "` must be a facet field."); - } - // string filter should be evaluated in strict "equals" mode str_comparator = EQUALS; while(++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); diff --git a/src/index.cpp b/src/index.cpp index 625cb289..4c6c6f00 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -294,15 +294,9 @@ Option Index::index_in_memory(const index_record& record, uint32_t seq art_tree *t = search_index.at(field_pair.second.faceted_name()); - if(field_pair.second.is_array()) { - index_strings_field(points, t, seq_id, is_facet, field_pair.second, - field_index_it->second.offsets, - field_index_it->second.facet_hashes); - } else { - index_strings_field(points, t, seq_id, is_facet, field_pair.second, - field_index_it->second.offsets, - field_index_it->second.facet_hashes); - } + index_strings_field(points, t, seq_id, is_facet, field_pair.second, + field_index_it->second.offsets, + field_index_it->second.facet_hashes); } if(field_pair.second.is_string()) { @@ -762,11 +756,6 @@ void Index::tokenize_string_with_facets(const std::string& text, bool is_facet, continue; } - if(is_facet) { - uint64_t hash = Index::facet_token_hash(a_field, token); - facet_hashes.push_back(hash); - } - token_to_offsets[token].push_back(token_index + 1); last_token = token; } @@ -775,6 +764,11 @@ void Index::tokenize_string_with_facets(const std::string& text, bool is_facet, // push 0 for the last occurring token (used for exact match ranking) token_to_offsets[last_token].push_back(0); } + + if(is_facet) { + uint64_t hash = Index::facet_token_hash(a_field, text); + facet_hashes.push_back(hash); + } } void Index::index_strings_field(const int64_t score, art_tree *t, @@ -824,12 +818,6 @@ void Index::tokenize_string_array_with_facets(const std::vector& st continue; } - if(is_facet) { - uint64_t hash = facet_token_hash(a_field, token); - facet_hashes.push_back(hash); - //LOG(INFO) << "indexing " << token << ", hash:" << hash; - } - token_to_offsets[token].push_back(token_index + 1); token_set.insert(token); last_token = token; @@ -842,7 +830,9 @@ void Index::tokenize_string_array_with_facets(const std::vector& st } if(is_facet) { - facet_hashes.push_back(FACET_ARRAY_DELIMETER); // as a delimiter + uint64_t hash = facet_token_hash(a_field, str); + //LOG(INFO) << "indexing " << token << ", hash:" << hash; + facet_hashes.push_back(hash); } for(auto& the_token: token_set) { @@ -893,84 +883,16 @@ void Index::compute_facet_stats(facet &a_facet, uint64_t raw_value, const std::s } void Index::do_facets(std::vector & facets, facet_query_t & facet_query, + const std::vector& facet_infos, const size_t group_limit, const std::vector& group_by_fields, const uint32_t* result_ids, size_t results_size) const { - - struct facet_info_t { - // facet hash => token position in the query - std::unordered_map fhash_qtoken_pos; - - bool use_facet_query = false; - bool should_compute_stats = false; - field facet_field{"", "", false}; - }; - - std::vector facet_infos(facets.size()); - - for(size_t findex=0; findex < facets.size(); findex++) { - const auto& a_facet = facets[findex]; - - facet_infos[findex].use_facet_query = false; - - const field &facet_field = facet_schema.at(a_facet.field_name); - facet_infos[findex].facet_field = facet_field; - - facet_infos[findex].should_compute_stats = (facet_field.type != field_types::STRING && - facet_field.type != field_types::BOOL && - facet_field.type != field_types::STRING_ARRAY && - facet_field.type != field_types::BOOL_ARRAY); - - if(a_facet.field_name == facet_query.field_name && !facet_query.query.empty()) { - facet_infos[findex].use_facet_query = true; - - if (facet_field.is_bool()) { - if (facet_query.query == "true") { - facet_query.query = "1"; - } else if (facet_query.query == "false") { - facet_query.query = "0"; - } - } - - // for non-string fields, `faceted_name` returns their aliased stringified field name - art_tree *t = search_index.at(facet_field.faceted_name()); - - std::vector query_tokens; - Tokenizer(facet_query.query, true, !facet_field.is_string()).tokenize(query_tokens); - - for (size_t qtoken_index = 0; qtoken_index < query_tokens.size(); qtoken_index++) { - auto &q = query_tokens[qtoken_index]; - - int bounded_cost = (q.size() < 3) ? 0 : 1; - bool prefix_search = (qtoken_index == - (query_tokens.size() - 1)); // only last token must be used as prefix - - std::vector leaves; - - const size_t q_len = prefix_search ? q.length() : q.length() + 1; - art_fuzzy_search(t, (const unsigned char *) q.c_str(), - q_len, 0, bounded_cost, 10000, - token_ordering::MAX_SCORE, prefix_search, nullptr, 0, leaves); - - for (size_t leaf_index = 0; leaf_index < leaves.size(); leaf_index++) { - const auto &leaf = leaves[leaf_index]; - // calculate hash without terminating null char - std::string key_str((const char *) leaf->key, leaf->key_len - 1); - uint64_t hash = facet_token_hash(facet_field, key_str); - - token_pos_cost_t token_pos_cost = {qtoken_index, 0}; - facet_infos[findex].fhash_qtoken_pos.emplace(hash, token_pos_cost); - //printf("%.*s - %llu\n", leaf->key_len, leaf->key, hash); - } - } - } - } - + // assumed that facet fields have already been validated upstream for(size_t findex=0; findex < facets.size(); findex++) { auto& a_facet = facets[findex]; const auto& facet_field = facet_infos[findex].facet_field; const bool use_facet_query = facet_infos[findex].use_facet_query; - const auto& fhash_qtoken_pos = facet_infos[findex].fhash_qtoken_pos; + const auto& fquery_hashes = facet_infos[findex].hashes; const bool should_compute_stats = facet_infos[findex].should_compute_stats; const auto& field_facet_mapping_it = facet_index_v3.find(a_facet.field_name); @@ -988,91 +910,38 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, continue; } - // FORMAT OF VALUES - // String: h1 h2 h3 - // String array: h1 h2 h3 0 h1 0 h1 h2 0 const auto& facet_hashes = facet_hashes_it->second; - const uint64_t distinct_id = group_limit ? get_distinct_id(group_by_fields, doc_seq_id) : 0; - int array_pos = 0; - bool fvalue_found = false; - uint64_t combined_hash = 1; // for hashing the entire facet value (multiple tokens) - - std::unordered_map query_token_positions; - size_t field_token_index = -1; - auto fhashes = facet_hashes.hashes; - for(size_t j = 0; j < facet_hashes.size(); j++) { - if(fhashes[j] != FACET_ARRAY_DELIMETER) { - uint64_t ftoken_hash = fhashes[j]; - field_token_index++; + auto fhash = facet_hashes.hashes[j]; - // reference: https://stackoverflow.com/a/4182771/131050 - // we also include token index to maintain orderliness - combined_hash *= (1779033703 + 2*ftoken_hash*(field_token_index+1)); - - // ftoken_hash is the raw value for numeric fields - if(should_compute_stats) { - compute_facet_stats(a_facet, ftoken_hash, facet_field.type); - } - - const auto fhash_qtoken_pos_it = fhash_qtoken_pos.find(ftoken_hash); - - // not using facet query or this particular facet value is found in facet filter - if(!use_facet_query || fhash_qtoken_pos_it != fhash_qtoken_pos.end()) { - fvalue_found = true; - - if(use_facet_query) { - // map token index to query index (used for highlighting later on) - const token_pos_cost_t& qtoken_pos = fhash_qtoken_pos_it->second; - - // if the query token has already matched another token in the string - // we will replace the position only if the cost is lower - if(query_token_positions.find(qtoken_pos.pos) == query_token_positions.end() || - query_token_positions[qtoken_pos.pos].cost >= qtoken_pos.cost ) { - token_pos_cost_t ftoken_pos_cost = {field_token_index, qtoken_pos.cost}; - query_token_positions[qtoken_pos.pos] = ftoken_pos_cost; - } - } - } + if(should_compute_stats) { + compute_facet_stats(a_facet, fhash, facet_field.type); } - // 0 indicates separator, while the second condition checks for non-array string - if(fhashes[j] == FACET_ARRAY_DELIMETER || (facet_hashes.back() != FACET_ARRAY_DELIMETER && j == facet_hashes.size() - 1)) { - if(!use_facet_query || fvalue_found) { - uint64_t fhash = combined_hash; - - if(a_facet.result_map.count(fhash) == 0) { - a_facet.result_map.emplace(fhash, facet_count_t{0, spp::sparse_hash_set(), - doc_seq_id, 0, - std::unordered_map()}); - } - - facet_count_t& facet_count = a_facet.result_map[fhash]; - - /*LOG(INFO) << "field: " << a_facet.field_name << ", doc id: " << doc_seq_id - << ", hash: " << fhash;*/ - - facet_count.doc_id = doc_seq_id; - facet_count.array_pos = array_pos; - - if(group_limit) { - facet_count.groups.emplace(distinct_id); - } else { - facet_count.count += 1; - } - - if(use_facet_query) { - facet_count.query_token_pos = query_token_positions; - } + if(!use_facet_query || fquery_hashes.find(fhash) != fquery_hashes.end()) { + if(a_facet.result_map.count(fhash) == 0) { + a_facet.result_map.emplace(fhash, facet_count_t{0, spp::sparse_hash_set(), + doc_seq_id, 0, {}}); } - array_pos++; - fvalue_found = false; - combined_hash = 1; - std::unordered_map().swap(query_token_positions); - field_token_index = -1; + facet_count_t& facet_count = a_facet.result_map[fhash]; + + //LOG(INFO) << "field: " << a_facet.field_name << ", doc id: " << doc_seq_id << ", hash: " << fhash; + + facet_count.doc_id = doc_seq_id; + facet_count.array_pos = j; + + if(group_limit) { + facet_count.groups.emplace(distinct_id); + } else { + facet_count.count += 1; + } + + if(use_facet_query) { + facet_count.tokens = fquery_hashes.at(fhash); + } } } } @@ -1095,7 +964,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { } void Index::search_candidates(const uint8_t & field_id, bool field_is_array, - uint32_t* filter_ids, size_t filter_ids_length, + const uint32_t* filter_ids, size_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::vector& curated_ids, const std::vector & sort_fields, @@ -1531,47 +1400,8 @@ void Index::do_filtering(uint32_t*& filter_ids, uint32_t& filter_ids_length, uint32_t* exact_strt_ids = new uint32_t[strt_ids_size]; size_t exact_strt_size = 0; - for(size_t strt_ids_index = 0; strt_ids_index < strt_ids_size; strt_ids_index++) { - uint32_t seq_id = strt_ids[strt_ids_index]; - const auto& fvalues = facet_index_v3.at(f.name)->at(seq_id); - bool found_filter = false; - - if(!f.is_array()) { - found_filter = (posting_lists.size() == fvalues.length); - } else { - uint64_t filter_hash = 1; - - for(size_t sindex=0; sindex < str_tokens.size(); sindex++) { - auto& this_str_token = str_tokens[sindex]; - uint64_t thash = facet_token_hash(f, this_str_token); - filter_hash *= (1779033703 + 2*thash*(sindex+1)); - } - - uint64_t all_fvalue_hash = 1; - size_t ftindex = 0; - - for(size_t findex=0; findex < fvalues.size(); findex++) { - auto fhash = fvalues.hashes[findex]; - if(fhash == FACET_ARRAY_DELIMETER) { - // end of array, check hash - if(all_fvalue_hash == filter_hash) { - found_filter = true; - break; - } - all_fvalue_hash = 1; - ftindex = 0; - } else { - all_fvalue_hash *= (1779033703 + 2*fhash*(ftindex + 1)); - ftindex++; - } - } - } - - if(found_filter) { - exact_strt_ids[exact_strt_size] = seq_id; - exact_strt_size++; - } - } + posting_t::get_exact_matches(posting_lists, f.is_array(), strt_ids, strt_ids_size, + exact_strt_ids, exact_strt_size); delete[] strt_ids; strt_ids = exact_strt_ids; @@ -2000,8 +1830,14 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string& std::set query_hashes; size_t num_toks_dropped = 0; - search_field(0, window_tokens, search_tokens, nullptr, 0, num_toks_dropped, field_name, - nullptr, 0, {}, facets, {}, 2, searched_queries, topster, groups_processed, + + auto field_it = search_schema.find(field_name); + if(field_it == search_schema.end()) { + continue; + } + + search_field(0, window_tokens, search_tokens, nullptr, 0, num_toks_dropped, field_it->second, field_name, + nullptr, 0, {}, {}, 2, searched_queries, topster, groups_processed, &result_ids, result_ids_len, field_num_results, 0, group_by_fields, false, 4, query_hashes, token_order, false, 0, 1, false, 3, 7); @@ -2175,6 +2011,11 @@ void Index::search(std::vector& field_query_tokens, const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - i); const std::string& field_name = search_fields[i].name; + auto field_it = search_schema.find(field_name); + if(field_it == search_schema.end()) { + continue; + } + std::vector query_tokens = q_include_pos_tokens; std::vector search_tokens = q_include_pos_tokens; size_t num_tokens_dropped = 0; @@ -2190,8 +2031,9 @@ void Index::search(std::vector& field_query_tokens, size_t field_num_results = 0; std::set query_hashes; - search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, - field_name, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, + search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, + num_tokens_dropped, field_it->second, field_name, + filter_ids, filter_ids_length, curated_ids_sorted, sort_fields_std, field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, token_order, field_prefix, @@ -2224,7 +2066,7 @@ void Index::search(std::vector& field_query_tokens, all_result_ids, all_result_ids_len, filter_ids, filter_ids_length); } else { search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, - field_name, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, + field_it->second, field_name, filter_ids, filter_ids_length, curated_ids_sorted, sort_fields_std, field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, token_order, field_prefix, @@ -2437,6 +2279,10 @@ void Index::search(std::vector& field_query_tokens, std::mutex m_process; std::condition_variable cv_process; + std::vector facet_infos(facets.size()); + compute_facet_infos(facets, facet_query, all_result_ids, all_result_ids_len, + group_by_fields, facet_infos); + std::vector> facet_batches(num_threads); for(size_t i = 0; i < num_threads; i++) { for(const auto& this_facet: facets) { @@ -2447,6 +2293,8 @@ void Index::search(std::vector& field_query_tokens, size_t num_queued = 0; size_t result_index = 0; + //auto beginF = std::chrono::high_resolution_clock::now(); + for(size_t thread_id = 0; thread_id < num_threads && result_index < all_result_ids_len; thread_id++) { size_t batch_res_len = window_size; @@ -2458,9 +2306,10 @@ void Index::search(std::vector& field_query_tokens, num_queued++; thread_pool->enqueue([this, thread_id, &facet_batches, &facet_query, group_limit, group_by_fields, - batch_result_ids, batch_res_len, &num_processed, &m_process, &cv_process]() { + batch_result_ids, batch_res_len, &facet_infos, + &num_processed, &m_process, &cv_process]() { auto fq = facet_query; - do_facets(facet_batches[thread_id], fq, group_limit, group_by_fields, + do_facets(facet_batches[thread_id], fq, facet_infos, group_limit, group_by_fields, batch_result_ids, batch_res_len); std::unique_lock lock(m_process); num_processed++; @@ -2497,7 +2346,7 @@ void Index::search(std::vector& field_query_tokens, acc_facet.result_map[facet_kv.first].doc_id = facet_kv.second.doc_id; acc_facet.result_map[facet_kv.first].array_pos = facet_kv.second.array_pos; - acc_facet.result_map[facet_kv.first].query_token_pos = facet_kv.second.query_token_pos; + acc_facet.result_map[facet_kv.first].tokens = facet_kv.second.tokens; } if(this_facet.stats.fvcount != 0) { @@ -2508,9 +2357,15 @@ void Index::search(std::vector& field_query_tokens, } } } + + /*long long int timeMillisF = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - beginF).count(); + LOG(INFO) << "Time for faceting: " << timeMillisF;*/ } - do_facets(facets, facet_query, group_limit, group_by_fields, &included_ids[0], included_ids.size()); + std::vector facet_infos(facets.size()); + compute_facet_infos(facets, facet_query, &included_ids[0], included_ids.size(), group_by_fields, facet_infos); + do_facets(facets, facet_query, facet_infos, group_limit, group_by_fields, &included_ids[0], included_ids.size()); all_result_ids_len += curated_topster->size; @@ -2526,6 +2381,141 @@ void Index::search(std::vector& field_query_tokens, //LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms"; } +void Index::compute_facet_infos(const std::vector& facets, facet_query_t& facet_query, + const uint32_t* all_result_ids, const size_t& all_result_ids_len, + const std::vector& group_by_fields, + std::vector& facet_infos) const { + + if(all_result_ids_len == 0) { + return; + } + + for(size_t findex=0; findex < facets.size(); findex++) { + const auto& a_facet = facets[findex]; + + const auto field_facet_mapping_it = facet_index_v3.find(a_facet.field_name); + if(field_facet_mapping_it == facet_index_v3.end()) { + continue; + } + + facet_infos[findex].use_facet_query = false; + + const field &facet_field = facet_schema.at(a_facet.field_name); + facet_infos[findex].facet_field = facet_field; + + facet_infos[findex].should_compute_stats = (facet_field.type != field_types::STRING && + facet_field.type != field_types::BOOL && + facet_field.type != field_types::STRING_ARRAY && + facet_field.type != field_types::BOOL_ARRAY); + + if(a_facet.field_name == facet_query.field_name && !facet_query.query.empty()) { + facet_infos[findex].use_facet_query = true; + + if (facet_field.is_bool()) { + if (facet_query.query == "true") { + facet_query.query = "1"; + } else if (facet_query.query == "false") { + facet_query.query = "0"; + } + } + + //LOG(INFO) << "facet_query.query: " << facet_query.query; + + std::vector query_tokens; + Tokenizer(facet_query.query, true, !facet_field.is_string()).tokenize(query_tokens); + + std::vector search_tokens, qtokens; + + for (size_t qtoken_index = 0; qtoken_index < query_tokens.size(); qtoken_index++) { + search_tokens.emplace_back(token_t{qtoken_index, query_tokens[qtoken_index]}); + qtokens.emplace_back(token_t{qtoken_index, query_tokens[qtoken_index]}); + } + + std::vector> searched_queries; + Topster* topster = nullptr; + spp::sparse_hash_set groups_processed; + uint32_t* field_result_ids = nullptr; + size_t field_result_ids_len = 0; + size_t field_num_results = 0; + std::set query_hashes; + size_t num_toks_dropped = 0; + + search_field(0, qtokens, search_tokens, nullptr, 0, num_toks_dropped, + facet_field, facet_field.faceted_name(), + all_result_ids, all_result_ids_len, {}, {}, 2, searched_queries, topster, groups_processed, + &field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields, + false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, 3, 1000); + + //LOG(INFO) << "searched_queries.size: " << searched_queries.size(); + + // NOTE: `field_result_ids` will consists of IDs across ALL queries in searched_queries + + for(size_t si = 0; si < searched_queries.size(); si++) { + const auto& searched_query = searched_queries[si]; + std::vector searched_tokens; + + std::vector posting_lists; + for(auto leaf: searched_query) { + posting_lists.push_back(leaf->values); + std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); + searched_tokens.push_back(tok); + //LOG(INFO) << "tok: " << tok; + } + + //LOG(INFO) << "si: " << si << ", field_result_ids_len: " << field_result_ids_len; + + for(size_t i = 0; i < std::min(1000, field_result_ids_len); i++) { + uint32_t seq_id = field_result_ids[i]; + + const auto doc_fvalues_it = field_facet_mapping_it->second->find(seq_id); + if(doc_fvalues_it == field_facet_mapping_it->second->end()) { + continue; + } + + bool id_matched = true; + + for(auto pl: posting_lists) { + if(!posting_t::contains(pl, seq_id)) { + // need to ensure that document ID actually contains both searched_query tokens + id_matched = false; + break; + } + } + + if(!id_matched) { + continue; + } + + if(facet_field.is_array()) { + std::vector array_indices; + posting_t::get_matching_array_indices(posting_lists, seq_id, array_indices); + + for(size_t array_index: array_indices) { + if(array_index < doc_fvalues_it->second.length) { + uint64_t hash = doc_fvalues_it->second.hashes[array_index]; + + /*LOG(INFO) << "seq_id: " << seq_id << ", hash: " << hash << ", array index: " + << array_index;*/ + + if(facet_infos[findex].hashes.count(hash) == 0) { + facet_infos[findex].hashes.emplace(hash, searched_tokens); + } + } + } + } else { + uint64_t hash = doc_fvalues_it->second.hashes[0]; + if(facet_infos[findex].hashes.count(hash) == 0) { + facet_infos[findex].hashes.emplace(hash, searched_tokens); + } + } + } + } + + delete [] field_result_ids; + } + } +} + void Index::curate_filtered_ids(const std::vector& filters, const std::set& curated_ids, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, uint32_t*& filter_ids, uint32_t& filter_ids_length, @@ -2642,10 +2632,10 @@ void Index::search_field(const uint8_t & field_id, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, size_t& num_tokens_dropped, - const std::string & field, - uint32_t *filter_ids, size_t filter_ids_length, + const field& the_field, const std::string& field_name, // to handle faceted index + const uint32_t *filter_ids, size_t filter_ids_length, const std::vector& curated_ids, - std::vector & facets, const std::vector & sort_fields, const int num_typos, + const std::vector & sort_fields, const int num_typos, std::vector> & searched_queries, Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, size_t & all_result_ids_len, size_t& field_num_results, @@ -2663,13 +2653,6 @@ void Index::search_field(const uint8_t & field_id, // NOTE: `query_tokens` preserve original tokens, while `search_tokens` could be a result of dropped tokens size_t max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos; - auto field_it = search_schema.find(field); - - if(field_it == search_schema.end()) { - return; - } - - auto& the_field = field_it->second; if(the_field.locale != "" && the_field.locale != "en") { // disable fuzzy trie traversal for non-english locales @@ -2739,7 +2722,7 @@ void Index::search_field(const uint8_t & field_id, const size_t token_len = prefix_search ? (int) token.length() : (int) token.length() + 1; // need less candidates for filtered searches since we already only pick tokens with results - art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len, + art_fuzzy_search(search_index.at(field_name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], num_fuzzy_candidates, token_order, prefix_search, filter_ids, filter_ids_length, leaves, unique_tokens); @@ -2834,7 +2817,7 @@ void Index::search_field(const uint8_t & field_id, } return search_field(field_id, query_tokens, truncated_tokens, exclude_token_ids, exclude_token_ids_size, - num_tokens_dropped, field, filter_ids, filter_ids_length, curated_ids,facets, + num_tokens_dropped, the_field, field_name, filter_ids, filter_ids_length, curated_ids, sort_fields, num_typos,searched_queries, topster, groups_processed, all_result_ids, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, @@ -2884,10 +2867,6 @@ void Index::score_results(const std::vector & sort_fields, const uint16 bool single_exact_query_token, const std::vector& posting_lists) const { - spp::sparse_hash_map* TEXT_MATCH_SENTINEL = &text_match_sentinel_value; - spp::sparse_hash_map* SEQ_ID_SENTINEL = &seq_id_sentinel_value; - spp::sparse_hash_map* GEO_SENTINEL = &geo_sentinel_value; - int64_t geopoint_distances[3]; for(auto& i: geopoint_indices) { @@ -2937,7 +2916,7 @@ void Index::score_results(const std::vector & sort_fields, const uint16 geopoint_distances[i] = dist; // Swap (id -> latlong) index to (id -> distance) index - field_values[i] = GEO_SENTINEL; + field_values[i] = &geo_sentinel_value; } //auto begin = std::chrono::high_resolution_clock::now(); @@ -3001,12 +2980,12 @@ void Index::score_results(const std::vector & sort_fields, const uint16 // avoiding loop if (sort_fields.size() > 0) { - if (field_values[0] == TEXT_MATCH_SENTINEL) { + if (field_values[0] == &text_match_sentinel_value) { scores[0] = int64_t(match_score); match_score_index = 0; - } else if (field_values[0] == SEQ_ID_SENTINEL) { + } else if (field_values[0] == &seq_id_sentinel_value) { scores[0] = seq_id; - } else if(field_values[0] == GEO_SENTINEL) { + } else if(field_values[0] == &geo_sentinel_value) { scores[0] = geopoint_distances[0]; } else { auto it = field_values[0]->find(seq_id); @@ -3019,12 +2998,12 @@ void Index::score_results(const std::vector & sort_fields, const uint16 } if(sort_fields.size() > 1) { - if (field_values[1] == TEXT_MATCH_SENTINEL) { + if (field_values[1] == &text_match_sentinel_value) { scores[1] = int64_t(match_score); match_score_index = 1; - } else if (field_values[1] == SEQ_ID_SENTINEL) { + } else if (field_values[1] == &seq_id_sentinel_value) { scores[1] = seq_id; - } else if(field_values[1] == GEO_SENTINEL) { + } else if(field_values[1] == &geo_sentinel_value) { scores[1] = geopoint_distances[1]; } else { auto it = field_values[1]->find(seq_id); @@ -3037,12 +3016,12 @@ void Index::score_results(const std::vector & sort_fields, const uint16 } if(sort_fields.size() > 2) { - if (field_values[2] == TEXT_MATCH_SENTINEL) { + if (field_values[2] == &text_match_sentinel_value) { scores[2] = int64_t(match_score); match_score_index = 2; - } else if (field_values[2] == SEQ_ID_SENTINEL) { + } else if (field_values[2] == &seq_id_sentinel_value) { scores[2] = seq_id; - } else if(field_values[2] == GEO_SENTINEL) { + } else if(field_values[2] == &geo_sentinel_value) { scores[2] = geopoint_distances[2]; } else { auto it = field_values[2]->find(seq_id); diff --git a/src/posting.cpp b/src/posting.cpp index a0270dc7..587e6d34 100644 --- a/src/posting.cpp +++ b/src/posting.cpp @@ -447,6 +447,46 @@ void posting_t::get_array_token_positions(uint32_t id, const std::vector& } } +void posting_t::get_exact_matches(const std::vector& raw_posting_lists, const bool field_is_array, + const uint32_t* ids, const uint32_t num_ids, + uint32_t*& exact_ids, size_t& num_exact_ids) { + + std::vector plists; + std::vector expanded_plists; + to_expanded_plists(raw_posting_lists, plists, expanded_plists); + + std::vector its; + + for(posting_list_t* pl: plists) { + its.push_back(pl->new_iterator()); + } + + posting_list_t::get_exact_matches(its, field_is_array, ids, num_ids, exact_ids, num_exact_ids); + + for(posting_list_t* expanded_plist: expanded_plists) { + delete expanded_plist; + } +} + +void posting_t::get_matching_array_indices(const std::vector& raw_posting_lists, + uint32_t id, std::vector& indices) { + std::vector plists; + std::vector expanded_plists; + to_expanded_plists(raw_posting_lists, plists, expanded_plists); + + std::vector its; + + for(posting_list_t* pl: plists) { + its.push_back(pl->new_iterator()); + } + + posting_list_t::get_matching_array_indices(id, its, indices); + + for(posting_list_t* expanded_plist: expanded_plists) { + delete expanded_plist; + } +} + void posting_t::block_intersector_t::split_lists(size_t concurrency, std::vector>& partial_its_vec) { const size_t num_blocks = this->plists[0]->num_blocks(); diff --git a/src/posting_list.cpp b/src/posting_list.cpp index bd048cbf..e202d689 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -1,4 +1,5 @@ #include "posting_list.h" +#include #include "for.h" #include "array_utils.h" @@ -977,6 +978,231 @@ bool posting_list_t::contains_atleast_one(const uint32_t* target_ids, size_t tar return false; } +void posting_list_t::get_exact_matches(std::vector& its, const bool field_is_array, + const uint32_t* ids, const uint32_t num_ids, + uint32_t*& exact_ids, size_t& num_exact_ids) { + + size_t exact_id_index = 0; + + if(its.size() == 1) { + for(size_t i = 0; i < num_ids; i++) { + uint32_t id = ids[i]; + if(is_single_token_verbatim_match(its[0], field_is_array)) { + exact_ids[exact_id_index++] = id; + } + } + } else { + + if(!field_is_array) { + for(size_t i = 0; i < num_ids; i++) { + uint32_t id = ids[i]; + bool is_exact_match = true; + + for(int j = its.size()-1; j >= 0; j--) { + posting_list_t::iterator_t& it = its[j]; + it.skip_to(id); + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + is_exact_match = false; + break; + } + + uint32_t* offsets = it.offsets; + + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + if(j == its.size()-1) { + // check if the last query token is the last offset + if(offsets[end_offset_index-1] != 0) { + // not the last token for the document, so skip + is_exact_match = false; + break; + } + } + + // looping handles duplicate query tokens, e.g. "hip hip hurray hurray" + while(start_offset_index < end_offset_index) { + uint32_t offset = offsets[start_offset_index]; + + if(offset == (j + 1)) { + // we have found a matching index, no need to look further + is_exact_match = true; + break; + } + + if(offset > (j + 1)) { + is_exact_match = false; + break; + } + } + + if(!is_exact_match) { + break; + } + } + + if(is_exact_match) { + exact_ids[exact_id_index++] = id; + } + } + } + + else { + // field is an array + + for(size_t i = 0; i < num_ids; i++) { + uint32_t id = ids[i]; + + std::map> array_index_to_token_index; + bool premature_exit = false; + + for(int j = its.size()-1; j >= 0; j--) { + posting_list_t::iterator_t& it = its[j]; + + it.skip_to(id); + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + premature_exit = true; + break; + } + + uint32_t* offsets = it.offsets; + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + int prev_pos = -1; + bool has_atleast_one_last_token = false; + bool found_matching_index = false; + + while(start_offset_index < end_offset_index) { + int pos = offsets[start_offset_index]; + start_offset_index++; + + if(pos == prev_pos) { // indicates end of array index + size_t array_index = (size_t) offsets[start_offset_index]; + + if(start_offset_index+1 < end_offset_index) { + size_t next_offset = (size_t) offsets[start_offset_index + 1]; + if(next_offset == 0) { + // indicates that token is the last token on the doc + has_atleast_one_last_token = true; + start_offset_index++; + } + } + + if(found_matching_index) { + array_index_to_token_index[array_index].set(j+1); + } + + start_offset_index++; // skip current value which is the array index or flag for last index + prev_pos = -1; + continue; + } + + if(pos == (j + 1)) { + // we have found a matching index + found_matching_index = true; + } + + prev_pos = pos; + } + + // check if the last query token is the last offset of ANY array element + if(j == its.size()-1 && !has_atleast_one_last_token) { + premature_exit = true; + break; + } + + if(!found_matching_index) { + // not even a single matching index found: can never be an exact match + premature_exit = true; + break; + } + } + + if(!premature_exit) { + // iterate array index to token index to check if atleast 1 array position contains all tokens + for(auto& kv: array_index_to_token_index) { + if(kv.second.count() == its.size()) { + exact_ids[exact_id_index++] = id; + break; + } + } + } + } + } + } + + num_exact_ids = exact_id_index; +} + +void posting_list_t::get_matching_array_indices(uint32_t id, std::vector& its, + std::vector& indices) { + std::map> array_index_to_token_index; + + for(int j = its.size()-1; j >= 0; j--) { + posting_list_t::iterator_t& it = its[j]; + + it.skip_to(id); + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + return; + } + + uint32_t* offsets = it.offsets; + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + int prev_pos = -1; + while(start_offset_index < end_offset_index) { + int pos = offsets[start_offset_index]; + start_offset_index++; + + if(pos == prev_pos) { // indicates end of array index + size_t array_index = (size_t) offsets[start_offset_index]; + + if(start_offset_index+1 < end_offset_index) { + size_t next_offset = (size_t) offsets[start_offset_index + 1]; + if(next_offset == 0) { + // indicates that token is the last token on the doc + start_offset_index++; + } + } + + array_index_to_token_index[array_index].set(j+1); + start_offset_index++; // skip current value which is the array index or flag for last index + prev_pos = -1; + continue; + } + + prev_pos = pos; + } + } + + // iterate array index to token index to check if atleast 1 array position contains all tokens + for(auto& kv: array_index_to_token_index) { + if(kv.second.count() == its.size()) { + indices.push_back(kv.first); + } + } +} + /* iterator_t operations */ posting_list_t::iterator_t::iterator_t(posting_list_t::block_t* start, posting_list_t::block_t* end): diff --git a/test/collection_faceting_test.cpp b/test/collection_faceting_test.cpp index 7982d041..bbd00c74 100644 --- a/test/collection_faceting_test.cpp +++ b/test/collection_faceting_test.cpp @@ -177,13 +177,14 @@ TEST_F(CollectionFacetingTest, FacetCounts) { results = coll_array_fields->search("*", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), - spp::sparse_hash_set(), 10, "tags: fxne aluminium").get(); + spp::sparse_hash_set(), 10, "tags: fxne platim").get(); ASSERT_EQ(5, results["hits"].size()); ASSERT_EQ(1, results["facet_counts"].size()); ASSERT_STREQ("tags", results["facet_counts"][0]["field_name"].get().c_str()); ASSERT_EQ(1, (int) results["facet_counts"][0]["counts"][0]["count"]); ASSERT_STREQ("FINE PLATINUM", results["facet_counts"][0]["counts"][0]["value"].get().c_str()); + ASSERT_STREQ("FINE PLATINUM", results["facet_counts"][0]["counts"][0]["highlighted"].get().c_str()); // facet with facet filter query matching first token of an array results = coll_array_fields->search("*", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, @@ -218,6 +219,7 @@ TEST_F(CollectionFacetingTest, FacetCounts) { ASSERT_EQ(5, results["hits"].size()); ASSERT_EQ(1, results["facet_counts"].size()); ASSERT_STREQ("age", results["facet_counts"][0]["field_name"].get().c_str()); + ASSERT_EQ(1, (int) results["facet_counts"][0]["counts"][0]["count"]); ASSERT_STREQ("21", results["facet_counts"][0]["counts"][0]["value"].get().c_str()); ASSERT_STREQ("21", results["facet_counts"][0]["counts"][0]["highlighted"].get().c_str()); @@ -238,6 +240,10 @@ TEST_F(CollectionFacetingTest, FacetCounts) { ASSERT_FLOAT_EQ(24.400999426841736, results["facet_counts"][0]["stats"]["sum"].get()); ASSERT_FLOAT_EQ(5, results["facet_counts"][0]["stats"]["total_values"].get()); + // check for "0" case + ASSERT_STREQ("0", results["facet_counts"][0]["counts"][0]["value"].get().c_str()); + ASSERT_EQ(1, results["facet_counts"][0]["counts"][0]["count"].get()); + // facet query on a float field results = coll_array_fields->search("*", query_fields, "", {"rating"}, sort_fields, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, @@ -264,7 +270,6 @@ TEST_F(CollectionFacetingTest, FacetCounts) { {false}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "timestamps: 142189002").get(); - ASSERT_EQ(5, results["hits"].size()); ASSERT_EQ(1, results["facet_counts"].size()); ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); @@ -688,8 +693,6 @@ TEST_F(CollectionFacetingTest, FacetCountOnSimilarStrings) { } TEST_F(CollectionFacetingTest, FacetQueryOnStringWithColon) { - ; - std::vector fields = {field("title", field_types::STRING, true), field("points", field_types::INT32, false)}; @@ -731,3 +734,82 @@ TEST_F(CollectionFacetingTest, FacetQueryOnStringWithColon) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionFacetingTest, FacetQueryOnStringArray) { + Collection* coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("genres", field_types::STRING_ARRAY, true)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if (coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 2, fields, "").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["title"] = "Song 1"; + doc1["genres"] = {"Country Punk Rock", "Country", "Slow"}; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["title"] = "Song 2"; + doc2["genres"] = {"Soft Rock", "Rock", "Electronic"}; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["title"] = "Song 3"; + doc3["genres"] = {"Rockabilly", "Metal"}; + + nlohmann::json doc4; + doc4["id"] = "3"; + doc4["title"] = "Song 4"; + doc4["genres"] = {"Pop Rock", "Rock", "Fast"}; + + nlohmann::json doc5; + doc5["id"] = "4"; + doc5["title"] = "Song 5"; + doc5["genres"] = {"Pop", "Rockabilly", "Fast"}; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + ASSERT_TRUE(coll1->add(doc4.dump()).ok()); + ASSERT_TRUE(coll1->add(doc5.dump()).ok()); + + auto results = coll1->search("*", {}, "", {"genres"}, sort_fields, {0}, 0, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "genres: roc").get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ(5, results["facet_counts"][0]["counts"].size()); + + results = coll1->search("*", {}, "", {"genres"}, sort_fields, {0}, 0, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "genres: soft roc").get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + + results = coll1->search("*", {}, "", {"genres"}, sort_fields, {0}, 0, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "genres: punk roc").get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("Country Punk Rock", results["facet_counts"][0]["counts"][0]["highlighted"].get()); + + results = coll1->search("*", {}, "", {"genres"}, sort_fields, {0}, 0, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "genres: country roc").get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("Country Punk Rock", results["facet_counts"][0]["counts"][0]["highlighted"].get()); + + collectionManager.drop_collection("coll1"); +} diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 09ca4c27..01452837 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -260,10 +260,10 @@ TEST_F(CollectionFilteringTest, FacetFieldStringArrayFiltering) { ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ(1, results["found"].get()); - // don't allow exact filter on non-faceted field - auto res_op = coll_array_fields->search("Jeremy", query_fields, "name:= Jeremy Howard", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}); - ASSERT_FALSE(res_op.ok()); - ASSERT_STREQ("To perform exact filtering, filter field `name` must be a facet field.", res_op.error().c_str()); + // allow exact filter on non-faceted field + results = coll_array_fields->search("Jeremy", query_fields, "name:= Jeremy Howard", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(5, results["hits"].size()); + ASSERT_EQ(5, results["found"].get()); // multi match exact query (OR condition) results = coll_array_fields->search("Jeremy", query_fields, "tags:= [Gold, bronze]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();