diff --git a/include/collection.h b/include/collection.h index 692e4142..88b8d750 100644 --- a/include/collection.h +++ b/include/collection.h @@ -187,9 +187,8 @@ private: std::vector& new_fields, bool enable_nested_fields); - static bool facet_count_compare(const std::pair& a, - const std::pair& b) { - return std::tie(a.second.count, a.first) > std::tie(b.second.count, b.first); + static bool facet_count_compare(const facet_count_t& a, const facet_count_t& b) { + return std::tie(a.count, a.fhash) > std::tie(b.count, b.fhash); } static bool facet_count_str_compare(const facet_value_t& a, @@ -276,10 +275,10 @@ private: const spp::sparse_hash_set& exclude_fields, tsl::htrie_set& include_fields_full, tsl::htrie_set& exclude_fields_full) const; - + Option get_reference_doc_id(const std::string& ref_collection_name, const uint32_t& seq_id) const; - Option get_reference_field(const std::string & collection_name) const; + Option get_reference_field(const std::string& ref_collection_name) const; static void hide_credential(nlohmann::json& json, const std::string& credential_name); @@ -377,8 +376,10 @@ public: static void remove_flat_fields(nlohmann::json& document); static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, - const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0, - const std::map& reference_filter_results = {}); + const tsl::htrie_set& exclude_names, const std::string& parent_name = "", + size_t depth = 0, + const std::map& reference_filter_results = {}, + Collection *const collection = nullptr, const uint32_t& seq_id = 0); const Index* _get_index() const; diff --git a/include/field.h b/include/field.h index 7160762e..a051c7ee 100644 --- a/include/field.h +++ b/include/field.h @@ -23,6 +23,7 @@ namespace field_types { static const std::string INT64 = "int64"; static const std::string FLOAT = "float"; static const std::string BOOL = "bool"; + static const std::string NIL = "nil"; static const std::string GEOPOINT = "geopoint"; static const std::string STRING_ARRAY = "string[]"; static const std::string INT32_ARRAY = "int32[]"; @@ -434,19 +435,19 @@ struct field { std::vector& fields_vec); static bool flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_array, bool has_obj_array, - const field& the_field, const std::string& flat_name, + bool is_update, const field& the_field, const std::string& flat_name, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields); static Option flatten_field(nlohmann::json& doc, nlohmann::json& obj, const field& the_field, std::vector& path_parts, size_t path_index, bool has_array, - bool has_obj_array, + bool has_obj_array, bool is_update, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields); static Option flatten_doc(nlohmann::json& document, const tsl::htrie_map& nested_fields, const std::unordered_map& dyn_fields, - bool missing_is_ok, std::vector& flattened_fields); + bool is_update, std::vector& flattened_fields); static void compact_nested_fields(tsl::htrie_map& nested_fields); }; @@ -569,6 +570,10 @@ public: struct facet_count_t { uint32_t count = 0; + // for value based faceting, actual value is stored here + std::string fvalue; + // for hash based faceting, hash value is stored here + int64_t fhash; // used to fetch the actual document and value for representation uint32_t doc_id = 0; uint32_t array_pos = 0; @@ -583,9 +588,12 @@ struct facet_stats_t { struct facet { const std::string field_name; - spp::sparse_hash_map result_map; + spp::sparse_hash_map result_map; + spp::sparse_hash_map value_result_map; + // used for facet value query - spp::sparse_hash_map> hash_tokens; + spp::sparse_hash_map> fvalue_tokens; + spp::sparse_hash_map> hash_tokens; // used for faceting grouped results spp::sparse_hash_map> hash_groups; @@ -593,7 +601,7 @@ struct facet { facet_stats_t stats; //dictionary of key=>pair(range_id, range_val) - std::map facet_range_map; + std::map facet_range_map; bool is_range_query; @@ -603,16 +611,14 @@ struct facet { bool is_intersected = false; - bool get_range(std::string key, std::pair& range_pair) - { - if(facet_range_map.empty()) - { + bool get_range(int64_t key, std::pair& range_pair) { + if(facet_range_map.empty()) { LOG (ERROR) << "Facet range is not defined!!!"; } + auto it = facet_range_map.lower_bound(key); - if(it != facet_range_map.end()) - { + if(it != facet_range_map.end()) { range_pair.first = it->first; range_pair.second = it->second; return true; @@ -621,19 +627,19 @@ struct facet { return false; } - explicit facet(const std::string& field_name, - std::map facet_range = {}, bool is_range_q = false) - :field_name(field_name){ - facet_range_map = facet_range; - is_range_query = is_range_q; + explicit facet(const std::string& field_name, std::map facet_range = {}, + bool is_range_q = false) :field_name(field_name), facet_range_map(facet_range), + is_range_query(is_range_q) { } }; struct facet_info_t { // facet hash => resolved tokens - std::unordered_map> hashes; + std::unordered_map> hashes; + std::vector fvalue_searched_tokens; bool use_facet_query = false; bool should_compute_stats = false; + bool use_value_index = false; field facet_field{"", "", false}; }; diff --git a/include/index.h b/include/index.h index 3c1976dd..bc961ec2 100644 --- a/include/index.h +++ b/include/index.h @@ -482,24 +482,8 @@ private: const size_t num_search_fields, std::vector& popular_field_ids); - void numeric_not_equals_filter(num_tree_t* const num_tree, - const int64_t value, - const uint32_t& context_ids_length, - uint32_t* const& context_ids, - size_t& ids_len, - uint32_t*& ids) const; - bool field_is_indexed(const std::string& field_name) const; - void aproximate_numerical_match(num_tree_t* const num_tree, - const NUM_COMPARATOR& comparator, - const int64_t& value, - const int64_t& range_end_value, - uint32_t& filter_ids_length) const; - - void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, - const std::unordered_map> &token_to_offsets) const; - static void tokenize_string(const std::string& text, const field& a_field, const std::vector& symbols_to_index, @@ -744,6 +728,7 @@ public: const size_t facet_query_num_typos, const uint32_t* all_result_ids, const size_t& all_result_ids_len, const std::vector& group_by_fields, + size_t group_limit, bool is_wildcard_no_filter_query, size_t max_candidates, std::vector& facet_infos, facet_index_type_t facet_index_type) const; @@ -940,6 +925,9 @@ public: Option seq_ids_outside_top_k(const std::string& field_name, size_t k, std::vector& outside_seq_ids); + Option get_reference_doc_id_with_lock(const std::string& reference_helper_field_name, + const uint32_t& seq_id) const; + friend class filter_result_iterator_t; }; diff --git a/include/topster.h b/include/topster.h index 1e7a91bc..5985a48c 100644 --- a/include/topster.h +++ b/include/topster.h @@ -25,13 +25,17 @@ struct KV { std::map reference_filter_results; - KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, uint8_t match_score_index, const int64_t *scores, + KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, int8_t match_score_index, const int64_t *scores, std::map reference_filter_results = {}): match_score_index(match_score_index), query_index(queryIndex), array_index(0), key(key), distinct_key(distinct_key), reference_filter_results(std::move(reference_filter_results)) { this->scores[0] = scores[0]; this->scores[1] = scores[1]; this->scores[2] = scores[2]; + + if(match_score_index >= 0) { + this->text_match_score = scores[match_score_index]; + } } KV() = default; diff --git a/src/art.cpp b/src/art.cpp index befdd194..e8a65322 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1489,13 +1489,15 @@ static inline void rotate(int &i, int &j, int &k) { } // -1: return without adding, 0 : continue iteration, 1: return after adding -static inline int fuzzy_search_state(const bool prefix, int key_index, bool last_key_char, - const int query_len, const int* cost_row, int min_cost, int max_cost) { +static inline int fuzzy_search_state(const bool prefix, int key_index, unsigned char p, unsigned char c, + const unsigned char* query, const int query_len, + const int* cost_row, int min_cost, int max_cost) { // There are 2 scenarios: // a) key_len < query_len: "pltninum" (query) on "pst" (key) // b) query_len < key_len: "pst" (query) on "pltninum" (key) + bool last_key_char = (c == '\0'); int key_len = last_key_char ? key_index : key_index + 1; if(last_key_char) { @@ -1527,11 +1529,67 @@ static inline int fuzzy_search_state(const bool prefix, int key_index, bool last } } - // Terminate the search early or continue iterating on the key? - // We have to account for the case that `cost` could momentarily exceed max_cost but resolve later. - // e.g. key=example, query=exZZample, after 5 chars, cost is 3 but drops to 2 at the end. - // But we will limit this for longer keys for performance. - return cost > max_cost && (key_len > 3 ? cost > (max_cost * 2) : true) ? -1 : 0; + /* + Terminate the search early or continue iterating on the key? + We have to account for the case that `cost` could momentarily exceed max_cost but resolve later. + In such cases, we will compare characters in the query with p and/or c to decide. + */ + + if(cost <= max_cost) { + return 0; + } + + if(cost == 2 || cost == 3) { + /* + [1 letter extra] + exam ple + exZa mple + + [1 letter missing] + exam ple + exmp le + + [1 letter missing + transpose] + dacrycystal gia + dacrcyystlg ia + */ + bool letter_more = (key_index+1 < query_len && query[key_index+1] == c); + bool letter_less = (key_index > 0 && query[key_index-1] == c); + if(letter_more || letter_less) { + return 0; + } + } + + if(cost == 3 || cost == 4) { + /* + [2 letter extra] + exam ple + eTxT ample + + abbviat ion + abbrevi ation + */ + + bool extra_matching_letters = (key_index + 1 < query_len && p == query[key_index + 1] && + key_index + 2 < query_len && c == query[key_index + 2]); + + if(extra_matching_letters) { + return 0; + } + + /* + [2 letter missing] + exam ple + expl e + */ + + bool two_letter_less = (key_index > 1 && query[key_index-2] == c); + if(two_letter_less) { + return 0; + } + } + + return -1; } static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node *n, int depth, const unsigned char *term, @@ -1560,10 +1618,9 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * if(!prefix || !last_key_char) { levenshtein_dist(depth, p, c, term, term_len, rows[i], rows[j], rows[k]); rotate(i, j, k); - p = c; } - int action = fuzzy_search_state(prefix, depth, last_key_char, term_len, rows[j], min_cost, max_cost); + int action = fuzzy_search_state(prefix, depth, p, c, term, term_len, rows[j], min_cost, max_cost); if(1 == action) { results.push_back(n); return; @@ -1573,6 +1630,7 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * return; } + p = c; depth++; } @@ -1591,7 +1649,7 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * if(depth >= iter_len) { // when a preceding partial node completely contains the whole leaf (e.g. "[raspberr]y" on "raspberries") - int action = fuzzy_search_state(prefix, depth, true, term_len, rows[j], min_cost, max_cost); + int action = fuzzy_search_state(prefix, depth, '\0', '\0', term, term_len, rows[j], min_cost, max_cost); if(action == 1) { results.push_back(n); } @@ -1611,10 +1669,9 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * printf("cost: %d, depth: %d, term_len: %d\n", temp_cost, depth, term_len); rotate(i, j, k); - p = c; } - int action = fuzzy_search_state(prefix, depth, last_key_char, term_len, rows[j], min_cost, max_cost); + int action = fuzzy_search_state(prefix, depth, p, c, term, term_len, rows[j], min_cost, max_cost); if(action == 1) { results.push_back(n); return; @@ -1624,6 +1681,7 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * return; } + p = c; depth++; } @@ -1640,9 +1698,8 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * levenshtein_dist(depth, p, c, term, term_len, rows[i], rows[j], rows[k]); rotate(i, j, k); - p = c; - int action = fuzzy_search_state(prefix, depth, false, term_len, rows[j], min_cost, max_cost); + int action = fuzzy_search_state(prefix, depth, p, c, term, term_len, rows[j], min_cost, max_cost); if(action == 1) { results.push_back(n); return; @@ -1652,6 +1709,7 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * return; } + p = c; depth++; } @@ -1660,9 +1718,8 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * c = term[depth]; levenshtein_dist(depth, p, c, term, term_len, rows[i], rows[j], rows[k]); rotate(i, j, k); - p = c; - int action = fuzzy_search_state(prefix, depth, false, term_len, rows[j], min_cost, max_cost); + int action = fuzzy_search_state(prefix, depth, p, c, term, term_len, rows[j], min_cost, max_cost); if(action == 1) { results.push_back(n); return; @@ -1672,6 +1729,7 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * return; } + p = c; depth++; partial_len++; } diff --git a/src/collection.cpp b/src/collection.cpp index ecee8eda..65afff8e 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -262,10 +262,6 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::reference] = coll_field.reference; } - if(!coll_field.embed.empty()) { - field_json[fields::embed] = coll_field.embed; - } - fields_arr.push_back(field_json); } @@ -1044,7 +1040,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().embed.count(fields::from) != 0; + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().num_dim > 0; if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -1074,7 +1070,7 @@ Option Collection::extract_field_name(const std::string& field_name, return Option(true); } -Option Collection::search(std::string raw_query, +Option Collection::search(std::string raw_query, const std::vector& raw_search_fields, const std::string & filter_query, const std::vector& facet_fields, const std::vector & sort_fields, const std::vector& num_typos, @@ -1205,6 +1201,7 @@ Option Collection::search(std::string raw_query, std::vector processed_search_fields; std::vector query_by_weights; size_t num_embed_fields = 0; + std::string query = raw_query; for(size_t i = 0; i < raw_search_fields.size(); i++) { const std::string& field_name = raw_search_fields[i]; @@ -1293,6 +1290,11 @@ Option Collection::search(std::string raw_query, } } + // Set query to * if it is semantic search + if(!vector_query.field_name.empty() && processed_search_fields.empty()) { + query = "*"; + } + if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) { std::string error = "Vector query could not find any embedded fields."; return Option(400, error); @@ -1448,7 +1450,7 @@ Option Collection::search(std::string raw_query, size_t max_hits = DEFAULT_TOPSTER_SIZE; // ensure that `max_hits` never exceeds number of documents in collection - if(search_fields.size() <= 1 || raw_query == "*") { + if(search_fields.size() <= 1 || query == "*") { max_hits = std::min(std::max(fetch_size, max_hits), get_num_documents()); } else { max_hits = std::min(std::max(fetch_size, max_hits), get_num_documents()); @@ -1481,7 +1483,6 @@ Option Collection::search(std::string raw_query, StringUtils::split(hidden_hits_str, hidden_hits, ","); std::vector filter_overrides; - std::string query = raw_query; bool filter_curated_hits = false; std::string curated_sort_by; curate_results(query, filter_query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, @@ -1944,7 +1945,8 @@ Option Collection::search(std::string raw_query, exclude_fields_full, "", 0, - field_order_kv->reference_filter_results); + field_order_kv->reference_filter_results, + const_cast(this), get_seq_id_from_key(seq_id_key)); if (!prune_op.ok()) { return Option(prune_op.code(), prune_op.error()); } @@ -1955,10 +1957,10 @@ Option Collection::search(std::string raw_query, if(field_order_kv->match_score_index == CURATED_RECORD_IDENTIFIER) { wrapper_doc["curated"] = true; } else if(field_order_kv->match_score_index >= 0) { - wrapper_doc["text_match"] = field_order_kv->scores[field_order_kv->match_score_index]; + wrapper_doc["text_match"] = field_order_kv->text_match_score; wrapper_doc["text_match_info"] = nlohmann::json::object(); populate_text_match_info(wrapper_doc["text_match_info"], - field_order_kv->scores[field_order_kv->match_score_index], match_type); + field_order_kv->text_match_score, match_type); if(!vector_query.field_name.empty()) { wrapper_doc["hybrid_search_info"] = nlohmann::json::object(); wrapper_doc["hybrid_search_info"]["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]); @@ -2000,9 +2002,11 @@ Option Collection::search(std::string raw_query, result["facet_counts"] = nlohmann::json::array(); // populate facets - for(facet & a_facet: facets) { + for(facet& a_facet: facets) { // Don't return zero counts for a wildcard facet. - if (a_facet.is_wildcard_match && a_facet.result_map.size() == 0) { + if (a_facet.is_wildcard_match && + (((a_facet.is_intersected && a_facet.value_result_map.empty())) || + (!a_facet.is_intersected && a_facet.result_map.empty()))) { continue; } @@ -2019,28 +2023,28 @@ Option Collection::search(std::string raw_query, facet_result["counts"] = nlohmann::json::array(); std::vector facet_values; - std::vector> facet_counts; + std::vector facet_counts; for (const auto & kv : a_facet.result_map) { - facet_counts.emplace_back(std::make_pair(kv.first, kv.second)); + facet_count_t v = kv.second; + v.fhash = kv.first; + facet_counts.emplace_back(v); + } + + for (const auto& kv : a_facet.value_result_map) { + facet_count_t v = kv.second; + v.fvalue = kv.first; + v.fhash = StringUtils::hash_wy(kv.first.c_str(), kv.first.size()); + facet_counts.emplace_back(v); } auto max_facets = std::min(max_facet_values, facet_counts.size()); auto nthElement = max_facets == facet_counts.size() ? max_facets - 1 : max_facets; - std::nth_element(facet_counts.begin(), facet_counts.begin() + nthElement, - facet_counts.end(), [&](const auto& kv1, const auto& kv2) { - size_t a_count = kv1.second.count; - size_t b_count = kv2.second.count; - - size_t a_value_size = UINT64_MAX - kv1.first.size(); - size_t b_value_size = UINT64_MAX - kv2.first.size(); - - return std::tie(a_count, a_value_size) > std::tie(b_count, b_value_size); - }); + std::nth_element(facet_counts.begin(), facet_counts.begin() + nthElement, facet_counts.end(), + Collection::facet_count_compare); if(a_facet.is_range_query){ - for(auto kv : a_facet.result_map){ - + for(const auto& kv : a_facet.result_map){ auto facet_range_iter = a_facet.facet_range_map.find(kv.first); if(facet_range_iter != a_facet.facet_range_map.end()){ auto & facet_count = kv.second; @@ -2058,13 +2062,11 @@ Option Collection::search(std::string raw_query, for(size_t fi = 0; fi < max_facets; fi++) { // remap facet value hash with actual string - auto & kv = facet_counts[fi]; - auto & facet_count = kv.second; - + auto & facet_count = facet_counts[fi]; std::string value; if(a_facet.is_intersected) { - value = kv.first; + value = facet_count.fvalue; //LOG(INFO) << "used intersection"; } else { // fetch actual facet value from representative doc id @@ -2088,7 +2090,8 @@ Option Collection::search(std::string raw_query, } std::unordered_map ftoken_pos; - std::vector& ftokens = a_facet.hash_tokens[kv.first]; + std::vector& ftokens = a_facet.is_intersected ? a_facet.fvalue_tokens[facet_count.fvalue] : + a_facet.hash_tokens[facet_count.fhash]; //LOG(INFO) << "working on hash_tokens for hash " << kv.first << " with size " << ftokens.size(); for(size_t ti = 0; ti < ftokens.size(); ti++) { if(the_field.is_bool()) { @@ -2696,11 +2699,21 @@ Option Collection::get_filter_ids(const std::string& filter_query, filter_ return index->do_filtering_with_lock(filter_tree_root, filter_result, name); } -Option Collection::get_reference_field(const std::string & collection_name) const { +Option Collection::get_reference_doc_id(const std::string& ref_collection_name, const uint32_t& seq_id) const { + auto get_reference_field_op = get_reference_field(ref_collection_name); + if (!get_reference_field_op.ok()) { + return Option(get_reference_field_op.code(), get_reference_field_op.error()); + } + + auto field_name = get_reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX; + return index->get_reference_doc_id_with_lock(field_name, seq_id); +} + +Option Collection::get_reference_field(const std::string& ref_collection_name) const { std::string reference_field_name; for (auto const& pair: reference_fields) { auto reference_pair = pair.second; - if (reference_pair.collection == collection_name) { + if (reference_pair.collection == ref_collection_name) { reference_field_name = pair.first; break; } @@ -2708,7 +2721,7 @@ Option Collection::get_reference_field(const std::string & collecti if (reference_field_name.empty()) { return Option(400, "Could not find any field in `" + name + "` referencing the collection `" - + collection_name + "`."); + + ref_collection_name + "`."); } return Option(reference_field_name); @@ -4043,7 +4056,8 @@ Option Collection::prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name, size_t depth, - const std::map& reference_filter_results) { + const std::map& reference_filter_results, + Collection *const collection, const uint32_t& seq_id) { // doc can only be an object auto it = doc.begin(); while(it != doc.end()) { @@ -4119,12 +4133,8 @@ Option Collection::prune_doc(nlohmann::json& doc, it++; } - if (reference_filter_results.empty()) { - return Option(true); - } - - auto include_reference_it = include_names.equal_prefix_range("$"); - for (auto reference = include_reference_it.first; reference != include_reference_it.second; reference++) { + auto include_reference_it_pair = include_names.equal_prefix_range("$"); + for (auto reference = include_reference_it_pair.first; reference != include_reference_it_pair.second; reference++) { auto ref = reference.key(); size_t parenthesis_index = ref.find('('); @@ -4161,9 +4171,36 @@ Option Collection::prune_doc(nlohmann::json& doc, return Option(include_exclude_op.code(), error_prefix + include_exclude_op.error()); } - if (reference_filter_results.count(ref_collection_name) == 0 || - reference_filter_results.at(ref_collection_name).count == 0) { - // doc has no references. + bool has_filter_reference = reference_filter_results.count(ref_collection_name) > 0; + if (!has_filter_reference) { + if (collection == nullptr) { + continue; + } + + // Reference include_by without join, check if doc itself contains the reference. + auto get_reference_doc_id_op = collection->get_reference_doc_id(ref_collection_name, seq_id); + if (!get_reference_doc_id_op.ok()) { + continue; + } + + auto ref_doc_seq_id = get_reference_doc_id_op.get(); + + nlohmann::json ref_doc; + auto get_doc_op = ref_collection->get_document_from_store(ref_doc_seq_id, ref_doc); + if (!get_doc_op.ok()) { + return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); + } + + auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); + if (!prune_op.ok()) { + return Option(prune_op.code(), error_prefix + prune_op.error()); + } + + doc.update(ref_doc); + continue; + } + + if (has_filter_reference && reference_filter_results.at(ref_collection_name).count == 0) { continue; } @@ -4873,7 +4910,7 @@ Option Collection::parse_facet(const std::string& facet_field, std::vector return Option(400, error); } - std::vector> tupVec; + std::vector> tupVec; auto& range_map = a_facet.facet_range_map; for(const auto& range : result){ @@ -4888,26 +4925,28 @@ Option Collection::parse_facet(const std::string& facet_field, std::vector auto pos2 = range.find(","); auto pos3 = range.find("]"); - std::string lower_range, upper_range; + int64_t lower_range, upper_range; auto lower_range_start = pos1 + 2; auto lower_range_len = pos2 - lower_range_start; auto upper_range_start = pos2 + 1; auto upper_range_len = pos3 - upper_range_start; if(a_field.is_integer()) { - lower_range = range.substr(lower_range_start, lower_range_len); - StringUtils::trim(lower_range); - upper_range = range.substr(upper_range_start, upper_range_len); - StringUtils::trim(upper_range); + std::string lower_range_str = range.substr(lower_range_start, lower_range_len); + StringUtils::trim(lower_range_str); + lower_range = std::stoll(lower_range_str); + std::string upper_range_str = range.substr(upper_range_start, upper_range_len); + StringUtils::trim(upper_range_str); + upper_range = std::stoll(upper_range_str); } else { float val = std::stof(range.substr(pos1 + 2, pos2)); - lower_range = std::to_string(Index::float_to_int64_t(val)); + lower_range = Index::float_to_int64_t(val); val = std::stof(range.substr(pos2 + 1, pos3)); - upper_range = std::to_string(Index::float_to_int64_t(val)); + upper_range = Index::float_to_int64_t(val); } - tupVec.emplace_back(std::make_tuple(lower_range, upper_range, range_val)); + tupVec.emplace_back(lower_range, upper_range, range_val); } //sort the range values so that we can check continuity @@ -5049,9 +5088,10 @@ void Collection::hide_credential(nlohmann::json& json, const std::string& creden // hide api key with * except first 5 chars std::string credential_name_str = json[credential_name]; if(credential_name_str.size() > 5) { - json[credential_name] = credential_name_str.replace(5, credential_name_str.size() - 5, credential_name_str.size() - 5, '*'); + size_t num_chars_to_replace = credential_name_str.size() - 5; + json[credential_name] = credential_name_str.replace(5, num_chars_to_replace, num_chars_to_replace, '*'); } else { - json[credential_name] = credential_name_str.replace(0, credential_name_str.size(), credential_name_str.size(), '*'); + json[credential_name] = "***********"; } } } diff --git a/src/field.cpp b/src/field.cpp index 4bf1fdef..b7f551fa 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -337,18 +337,41 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } bool field::flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_array, bool has_obj_array, - const field& the_field, const std::string& flat_name, + bool is_update, const field& the_field, const std::string& flat_name, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields) { if(value.is_object()) { has_obj_array = has_array; - for(const auto& kv: value.items()) { - flatten_obj(doc, kv.value(), has_array, has_obj_array, the_field, flat_name + "." + kv.key(), - dyn_fields, flattened_fields); + auto it = value.begin(); + while(it != value.end()) { + const std::string& child_field_name = flat_name + "." + it.key(); + if(it.value().is_null()) { + if(has_array) { + doc[child_field_name].push_back(nullptr); + } else { + doc[child_field_name] = nullptr; + } + + field flattened_field; + flattened_field.name = child_field_name; + flattened_field.type = field_types::NIL; + flattened_fields[child_field_name] = flattened_field; + + if(!is_update) { + // update code path requires and takes care of null values + it = value.erase(it); + } else { + it++; + } + } else { + flatten_obj(doc, it.value(), has_array, has_obj_array, is_update, the_field, child_field_name, + dyn_fields, flattened_fields); + it++; + } } } else if(value.is_array()) { for(const auto& kv: value.items()) { - flatten_obj(doc, kv.value(), true, has_obj_array, the_field, flat_name, dyn_fields, flattened_fields); + flatten_obj(doc, kv.value(), true, has_obj_array, is_update, the_field, flat_name, dyn_fields, flattened_fields); } } else { // must be a primitive if(doc.count(flat_name) != 0 && flattened_fields.find(flat_name) == flattened_fields.end()) { @@ -404,7 +427,7 @@ bool field::flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_arr Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, const field& the_field, std::vector& path_parts, size_t path_index, - bool has_array, bool has_obj_array, + bool has_array, bool has_obj_array, bool is_update, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields) { if(path_index == path_parts.size()) { @@ -459,7 +482,8 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons if(detected_type == the_field.type || is_numericaly_valid) { if(the_field.is_object()) { - flatten_obj(doc, obj, has_array, has_obj_array, the_field, the_field.name, dyn_fields, flattened_fields); + flatten_obj(doc, obj, has_array, has_obj_array, is_update, the_field, the_field.name, + dyn_fields, flattened_fields); } else { if(doc.count(the_field.name) != 0 && flattened_fields.find(the_field.name) == flattened_fields.end()) { return Option(true); @@ -502,7 +526,7 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons for(auto& ele: it.value()) { has_obj_array = has_obj_array || ele.is_object(); Option op = flatten_field(doc, ele, the_field, path_parts, path_index + 1, has_array, - has_obj_array, dyn_fields, flattened_fields); + has_obj_array, is_update, dyn_fields, flattened_fields); if(!op.ok()) { return op; } @@ -510,7 +534,7 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons return Option(true); } else { return flatten_field(doc, it.value(), the_field, path_parts, path_index + 1, has_array, has_obj_array, - dyn_fields, flattened_fields); + is_update, dyn_fields, flattened_fields); } } { return Option(404, "Field `" + the_field.name + "` not found."); @@ -520,7 +544,7 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons Option field::flatten_doc(nlohmann::json& document, const tsl::htrie_map& nested_fields, const std::unordered_map& dyn_fields, - bool missing_is_ok, std::vector& flattened_fields) { + bool is_update, std::vector& flattened_fields) { std::unordered_map flattened_fields_map; @@ -534,12 +558,12 @@ Option field::flatten_doc(nlohmann::json& document, } auto op = flatten_field(document, document, nested_field, field_parts, 0, false, false, - dyn_fields, flattened_fields_map); + is_update, dyn_fields, flattened_fields_map); if(op.ok()) { continue; } - if(op.code() == 404 && (missing_is_ok || nested_field.optional)) { + if(op.code() == 404 && (is_update || nested_field.optional)) { continue; } else { return op; @@ -549,7 +573,10 @@ Option field::flatten_doc(nlohmann::json& document, document[".flat"] = nlohmann::json::array(); for(auto& kv: flattened_fields_map) { document[".flat"].push_back(kv.second.name); - flattened_fields.push_back(kv.second); + if(kv.second.type != field_types::NIL) { + // not a real field so we won't add it + flattened_fields.push_back(kv.second); + } } return Option(true); diff --git a/src/http_server.cpp b/src/http_server.cpp index c4f95bbe..ce381c9f 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -570,13 +570,11 @@ int HttpServer::async_req_cb(void *ctx, int is_end_stream) { bool async_req = custom_generator->rpath->async_req; bool is_http_v1 = (0x101 <= request->_req->version && request->_req->version < 0x200); - /* - LOG(INFO) << "async_req_cb, chunk.len=" << chunk.len + /*LOG(INFO) << "async_req_cb, chunk.len=" << chunk.len << ", is_http_v1: " << is_http_v1 - << ", request->req->entity.len=" << request->req->entity.len - << ", content_len: " << request->req->content_length - << ", is_end_stream=" << is_end_stream; - */ + << ", request->req->entity.len=" << request->_req->entity.len + << ", content_len: " << request->_req->content_length + << ", is_end_stream=" << is_end_stream;*/ // disallow specific curl clients from using import call via http2 // detects: https://github.com/curl/curl/issues/1410 diff --git a/src/index.cpp b/src/index.cpp index cd7b2bbd..db429001 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -429,6 +429,8 @@ void Index::validate_and_preprocess(Index *index, continue; } + handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc); + if(do_validation) { Option validation_op = validator_t::validate_index_in_memory(index_rec.doc, index_rec.seq_id, default_sorting_field, @@ -466,7 +468,6 @@ void Index::validate_and_preprocess(Index *index, } } } else { - handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc); if(generate_embeddings) { records_to_embed.push_back(&index_rec); } @@ -943,9 +944,8 @@ void Index::index_field_in_memory(const field& afield, std::vector continue; } - const std::vector& float_vals = record.doc[afield.name].get>(); - try { + const std::vector& float_vals = record.doc[afield.name].get>(); if(afield.vec_dist == cosine) { std::vector normalized_vals(afield.num_dim); hnsw_index_t::normalize_vector(float_vals, normalized_vals); @@ -1264,6 +1264,7 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, const bool use_facet_query = facet_infos[findex].use_facet_query; const auto& fquery_hashes = facet_infos[findex].hashes; const bool should_compute_stats = facet_infos[findex].should_compute_stats; + const bool use_value_index = facet_infos[findex].use_value_index; auto sort_index_it = sort_index.find(a_facet.field_name); @@ -1277,15 +1278,6 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, bool is_wildcard_no_filter_query = is_wildcard_query && no_filters_provided; bool facet_value_index_exists = facet_index_v4->has_value_index(facet_field.name); - // We have to choose between hash and value index: - // 1. Group queries -> requires hash index - // 2. Wildcard + no filters -> use value index - // 3. Very few unique facet values (< 250) -> use value index - // 4. Result match > 50% - bool use_value_index = (group_limit == 0) && ( is_wildcard_no_filter_query || - (results_size > 1000 && num_facet_values < 250) || - (results_size > 1000 && results_size * 2 > total_docs)); - #ifdef TEST_BUILD if(facet_index_type == VALUE) { #else @@ -1303,32 +1295,29 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, //range facet processing if(a_facet.is_range_query) { const auto doc_val = kv.first; - std::pair range_pair {}; - if(a_facet.get_range(doc_val, range_pair)) { + std::pair range_pair {}; + if(a_facet.get_range(std::stoll(doc_val), range_pair)) { const auto& range_id = range_pair.first; facet_count_t& facet_count = a_facet.result_map[range_id]; facet_count.count = kv.second; } } else { if(use_facet_query) { - const auto fquery_hashes_it = fquery_hashes.find(facet_field.name); - if(fquery_hashes_it != fquery_hashes.end()) { - const auto& searched_tokens = fquery_hashes_it->second; - auto facet_str = kv.first; - transform(facet_str.begin(), facet_str.end(), facet_str.begin(), ::tolower); + const auto& searched_tokens = facet_infos[findex].fvalue_searched_tokens; + auto facet_str = kv.first; + transform(facet_str.begin(), facet_str.end(), facet_str.begin(), ::tolower); - for(const auto& val : searched_tokens) { - if(facet_str.find(val) != std::string::npos) { - facet_count_t& facet_count = a_facet.result_map[kv.first]; - facet_count.count = kv.second; + for(const auto& val : searched_tokens) { + if(facet_str.find(val) != std::string::npos) { + facet_count_t& facet_count = a_facet.value_result_map[kv.first]; + facet_count.count = kv.second; - a_facet.hash_tokens[kv.first] = fquery_hashes.at(facet_field.name); - } + a_facet.fvalue_tokens[kv.first] = searched_tokens; } } - + } else { - facet_count_t& facet_count = a_facet.result_map[kv.first]; + facet_count_t& facet_count = a_facet.value_result_map[kv.first]; facet_count.count = kv.second; } } @@ -1339,7 +1328,7 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, compute_facet_stats(a_facet, kv.first, facet_field.type); } } - } + } } else { //LOG(INFO) << "Using hashing to find facets"; bool facet_hash_index_exists = facet_index_v4->has_hash_index(facet_field.name); @@ -1397,18 +1386,17 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, compute_facet_stats(a_facet, fhash, facet_field.type); } - std::string fhash_str = std::to_string(fhash); if(a_facet.is_range_query) { int64_t doc_val = get_doc_val_from_sort_index(sort_index_it, doc_seq_id); - std::pair range_pair {}; - if(a_facet.get_range(std::to_string(doc_val), range_pair)) { + std::pair range_pair {}; + if(a_facet.get_range(doc_val, range_pair)) { const auto& range_id = range_pair.first; facet_count_t& facet_count = a_facet.result_map[range_id]; facet_count.count += 1; } - } else if(!use_facet_query || fquery_hashes.find(fhash_str) != fquery_hashes.end()) { - facet_count_t& facet_count = a_facet.result_map[fhash_str]; + } else if(!use_facet_query || fquery_hashes.find(fhash) != fquery_hashes.end()) { + facet_count_t& facet_count = a_facet.result_map[fhash]; //LOG(INFO) << "field: " << a_facet.field_name << ", doc id: " << doc_seq_id << ", hash: " << fhash; facet_count.doc_id = doc_seq_id; facet_count.array_pos = j; @@ -1419,7 +1407,7 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, } if(use_facet_query) { //LOG (INFO) << "adding hash tokens for hash " << fhash; - a_facet.hash_tokens[fhash_str] = fquery_hashes.at(fhash_str); + a_facet.hash_tokens[fhash] = fquery_hashes.at(fhash); } } } @@ -1664,34 +1652,6 @@ bool Index::field_is_indexed(const std::string& field_name) const { geo_range_index.count(field_name) != 0; } -void Index::aproximate_numerical_match(num_tree_t* const num_tree, - const NUM_COMPARATOR& comparator, - const int64_t& value, - const int64_t& range_end_value, - uint32_t& filter_ids_length) const { - if (comparator == RANGE_INCLUSIVE) { - num_tree->approx_range_inclusive_search_count(value, range_end_value, filter_ids_length); - return; - } - - if (comparator == NOT_EQUALS) { - uint32_t to_exclude_ids_len = 0; - num_tree->approx_search_count(EQUALS, value, to_exclude_ids_len); - - if (to_exclude_ids_len == 0) { - filter_ids_length += seq_ids->num_ids(); - } else if (to_exclude_ids_len >= seq_ids->num_ids()) { - filter_ids_length += 0; - } else { - filter_ids_length += (seq_ids->num_ids() - to_exclude_ids_len); - } - - return; - } - - num_tree->approx_search_count(comparator, value, filter_ids_length); -} - Option Index::do_filtering_with_lock(filter_node_t* const filter_tree_root, filter_result_t& filter_result, const std::string& collection_name) const { @@ -2753,6 +2713,7 @@ Option Index::search(std::vector& field_query_tokens, cons auto result = result_it->second; // old_score + (1 / rank_of_document) * WEIGHT) result->vector_distance = vec_result.second; + result->text_match_score = result->scores[result->match_score_index]; int64_t match_score = float_to_int64_t( (int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT)); @@ -2773,6 +2734,7 @@ Option Index::search(std::vector& field_query_tokens, cons int64_t match_score_index = -1; compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second); KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores); + kv.text_match_score = 0; kv.vector_distance = vec_result.second; if (filter_result_iterator->is_valid && @@ -2812,6 +2774,7 @@ Option Index::search(std::vector& field_query_tokens, cons delete [] excluded_result_ids; bool estimate_facets = (facet_sample_percent < 100 && all_result_ids_len > facet_sample_threshold); + bool is_wildcard_no_filter_query = is_wildcard_query && no_filters_provided; if(!facets.empty()) { const size_t num_threads = 1; @@ -2822,9 +2785,16 @@ Option Index::search(std::vector& field_query_tokens, cons std::mutex m_process; std::condition_variable cv_process; + // We have to choose between hash and value index: + // 1. Group queries -> requires hash index + // 2. Wildcard + no filters -> use value index + // 3. Very few unique facet values (< 250) -> use value index + // 4. Result match > 50% + std::vector facet_infos(facets.size()); compute_facet_infos(facets, facet_query, facet_query_num_typos, all_result_ids, all_result_ids_len, - group_by_fields, max_candidates, facet_infos, facet_index_type); + group_by_fields, group_limit, is_wildcard_no_filter_query, + max_candidates, facet_infos, facet_index_type); std::vector> facet_batches(num_threads); for(size_t i = 0; i < num_threads; i++) { @@ -2889,7 +2859,7 @@ Option Index::search(std::vector& field_query_tokens, cons for(auto & facet_kv: this_facet.result_map) { uint32_t fhash = 0; if(group_limit) { - fhash = std::stoul(facet_kv.first); + fhash = facet_kv.first; // we have to add all group sets acc_facet.hash_groups[fhash].insert( this_facet.hash_groups[fhash].begin(), @@ -2913,6 +2883,24 @@ Option Index::search(std::vector& field_query_tokens, cons acc_facet.hash_tokens[facet_kv.first] = this_facet.hash_tokens[facet_kv.first]; } + for(auto& facet_kv: this_facet.value_result_map) { + size_t count = 0; + if(acc_facet.value_result_map.count(facet_kv.first) == 0) { + // not found, so set it + count = facet_kv.second.count; + } else { + count = acc_facet.value_result_map[facet_kv.first].count + facet_kv.second.count; + } + + acc_facet.value_result_map[facet_kv.first].count = count; + + acc_facet.value_result_map[facet_kv.first].doc_id = facet_kv.second.doc_id; + acc_facet.value_result_map[facet_kv.first].array_pos = facet_kv.second.array_pos; + acc_facet.is_intersected = this_facet.is_intersected; + + acc_facet.fvalue_tokens[facet_kv.first] = this_facet.fvalue_tokens[facet_kv.first]; + } + if(this_facet.stats.fvcount != 0) { acc_facet.stats.fvcount += this_facet.stats.fvcount; acc_facet.stats.fvsum += this_facet.stats.fvsum; @@ -2925,7 +2913,7 @@ Option Index::search(std::vector& field_query_tokens, cons for(auto & acc_facet: facets) { for(auto& facet_kv: acc_facet.result_map) { if(group_limit) { - facet_kv.second.count = acc_facet.hash_groups[std::stoul(facet_kv.first)].size(); + facet_kv.second.count = acc_facet.hash_groups[facet_kv.first].size(); } if(estimate_facets) { @@ -2933,6 +2921,12 @@ Option Index::search(std::vector& field_query_tokens, cons } } + for(auto& facet_kv: acc_facet.value_result_map) { + if(estimate_facets) { + facet_kv.second.count = size_t(double(facet_kv.second.count) * (100.0f / facet_sample_percent)); + } + } + if(estimate_facets) { acc_facet.sampled = true; } @@ -2945,7 +2939,8 @@ Option Index::search(std::vector& field_query_tokens, cons std::vector> found_docs; std::vector facet_infos(facets.size()); compute_facet_infos(facets, facet_query, facet_query_num_typos, - &included_ids_vec[0], included_ids_vec.size(), group_by_fields, + &included_ids_vec[0], included_ids_vec.size(), group_by_fields, + group_limit, is_wildcard_no_filter_query, max_candidates, facet_infos, facet_index_type); do_facets(facets, facet_query, estimate_facets, facet_sample_percent, facet_infos, group_limit, group_by_fields, &included_ids_vec[0], @@ -3729,6 +3724,7 @@ void Index::search_across_fields(const std::vector& query_tokens, if(match_score_index != -1) { kv.scores[match_score_index] = aggregated_score; + kv.text_match_score = aggregated_score; } int ret = topster->add(&kv); @@ -4370,6 +4366,7 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& const size_t facet_query_num_typos, const uint32_t* all_result_ids, const size_t& all_result_ids_len, const std::vector& group_by_fields, + const size_t group_limit, const bool is_wildcard_no_filter_query, const size_t max_candidates, std::vector& facet_infos, facet_index_type_t facet_index_type) const { @@ -4377,13 +4374,14 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& return; } + size_t total_docs = seq_ids->num_ids(); + for(size_t findex=0; findex < facets.size(); findex++) { const auto& a_facet = facets[findex]; - facet_infos[findex].use_facet_query = false; - const field &facet_field = search_schema.at(a_facet.field_name); - facet_infos[findex].facet_field = facet_field; + facet_infos[findex].facet_field = facet_field; + facet_infos[findex].use_facet_query = false; facet_infos[findex].should_compute_stats = (facet_field.type != field_types::STRING && facet_field.type != field_types::BOOL && facet_field.type != field_types::STRING_ARRAY && @@ -4391,6 +4389,13 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& facet_field.type != field_types::INT64 && facet_field.type != field_types::INT64_ARRAY); + size_t num_facet_values = facet_index_v4->get_facet_count(facet_field.name); + facet_infos[findex].use_value_index = (group_limit == 0) && ( is_wildcard_no_filter_query || + (all_result_ids_len > 1000 && num_facet_values < 250) || + (all_result_ids_len > 1000 && all_result_ids_len * 2 > total_docs)); + + bool facet_value_index_exists = facet_index_v4->has_value_index(facet_field.name); + if(a_facet.field_name == facet_query.field_name && !facet_query.query.empty()) { facet_infos[findex].use_facet_query = true; @@ -4450,33 +4455,53 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& //LOG(INFO) << "si: " << si << ", field_result_ids_len: " << field_result_ids_len; - for(size_t i = 0; i < field_result_ids_len; i++) { - uint32_t seq_id = field_result_ids[i]; - bool id_matched = true; - +#ifdef TEST_BUILD + if(facet_index_type == VALUE) { +#else + if(facet_value_index_exists && facet_infos[findex].use_value_index) { +#endif + size_t num_tokens_found = 0; for(auto pl: posting_lists) { - if(!posting_t::contains(pl, seq_id)) { - // need to ensure that document ID actually contains searched_query tokens - // since `field_result_ids` contains documents matched across all queries - id_matched = false; + if(posting_t::contains_atleast_one(pl, field_result_ids, field_result_ids_len)) { + num_tokens_found++; + } else { break; } } - if(!id_matched) { - continue; + if(num_tokens_found == posting_lists.size()) { + // need to ensure that document ID actually contains searched_query tokens + // since `field_result_ids` contains documents matched across all queries + // value based index + for(const auto& val : searched_tokens) { + facet_infos[findex].fvalue_searched_tokens.emplace_back(val); + } } + } + + else { + for(size_t i = 0; i < field_result_ids_len; i++) { + uint32_t seq_id = field_result_ids[i]; + bool id_matched = true; + + for(auto pl: posting_lists) { + if(!posting_t::contains(pl, seq_id)) { + // need to ensure that document ID actually contains searched_query tokens + // since `field_result_ids` contains documents matched across all queries + id_matched = false; + break; + } + } + + if(!id_matched) { + continue; + } - #ifdef TEST_BUILD - if(facet_index_type == HASH) { - #else - if(facet_index_v4->has_hash_index(a_facet.field_name)) { - #endif std::vector facet_hashes; auto facet_index = facet_index_v4->get_facet_hash_index(a_facet.field_name); posting_list_t::iterator_t facet_index_it = facet_index->new_iterator(); facet_index_it.skip_to(seq_id); - + if(facet_index_it.valid()) { posting_list_t::get_offsets(facet_index_it, facet_hashes); @@ -4486,7 +4511,7 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& for(size_t array_index: array_indices) { if(array_index < facet_hashes.size()) { - std::string hash = std::to_string(facet_hashes[array_index]); + uint32_t hash = facet_hashes[array_index]; /*LOG(INFO) << "seq_id: " << seq_id << ", hash: " << hash << ", array index: " << array_index;*/ @@ -4498,18 +4523,13 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& } } } else { - std::string hash = std::to_string(facet_hashes[0]); + uint32_t hash = facet_hashes[0]; if(facet_infos[findex].hashes.count(hash) == 0) { //LOG(INFO) << "adding searched_tokens for hash " << hash; facet_infos[findex].hashes.emplace(hash, searched_tokens); } } } - } else { - // value based index - for(const auto& val : searched_tokens) { - facet_infos[findex].hashes[facet_field.name].emplace_back(val); - } } } } @@ -5812,7 +5832,6 @@ void Index::get_doc_changes(const index_operation_t op, const tsl::htrie_map& records, } } +Option Index::get_reference_doc_id_with_lock(const string& reference_helper_field_name, + const uint32_t& seq_id) const { + std::shared_lock lock(mutex); + if (sort_index.count(reference_helper_field_name) == 0 || + sort_index.at(reference_helper_field_name)->count(seq_id) == 0) { + return Option(400, "Could not find a reference for doc " + std::to_string(seq_id)); + } + + return Option(sort_index.at(reference_helper_field_name)->at(seq_id)); +} + /* // https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon // NOTE: polygon and point should have been transformed with `transform_for_180th_meridian` diff --git a/src/raft_server.cpp b/src/raft_server.cpp index b3793b3c..560cd9ba 100644 --- a/src/raft_server.cpp +++ b/src/raft_server.cpp @@ -315,7 +315,7 @@ void ReplicationState::write_to_leader(const std::shared_ptr& request, // Handle no leader scenario LOG(ERROR) << "Rejecting write: could not find a leader."; - if(request->_req->proceed_req && response->proxied_stream) { + if(response->proxied_stream) { // streaming in progress: ensure graceful termination (cannot start response again) LOG(ERROR) << "Terminating streaming request gracefully."; response->is_alive = false; @@ -328,7 +328,7 @@ void ReplicationState::write_to_leader(const std::shared_ptr& request, return message_dispatcher->send_message(HttpServer::STREAM_RESPONSE_MESSAGE, req_res); } - if (request->_req->proceed_req && response->proxied_stream) { + if (response->proxied_stream) { // indicates async request body of in-flight request //LOG(INFO) << "Inflight proxied request, returning control to caller, body_size=" << request->body.size(); request->notify(); diff --git a/src/validator.cpp b/src/validator.cpp index 38288673..766696e6 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -627,7 +627,7 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, continue; } - if((a_field.optional || op == UPDATE || op == EMPLACE) && document.count(field_name) == 0) { + if((a_field.optional || op == UPDATE || (op == EMPLACE && is_update)) && document.count(field_name) == 0) { continue; } @@ -717,7 +717,7 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, } } } - if(all_optional_and_null && !field.optional) { + if(all_optional_and_null && !field.optional && !is_update) { return Option(400, "No valid fields found to create embedding for `" + field.name + "`, please provide at least one valid field or make the embedding field optional."); } } diff --git a/test/art_test.cpp b/test/art_test.cpp index 0ce46340..daf9d4fd 100644 --- a/test/art_test.cpp +++ b/test/art_test.cpp @@ -819,7 +819,7 @@ TEST(ArtTest, test_art_fuzzy_search) { art_fuzzy_search(&t, (const unsigned char *) "hown", strlen("hown") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(10, leaves.size()); - std::set expected_words = {"town", "sown", "mown", "lown", "howl", "howk", "howe", "how", "horn", "hoon"}; + std::set expected_words = {"town", "sown", "shown", "own", "mown", "lown", "howl", "howk", "howe", "how"}; for(size_t leaf_index = 0; leaf_index < leaves.size(); leaf_index++) { art_leaf*& leaf = leaves.at(leaf_index); @@ -1103,6 +1103,14 @@ TEST(ArtTest, test_art_search_roche_chews) { ASSERT_EQ(1, leaves.size()); + term = "xxroche"; + leaves.clear(); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char*)term.c_str(), term.size()+1, 0, 2, 10, + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); + + ASSERT_EQ(1, leaves.size()); + res = art_tree_destroy(&t); ASSERT_TRUE(res == 0); } diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 68722bac..13d8b2a6 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -1383,8 +1383,7 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference_SingleMatch) { ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name")); ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); - ASSERT_NE(0, res_obj["hits"][0]["hybrid_search_info"].at("text_match_score")); - ASSERT_NE(0, res_obj["hits"][0]["hybrid_search_info"].at("vector_distance")); + ASSERT_NE(0, res_obj["hits"][0]["hybrid_search_info"].at("rank_fusion_score")); // Hybrid search - Only vector match req_params = { @@ -1405,8 +1404,7 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference_SingleMatch) { ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name")); ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); - ASSERT_EQ(0, res_obj["hits"][0]["hybrid_search_info"].at("text_match_score")); - ASSERT_NE(0, res_obj["hits"][0]["hybrid_search_info"].at("vector_distance")); + ASSERT_NE(0, res_obj["hits"][0]["hybrid_search_info"].at("rank_fusion_score")); // Infix search req_params = { @@ -1429,6 +1427,26 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference_SingleMatch) { ASSERT_EQ("soap", res_obj["hits"][0]["document"].at("product_name")); ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); + + // Reference include_by without join + req_params = { + {"collection", "Customers"}, + {"q", "Joe"}, + {"query_by", "customer_name"}, + {"filter_by", "product_price:<100"}, + {"include_fields", "$Products(product_name), product_price"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name")); + ASSERT_EQ("soap", res_obj["hits"][0]["document"].at("product_name")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); + ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); } TEST_F(CollectionJoinTest, CascadeDeletion) { diff --git a/test/collection_nested_fields_test.cpp b/test/collection_nested_fields_test.cpp index a2eef13d..98a94f37 100644 --- a/test/collection_nested_fields_test.cpp +++ b/test/collection_nested_fields_test.cpp @@ -2560,6 +2560,144 @@ TEST_F(CollectionNestedFieldsTest, NullValuesWithExplicitSchema) { auto results = coll1->search("jack", {"name.first"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(1, results["found"].get()); ASSERT_EQ(2, results["hits"][0]["document"].size()); // id, name + ASSERT_EQ(1, results["hits"][0]["document"]["name"].size()); // name.first + ASSERT_EQ("Jack", results["hits"][0]["document"]["name"]["first"].get()); +} + +TEST_F(CollectionNestedFieldsTest, EmplaceWithNullValueOnRequiredField) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name":"currency", "type":"object"}, + {"name":"currency.eu", "type":"int32", "optional": false} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection *coll1 = op.get(); + + auto doc_with_null = R"({ + "id": "0", + "currency": { + "eu": null + } + })"_json; + + auto add_op = coll1->add(doc_with_null.dump(), EMPLACE); + ASSERT_FALSE(add_op.ok()); + + add_op = coll1->add(doc_with_null.dump(), CREATE); + ASSERT_FALSE(add_op.ok()); + + auto doc1 = R"({ + "id": "0", + "currency": { + "eu": 12000 + } + })"_json; + + add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // now update with null value -- should not be allowed + auto update_doc = R"({ + "id": "0", + "currency": { + "eu": null + } + })"_json; + + auto update_op = coll1->add(update_doc.dump(), EMPLACE); + ASSERT_FALSE(update_op.ok()); + ASSERT_EQ("Field `currency.eu` must be an int32.", update_op.error()); +} + +TEST_F(CollectionNestedFieldsTest, EmplaceWithNullValueOnOptionalField) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name":"currency", "type":"object"}, + {"name":"currency.eu", "type":"int32", "optional": true} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection *coll1 = op.get(); + + auto doc1 = R"({ + "id": "0", + "currency": { + "eu": 12000 + } + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // now update with null value -- should be allowed since field is optional + auto update_doc = R"({ + "id": "0", + "currency": { + "eu": null + } + })"_json; + + auto update_op = coll1->add(update_doc.dump(), EMPLACE); + ASSERT_TRUE(update_op.ok()); + + // try to fetch the document to see the stored value + auto results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(2, results["hits"][0]["document"].size()); // id, currency + ASSERT_EQ(0, results["hits"][0]["document"]["currency"].size()); +} + +TEST_F(CollectionNestedFieldsTest, EmplaceWithMissingArrayValueOnOptionalField) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name":"currency", "type":"object[]"}, + {"name":"currency.eu", "type":"int32[]", "optional": true} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection *coll1 = op.get(); + + auto doc1 = R"({ + "id": "0", + "currency": [ + {"eu": 12000}, + {"us": 10000} + ] + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // now update with null value -- should be allowed since field is optional + auto update_doc = R"({ + "id": "0", + "currency": [ + {"us": 10000} + ] + })"_json; + + auto update_op = coll1->add(update_doc.dump(), EMPLACE); + ASSERT_TRUE(update_op.ok()); + + // try to fetch the document to see the stored value + auto results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(2, results["hits"][0]["document"].size()); // id, currency + ASSERT_EQ(1, results["hits"][0]["document"]["currency"].size()); + ASSERT_EQ(10000, results["hits"][0]["document"]["currency"][0]["us"].get()); } TEST_F(CollectionNestedFieldsTest, UpdateNestedDocument) { diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 68980d21..aa83229e 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -2133,6 +2133,76 @@ TEST_F(CollectionSpecificMoreTest, WeightTakingPrecendeceOverMatch) { ASSERT_EQ(2, res["hits"][1]["text_match_info"]["tokens_matched"].get()); } +TEST_F(CollectionSpecificMoreTest, IncrementingCount) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "count", "type": "int32"} + ] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + // brand new document: create + upsert + emplace should work + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "Foo"; + doc["$operations"]["increment"]["count"] = 1; + ASSERT_TRUE(coll1->add(doc.dump(), CREATE).ok()); + + doc.clear(); + doc["id"] = "1"; + doc["title"] = "Bar"; + doc["$operations"]["increment"]["count"] = 1; + ASSERT_TRUE(coll1->add(doc.dump(), EMPLACE).ok()); + + doc.clear(); + doc["id"] = "2"; + doc["title"] = "Taz"; + doc["$operations"]["increment"]["count"] = 1; + ASSERT_TRUE(coll1->add(doc.dump(), UPSERT).ok()); + + auto res = coll1->search("*", {}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10).get(); + + ASSERT_EQ(3, res["hits"].size()); + ASSERT_EQ(1, res["hits"][0]["document"]["count"].get()); + ASSERT_EQ(1, res["hits"][1]["document"]["count"].get()); + ASSERT_EQ(1, res["hits"][2]["document"]["count"].get()); + + // should support updates + + doc.clear(); + doc["id"] = "0"; + doc["title"] = "Foo"; + doc["$operations"]["increment"]["count"] = 3; + ASSERT_TRUE(coll1->add(doc.dump(), UPSERT).ok()); + + doc.clear(); + doc["id"] = "1"; + doc["title"] = "Bar"; + doc["$operations"]["increment"]["count"] = 3; + ASSERT_TRUE(coll1->add(doc.dump(), EMPLACE).ok()); + + doc.clear(); + doc["id"] = "2"; + doc["title"] = "Bar"; + doc["$operations"]["increment"]["count"] = 3; + ASSERT_TRUE(coll1->add(doc.dump(), UPDATE).ok()); + + res = coll1->search("*", {}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10).get(); + + ASSERT_EQ(3, res["hits"].size()); + ASSERT_EQ(4, res["hits"][0]["document"]["count"].get()); + ASSERT_EQ(4, res["hits"][1]["document"]["count"].get()); + ASSERT_EQ(4, res["hits"][2]["document"]["count"].get()); +} + TEST_F(CollectionSpecificMoreTest, HighlightOnFieldNameWithDot) { nlohmann::json schema = R"({ "name": "coll1", @@ -2386,3 +2456,55 @@ TEST_F(CollectionSpecificMoreTest, TruncateAterTopK) { ASSERT_EQ(ids[i], results["hits"][i]["document"]["id"]); } } + +TEST_F(CollectionSpecificMoreTest, HybridSearchTextMatchInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string", "infix": true}, + {"name": "product_description", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_description"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + std::vector documents = { + R"({ + "product_id": "product_a", + "product_name": "shampoo", + "product_description": "Our new moisturizing shampoo is perfect for those with dry or damaged hair." + })"_json, + R"({ + "product_id": "product_b", + "product_name": "soap", + "product_description": "Introducing our all-natural, organic soap bar made with essential oils and botanical ingredients." + })"_json + }; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + ASSERT_TRUE(add_op.ok()); + } + + auto coll1 = collection_create_op.get(); + auto results = coll1->search("natural products", {"product_name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(2, results["hits"].size()); + + // It's a hybrid search with only vector match + ASSERT_EQ("0", results["hits"][0]["text_match_info"]["score"].get()); + ASSERT_EQ("0", results["hits"][1]["text_match_info"]["score"].get()); + + ASSERT_EQ(0, results["hits"][0]["text_match_info"]["fields_matched"].get()); + ASSERT_EQ(0, results["hits"][1]["text_match_info"]["fields_matched"].get()); + + ASSERT_EQ(0, results["hits"][0]["text_match_info"]["tokens_matched"].get()); + ASSERT_EQ(0, results["hits"][1]["text_match_info"]["tokens_matched"].get()); +} diff --git a/test/collection_test.cpp b/test/collection_test.cpp index ccc65d2d..ceeec72b 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -398,7 +398,7 @@ TEST_F(CollectionTest, QueryWithTypo) { spp::sparse_hash_set(), 10, "", 30, 5, "", 10).get(); - ids = {"1", "13", "8"}; + ids = {"8", "1", "17"}; ASSERT_EQ(3, results["hits"].size()); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 3ffe288e..059e4437 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -1033,6 +1033,134 @@ TEST_F(CollectionVectorTest, EmbedFromOptionalNullField) { ASSERT_TRUE(add_op.ok()); } +TEST_F(CollectionVectorTest, HideCredential) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name"], + "model_config": { + "model_name": "ts/e5-small", + "api_key": "ax-abcdef12345", + "access_token": "ax-abcdef12345", + "refresh_token": "ax-abcdef12345", + "client_id": "ax-abcdef12345", + "client_secret": "ax-abcdef12345", + "project_id": "ax-abcdef12345" + }}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + auto coll_summary = coll1->get_summary_json(); + + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["api_key"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["access_token"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["refresh_token"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["client_id"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["client_secret"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["project_id"].get()); + + // small api key + + schema_json = + R"({ + "name": "Products2", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name"], + "model_config": { + "model_name": "ts/e5-small", + "api_key": "ax1", + "access_token": "ax1", + "refresh_token": "ax1", + "client_id": "ax1", + "client_secret": "ax1", + "project_id": "ax1" + }}} + ] + })"_json; + + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll2 = collection_create_op.get(); + coll_summary = coll2->get_summary_json(); + + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["api_key"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["access_token"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["refresh_token"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["client_id"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["client_secret"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["project_id"].get()); +} + +TEST_F(CollectionVectorTest, UpdateOfCollWithNonOptionalEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "about", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json object; + object["id"] = "0"; + object["name"] = "butter"; + object["about"] = "about butter"; + + auto add_op = coll->add(object.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + nlohmann::json update_object; + update_object["id"] = "0"; + update_object["about"] = "something about butter"; + auto update_op = coll->add(update_object.dump(), EMPLACE); + ASSERT_TRUE(update_op.ok()); + + // action = update + update_object["about"] = "something about butter 2"; + update_op = coll->add(update_object.dump(), UPDATE); + ASSERT_TRUE(update_op.ok()); +} + +TEST_F(CollectionVectorTest, FreshEmplaceWithOptionalEmbeddingReferencedField) { + auto schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string", "optional": true}, + {"name": "about", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json object; + object["id"] = "0"; + object["about"] = "about butter"; + + auto add_op = coll->add(object.dump(), EMPLACE); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("No valid fields found to create embedding for `embedding`, please provide at least one valid field " + "or make the embedding field optional.", add_op.error()); +} + TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) { nlohmann::json schema = R"({ "name": "objects", @@ -1102,3 +1230,115 @@ TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) { ASSERT_FALSE(add_op.ok()); ASSERT_EQ("Field `embedding` contains invalid float values.", add_op.error()); } + +TEST_F(CollectionVectorTest, SemanticSearchReturnOnlyVectorDistance) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + auto results = coll1->search("moisturizer", {"embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // Return only vector distance + ASSERT_EQ(0, results["hits"][0].count("text_match_info")); + ASSERT_EQ(0, results["hits"][0].count("hybrid_search_info")); + ASSERT_EQ(1, results["hits"][0].count("vector_distance")); +} + +TEST_F(CollectionVectorTest, KeywordSearchReturnOnlyTextMatchInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + auto results = coll1->search("moisturizer", {"product_name"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + + ASSERT_EQ(1, results["hits"].size()); + + // Return only text match info + ASSERT_EQ(0, results["hits"][0].count("vector_distance")); + ASSERT_EQ(0, results["hits"][0].count("hybrid_search_info")); + ASSERT_EQ(1, results["hits"][0].count("text_match_info")); +} + +TEST_F(CollectionVectorTest, HybridSearchReturnAllInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + + auto results = coll1->search("moisturizer", {"product_name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // Return all info + ASSERT_EQ(1, results["hits"][0].count("vector_distance")); + ASSERT_EQ(1, results["hits"][0].count("text_match_info")); + ASSERT_EQ(1, results["hits"][0].count("hybrid_search_info")); +}