diff --git a/include/field.h b/include/field.h index 6b4ee937..d34d32ae 100644 --- a/include/field.h +++ b/include/field.h @@ -23,6 +23,7 @@ namespace field_types { static const std::string INT64 = "int64"; static const std::string FLOAT = "float"; static const std::string BOOL = "bool"; + static const std::string NIL = "nil"; static const std::string GEOPOINT = "geopoint"; static const std::string STRING_ARRAY = "string[]"; static const std::string INT32_ARRAY = "int32[]"; @@ -429,19 +430,19 @@ struct field { std::vector& fields_vec); static bool flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_array, bool has_obj_array, - const field& the_field, const std::string& flat_name, + bool is_update, const field& the_field, const std::string& flat_name, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields); static Option flatten_field(nlohmann::json& doc, nlohmann::json& obj, const field& the_field, std::vector& path_parts, size_t path_index, bool has_array, - bool has_obj_array, + bool has_obj_array, bool is_update, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields); static Option flatten_doc(nlohmann::json& document, const tsl::htrie_map& nested_fields, const std::unordered_map& dyn_fields, - bool missing_is_ok, std::vector& flattened_fields); + bool is_update, std::vector& flattened_fields); static void compact_nested_fields(tsl::htrie_map& nested_fields); }; diff --git a/include/topster.h b/include/topster.h index b0b8f125..a16b4440 100644 --- a/include/topster.h +++ b/include/topster.h @@ -24,13 +24,17 @@ struct KV { // to be used only in final aggregation uint64_t* query_indices = nullptr; - KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, uint8_t match_score_index, const int64_t *scores, + KV(uint16_t queryIndex, uint64_t key, uint64_t distinct_key, int8_t match_score_index, const int64_t *scores, reference_filter_result_t* reference_filter_result = nullptr): match_score_index(match_score_index), query_index(queryIndex), array_index(0), key(key), distinct_key(distinct_key), reference_filter_result(reference_filter_result) { this->scores[0] = scores[0]; this->scores[1] = scores[1]; this->scores[2] = scores[2]; + + if(match_score_index >= 0) { + this->text_match_score = scores[match_score_index]; + } } KV() = default; diff --git a/src/collection.cpp b/src/collection.cpp index 1b00fb8a..8883f4ab 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -261,10 +261,6 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::reference] = coll_field.reference; } - if(!coll_field.embed.empty()) { - field_json[fields::embed] = coll_field.embed; - } - fields_arr.push_back(field_json); } @@ -1043,7 +1039,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().embed.count(fields::from) != 0; + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().num_dim > 0; if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -1073,7 +1069,7 @@ Option Collection::extract_field_name(const std::string& field_name, return Option(true); } -Option Collection::search(std::string raw_query, +Option Collection::search(std::string raw_query, const std::vector& raw_search_fields, const std::string & filter_query, const std::vector& facet_fields, const std::vector & sort_fields, const std::vector& num_typos, @@ -1201,6 +1197,7 @@ Option Collection::search(std::string raw_query, std::vector processed_search_fields; std::vector query_by_weights; size_t num_embed_fields = 0; + std::string query = raw_query; for(size_t i = 0; i < raw_search_fields.size(); i++) { const std::string& field_name = raw_search_fields[i]; @@ -1289,6 +1286,11 @@ Option Collection::search(std::string raw_query, } } + // Set query to * if it is semantic search + if(!vector_query.field_name.empty() && processed_search_fields.empty()) { + query = "*"; + } + if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) { std::string error = "Vector query could not find any embedded fields."; return Option(400, error); @@ -1444,7 +1446,7 @@ Option Collection::search(std::string raw_query, size_t max_hits = DEFAULT_TOPSTER_SIZE; // ensure that `max_hits` never exceeds number of documents in collection - if(search_fields.size() <= 1 || raw_query == "*") { + if(search_fields.size() <= 1 || query == "*") { max_hits = std::min(std::max(fetch_size, max_hits), get_num_documents()); } else { max_hits = std::min(std::max(fetch_size, max_hits), get_num_documents()); @@ -1477,7 +1479,6 @@ Option Collection::search(std::string raw_query, StringUtils::split(hidden_hits_str, hidden_hits, ","); std::vector filter_overrides; - std::string query = raw_query; bool filter_curated_hits = false; std::string curated_sort_by; curate_results(query, filter_query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, @@ -1950,10 +1951,10 @@ Option Collection::search(std::string raw_query, if(field_order_kv->match_score_index == CURATED_RECORD_IDENTIFIER) { wrapper_doc["curated"] = true; } else if(field_order_kv->match_score_index >= 0) { - wrapper_doc["text_match"] = field_order_kv->scores[field_order_kv->match_score_index]; + wrapper_doc["text_match"] = field_order_kv->text_match_score; wrapper_doc["text_match_info"] = nlohmann::json::object(); populate_text_match_info(wrapper_doc["text_match_info"], - field_order_kv->scores[field_order_kv->match_score_index], match_type); + field_order_kv->text_match_score, match_type); if(!vector_query.field_name.empty()) { wrapper_doc["hybrid_search_info"] = nlohmann::json::object(); wrapper_doc["hybrid_search_info"]["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]); @@ -4931,9 +4932,10 @@ void Collection::hide_credential(nlohmann::json& json, const std::string& creden // hide api key with * except first 5 chars std::string credential_name_str = json[credential_name]; if(credential_name_str.size() > 5) { - json[credential_name] = credential_name_str.replace(5, credential_name_str.size() - 5, credential_name_str.size() - 5, '*'); + size_t num_chars_to_replace = credential_name_str.size() - 5; + json[credential_name] = credential_name_str.replace(5, num_chars_to_replace, num_chars_to_replace, '*'); } else { - json[credential_name] = credential_name_str.replace(0, credential_name_str.size(), credential_name_str.size(), '*'); + json[credential_name] = "***********"; } } } diff --git a/src/field.cpp b/src/field.cpp index bccd8da5..7d5e399c 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -824,18 +824,41 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } bool field::flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_array, bool has_obj_array, - const field& the_field, const std::string& flat_name, + bool is_update, const field& the_field, const std::string& flat_name, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields) { if(value.is_object()) { has_obj_array = has_array; - for(const auto& kv: value.items()) { - flatten_obj(doc, kv.value(), has_array, has_obj_array, the_field, flat_name + "." + kv.key(), - dyn_fields, flattened_fields); + auto it = value.begin(); + while(it != value.end()) { + const std::string& child_field_name = flat_name + "." + it.key(); + if(it.value().is_null()) { + if(has_array) { + doc[child_field_name].push_back(nullptr); + } else { + doc[child_field_name] = nullptr; + } + + field flattened_field; + flattened_field.name = child_field_name; + flattened_field.type = field_types::NIL; + flattened_fields[child_field_name] = flattened_field; + + if(!is_update) { + // update code path requires and takes care of null values + it = value.erase(it); + } else { + it++; + } + } else { + flatten_obj(doc, it.value(), has_array, has_obj_array, is_update, the_field, child_field_name, + dyn_fields, flattened_fields); + it++; + } } } else if(value.is_array()) { for(const auto& kv: value.items()) { - flatten_obj(doc, kv.value(), true, has_obj_array, the_field, flat_name, dyn_fields, flattened_fields); + flatten_obj(doc, kv.value(), true, has_obj_array, is_update, the_field, flat_name, dyn_fields, flattened_fields); } } else { // must be a primitive if(doc.count(flat_name) != 0 && flattened_fields.find(flat_name) == flattened_fields.end()) { @@ -891,7 +914,7 @@ bool field::flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_arr Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, const field& the_field, std::vector& path_parts, size_t path_index, - bool has_array, bool has_obj_array, + bool has_array, bool has_obj_array, bool is_update, const std::unordered_map& dyn_fields, std::unordered_map& flattened_fields) { if(path_index == path_parts.size()) { @@ -946,7 +969,8 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons if(detected_type == the_field.type || is_numericaly_valid) { if(the_field.is_object()) { - flatten_obj(doc, obj, has_array, has_obj_array, the_field, the_field.name, dyn_fields, flattened_fields); + flatten_obj(doc, obj, has_array, has_obj_array, is_update, the_field, the_field.name, + dyn_fields, flattened_fields); } else { if(doc.count(the_field.name) != 0 && flattened_fields.find(the_field.name) == flattened_fields.end()) { return Option(true); @@ -989,7 +1013,7 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons for(auto& ele: it.value()) { has_obj_array = has_obj_array || ele.is_object(); Option op = flatten_field(doc, ele, the_field, path_parts, path_index + 1, has_array, - has_obj_array, dyn_fields, flattened_fields); + has_obj_array, is_update, dyn_fields, flattened_fields); if(!op.ok()) { return op; } @@ -997,7 +1021,7 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons return Option(true); } else { return flatten_field(doc, it.value(), the_field, path_parts, path_index + 1, has_array, has_obj_array, - dyn_fields, flattened_fields); + is_update, dyn_fields, flattened_fields); } } { return Option(404, "Field `" + the_field.name + "` not found."); @@ -1007,7 +1031,7 @@ Option field::flatten_field(nlohmann::json& doc, nlohmann::json& obj, cons Option field::flatten_doc(nlohmann::json& document, const tsl::htrie_map& nested_fields, const std::unordered_map& dyn_fields, - bool missing_is_ok, std::vector& flattened_fields) { + bool is_update, std::vector& flattened_fields) { std::unordered_map flattened_fields_map; @@ -1021,12 +1045,12 @@ Option field::flatten_doc(nlohmann::json& document, } auto op = flatten_field(document, document, nested_field, field_parts, 0, false, false, - dyn_fields, flattened_fields_map); + is_update, dyn_fields, flattened_fields_map); if(op.ok()) { continue; } - if(op.code() == 404 && (missing_is_ok || nested_field.optional)) { + if(op.code() == 404 && (is_update || nested_field.optional)) { continue; } else { return op; @@ -1036,7 +1060,10 @@ Option field::flatten_doc(nlohmann::json& document, document[".flat"] = nlohmann::json::array(); for(auto& kv: flattened_fields_map) { document[".flat"].push_back(kv.second.name); - flattened_fields.push_back(kv.second); + if(kv.second.type != field_types::NIL) { + // not a real field so we won't add it + flattened_fields.push_back(kv.second); + } } return Option(true); diff --git a/src/http_server.cpp b/src/http_server.cpp index 9fbfb2a5..50391ed8 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -569,13 +569,11 @@ int HttpServer::async_req_cb(void *ctx, int is_end_stream) { bool async_req = custom_generator->rpath->async_req; bool is_http_v1 = (0x101 <= request->_req->version && request->_req->version < 0x200); - /* - LOG(INFO) << "async_req_cb, chunk.len=" << chunk.len + /*LOG(INFO) << "async_req_cb, chunk.len=" << chunk.len << ", is_http_v1: " << is_http_v1 - << ", request->req->entity.len=" << request->req->entity.len - << ", content_len: " << request->req->content_length - << ", is_end_stream=" << is_end_stream; - */ + << ", request->req->entity.len=" << request->_req->entity.len + << ", content_len: " << request->_req->content_length + << ", is_end_stream=" << is_end_stream;*/ // disallow specific curl clients from using import call via http2 // detects: https://github.com/curl/curl/issues/1410 diff --git a/src/index.cpp b/src/index.cpp index a87f245c..4d81126a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -434,6 +434,8 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite continue; } + handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc); + if(do_validation) { Option validation_op = validator_t::validate_index_in_memory(index_rec.doc, index_rec.seq_id, default_sorting_field, @@ -471,7 +473,6 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite } } } else { - handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc); if(generate_embeddings) { records_to_embed.push_back(&index_rec); } @@ -867,9 +868,8 @@ void Index::index_field_in_memory(const field& afield, std::vector continue; } - const std::vector& float_vals = record.doc[afield.name].get>(); - try { + const std::vector& float_vals = record.doc[afield.name].get>(); if(afield.vec_dist == cosine) { std::vector normalized_vals(afield.num_dim); hnsw_index_t::normalize_vector(float_vals, normalized_vals); @@ -3213,6 +3213,7 @@ Option Index::search(std::vector& field_query_tokens, cons auto result = result_it->second; // old_score + (1 / rank_of_document) * WEIGHT) result->vector_distance = vec_result.second; + result->text_match_score = result->scores[result->match_score_index]; int64_t match_score = float_to_int64_t( (int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT)); @@ -3234,6 +3235,7 @@ Option Index::search(std::vector& field_query_tokens, cons int64_t match_score_index = -1; compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second); KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores); + kv.text_match_score = 0; kv.vector_distance = vec_result.second; topster->add(&kv); vec_search_ids.push_back(doc_id); @@ -4163,6 +4165,7 @@ void Index::search_across_fields(const std::vector& query_tokens, KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores); if(match_score_index != -1) { kv.scores[match_score_index] = aggregated_score; + kv.text_match_score = aggregated_score; } int ret = topster->add(&kv); @@ -6258,7 +6261,6 @@ void Index::get_doc_changes(const index_operation_t op, const tsl::htrie_map& request, // Handle no leader scenario LOG(ERROR) << "Rejecting write: could not find a leader."; - if(request->_req->proceed_req && response->proxied_stream) { + if(response->proxied_stream) { // streaming in progress: ensure graceful termination (cannot start response again) LOG(ERROR) << "Terminating streaming request gracefully."; response->is_alive = false; @@ -267,7 +267,7 @@ void ReplicationState::write_to_leader(const std::shared_ptr& request, return message_dispatcher->send_message(HttpServer::STREAM_RESPONSE_MESSAGE, req_res); } - if (request->_req->proceed_req && response->proxied_stream) { + if (response->proxied_stream) { // indicates async request body of in-flight request //LOG(INFO) << "Inflight proxied request, returning control to caller, body_size=" << request->body.size(); request->notify(); diff --git a/src/validator.cpp b/src/validator.cpp index 51a7d19c..f8c23ee9 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -626,7 +626,7 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, continue; } - if((a_field.optional || op == UPDATE || op == EMPLACE) && document.count(field_name) == 0) { + if((a_field.optional || op == UPDATE || (op == EMPLACE && is_update)) && document.count(field_name) == 0) { continue; } @@ -716,7 +716,7 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, } } } - if(all_optional_and_null && !field.optional) { + if(all_optional_and_null && !field.optional && !is_update) { return Option(400, "No valid fields found to create embedding for `" + field.name + "`, please provide at least one valid field or make the embedding field optional."); } } diff --git a/test/collection_nested_fields_test.cpp b/test/collection_nested_fields_test.cpp index a2eef13d..98a94f37 100644 --- a/test/collection_nested_fields_test.cpp +++ b/test/collection_nested_fields_test.cpp @@ -2560,6 +2560,144 @@ TEST_F(CollectionNestedFieldsTest, NullValuesWithExplicitSchema) { auto results = coll1->search("jack", {"name.first"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(1, results["found"].get()); ASSERT_EQ(2, results["hits"][0]["document"].size()); // id, name + ASSERT_EQ(1, results["hits"][0]["document"]["name"].size()); // name.first + ASSERT_EQ("Jack", results["hits"][0]["document"]["name"]["first"].get()); +} + +TEST_F(CollectionNestedFieldsTest, EmplaceWithNullValueOnRequiredField) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name":"currency", "type":"object"}, + {"name":"currency.eu", "type":"int32", "optional": false} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection *coll1 = op.get(); + + auto doc_with_null = R"({ + "id": "0", + "currency": { + "eu": null + } + })"_json; + + auto add_op = coll1->add(doc_with_null.dump(), EMPLACE); + ASSERT_FALSE(add_op.ok()); + + add_op = coll1->add(doc_with_null.dump(), CREATE); + ASSERT_FALSE(add_op.ok()); + + auto doc1 = R"({ + "id": "0", + "currency": { + "eu": 12000 + } + })"_json; + + add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // now update with null value -- should not be allowed + auto update_doc = R"({ + "id": "0", + "currency": { + "eu": null + } + })"_json; + + auto update_op = coll1->add(update_doc.dump(), EMPLACE); + ASSERT_FALSE(update_op.ok()); + ASSERT_EQ("Field `currency.eu` must be an int32.", update_op.error()); +} + +TEST_F(CollectionNestedFieldsTest, EmplaceWithNullValueOnOptionalField) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name":"currency", "type":"object"}, + {"name":"currency.eu", "type":"int32", "optional": true} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection *coll1 = op.get(); + + auto doc1 = R"({ + "id": "0", + "currency": { + "eu": 12000 + } + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // now update with null value -- should be allowed since field is optional + auto update_doc = R"({ + "id": "0", + "currency": { + "eu": null + } + })"_json; + + auto update_op = coll1->add(update_doc.dump(), EMPLACE); + ASSERT_TRUE(update_op.ok()); + + // try to fetch the document to see the stored value + auto results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(2, results["hits"][0]["document"].size()); // id, currency + ASSERT_EQ(0, results["hits"][0]["document"]["currency"].size()); +} + +TEST_F(CollectionNestedFieldsTest, EmplaceWithMissingArrayValueOnOptionalField) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name":"currency", "type":"object[]"}, + {"name":"currency.eu", "type":"int32[]", "optional": true} + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection *coll1 = op.get(); + + auto doc1 = R"({ + "id": "0", + "currency": [ + {"eu": 12000}, + {"us": 10000} + ] + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + // now update with null value -- should be allowed since field is optional + auto update_doc = R"({ + "id": "0", + "currency": [ + {"us": 10000} + ] + })"_json; + + auto update_op = coll1->add(update_doc.dump(), EMPLACE); + ASSERT_TRUE(update_op.ok()); + + // try to fetch the document to see the stored value + auto results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(2, results["hits"][0]["document"].size()); // id, currency + ASSERT_EQ(1, results["hits"][0]["document"]["currency"].size()); + ASSERT_EQ(10000, results["hits"][0]["document"]["currency"][0]["us"].get()); } TEST_F(CollectionNestedFieldsTest, UpdateNestedDocument) { diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index e3f82f69..83181283 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -2123,6 +2123,76 @@ TEST_F(CollectionSpecificMoreTest, WeightTakingPrecendeceOverMatch) { ASSERT_EQ(2, res["hits"][1]["text_match_info"]["tokens_matched"].get()); } +TEST_F(CollectionSpecificMoreTest, IncrementingCount) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "count", "type": "int32"} + ] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + // brand new document: create + upsert + emplace should work + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "Foo"; + doc["$operations"]["increment"]["count"] = 1; + ASSERT_TRUE(coll1->add(doc.dump(), CREATE).ok()); + + doc.clear(); + doc["id"] = "1"; + doc["title"] = "Bar"; + doc["$operations"]["increment"]["count"] = 1; + ASSERT_TRUE(coll1->add(doc.dump(), EMPLACE).ok()); + + doc.clear(); + doc["id"] = "2"; + doc["title"] = "Taz"; + doc["$operations"]["increment"]["count"] = 1; + ASSERT_TRUE(coll1->add(doc.dump(), UPSERT).ok()); + + auto res = coll1->search("*", {}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10).get(); + + ASSERT_EQ(3, res["hits"].size()); + ASSERT_EQ(1, res["hits"][0]["document"]["count"].get()); + ASSERT_EQ(1, res["hits"][1]["document"]["count"].get()); + ASSERT_EQ(1, res["hits"][2]["document"]["count"].get()); + + // should support updates + + doc.clear(); + doc["id"] = "0"; + doc["title"] = "Foo"; + doc["$operations"]["increment"]["count"] = 3; + ASSERT_TRUE(coll1->add(doc.dump(), UPSERT).ok()); + + doc.clear(); + doc["id"] = "1"; + doc["title"] = "Bar"; + doc["$operations"]["increment"]["count"] = 3; + ASSERT_TRUE(coll1->add(doc.dump(), EMPLACE).ok()); + + doc.clear(); + doc["id"] = "2"; + doc["title"] = "Bar"; + doc["$operations"]["increment"]["count"] = 3; + ASSERT_TRUE(coll1->add(doc.dump(), UPDATE).ok()); + + res = coll1->search("*", {}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10).get(); + + ASSERT_EQ(3, res["hits"].size()); + ASSERT_EQ(4, res["hits"][0]["document"]["count"].get()); + ASSERT_EQ(4, res["hits"][1]["document"]["count"].get()); + ASSERT_EQ(4, res["hits"][2]["document"]["count"].get()); +} + TEST_F(CollectionSpecificMoreTest, HighlightOnFieldNameWithDot) { nlohmann::json schema = R"({ "name": "coll1", @@ -2530,4 +2600,58 @@ TEST_F(CollectionSpecificMoreTest, ApproxFilterMatchCount) { delete filter_tree_root; collectionManager.drop_collection("Collection"); -} \ No newline at end of file +} + +TEST_F(CollectionSpecificMoreTest, HybridSearchTextMatchInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_id", "type": "string"}, + {"name": "product_name", "type": "string", "infix": true}, + {"name": "product_description", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_description"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + std::vector documents = { + R"({ + "product_id": "product_a", + "product_name": "shampoo", + "product_description": "Our new moisturizing shampoo is perfect for those with dry or damaged hair." + })"_json, + R"({ + "product_id": "product_b", + "product_name": "soap", + "product_description": "Introducing our all-natural, organic soap bar made with essential oils and botanical ingredients." + })"_json + }; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + ASSERT_TRUE(add_op.ok()); + } + + auto coll1 = collection_create_op.get(); + auto results = coll1->search("natural products", {"product_name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(2, results["hits"].size()); + + // It's a hybrid search with only vector match + ASSERT_EQ("0", results["hits"][0]["text_match_info"]["score"].get()); + ASSERT_EQ("0", results["hits"][1]["text_match_info"]["score"].get()); + + ASSERT_EQ(0, results["hits"][0]["text_match_info"]["fields_matched"].get()); + ASSERT_EQ(0, results["hits"][1]["text_match_info"]["fields_matched"].get()); + + ASSERT_EQ(0, results["hits"][0]["text_match_info"]["tokens_matched"].get()); + ASSERT_EQ(0, results["hits"][1]["text_match_info"]["tokens_matched"].get()); +} + + diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 3ffe288e..059e4437 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -1033,6 +1033,134 @@ TEST_F(CollectionVectorTest, EmbedFromOptionalNullField) { ASSERT_TRUE(add_op.ok()); } +TEST_F(CollectionVectorTest, HideCredential) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name"], + "model_config": { + "model_name": "ts/e5-small", + "api_key": "ax-abcdef12345", + "access_token": "ax-abcdef12345", + "refresh_token": "ax-abcdef12345", + "client_id": "ax-abcdef12345", + "client_secret": "ax-abcdef12345", + "project_id": "ax-abcdef12345" + }}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + auto coll_summary = coll1->get_summary_json(); + + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["api_key"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["access_token"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["refresh_token"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["client_id"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["client_secret"].get()); + ASSERT_EQ("ax-ab*********", coll_summary["fields"][1]["embed"]["model_config"]["project_id"].get()); + + // small api key + + schema_json = + R"({ + "name": "Products2", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name"], + "model_config": { + "model_name": "ts/e5-small", + "api_key": "ax1", + "access_token": "ax1", + "refresh_token": "ax1", + "client_id": "ax1", + "client_secret": "ax1", + "project_id": "ax1" + }}} + ] + })"_json; + + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll2 = collection_create_op.get(); + coll_summary = coll2->get_summary_json(); + + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["api_key"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["access_token"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["refresh_token"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["client_id"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["client_secret"].get()); + ASSERT_EQ("***********", coll_summary["fields"][1]["embed"]["model_config"]["project_id"].get()); +} + +TEST_F(CollectionVectorTest, UpdateOfCollWithNonOptionalEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "about", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json object; + object["id"] = "0"; + object["name"] = "butter"; + object["about"] = "about butter"; + + auto add_op = coll->add(object.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + nlohmann::json update_object; + update_object["id"] = "0"; + update_object["about"] = "something about butter"; + auto update_op = coll->add(update_object.dump(), EMPLACE); + ASSERT_TRUE(update_op.ok()); + + // action = update + update_object["about"] = "something about butter 2"; + update_op = coll->add(update_object.dump(), UPDATE); + ASSERT_TRUE(update_op.ok()); +} + +TEST_F(CollectionVectorTest, FreshEmplaceWithOptionalEmbeddingReferencedField) { + auto schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string", "optional": true}, + {"name": "about", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json object; + object["id"] = "0"; + object["about"] = "about butter"; + + auto add_op = coll->add(object.dump(), EMPLACE); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("No valid fields found to create embedding for `embedding`, please provide at least one valid field " + "or make the embedding field optional.", add_op.error()); +} + TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) { nlohmann::json schema = R"({ "name": "objects", @@ -1102,3 +1230,115 @@ TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) { ASSERT_FALSE(add_op.ok()); ASSERT_EQ("Field `embedding` contains invalid float values.", add_op.error()); } + +TEST_F(CollectionVectorTest, SemanticSearchReturnOnlyVectorDistance) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + auto results = coll1->search("moisturizer", {"embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // Return only vector distance + ASSERT_EQ(0, results["hits"][0].count("text_match_info")); + ASSERT_EQ(0, results["hits"][0].count("hybrid_search_info")); + ASSERT_EQ(1, results["hits"][0].count("vector_distance")); +} + +TEST_F(CollectionVectorTest, KeywordSearchReturnOnlyTextMatchInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + auto results = coll1->search("moisturizer", {"product_name"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + + ASSERT_EQ(1, results["hits"].size()); + + // Return only text match info + ASSERT_EQ(0, results["hits"][0].count("vector_distance")); + ASSERT_EQ(0, results["hits"][0].count("hybrid_search_info")); + ASSERT_EQ(1, results["hits"][0].count("text_match_info")); +} + +TEST_F(CollectionVectorTest, HybridSearchReturnAllInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + + auto results = coll1->search("moisturizer", {"product_name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // Return all info + ASSERT_EQ(1, results["hits"][0].count("vector_distance")); + ASSERT_EQ(1, results["hits"][0].count("text_match_info")); + ASSERT_EQ(1, results["hits"][0].count("hybrid_search_info")); +}