diff --git a/BUILD b/BUILD index 75df2b3e..df4270f9 100644 --- a/BUILD +++ b/BUILD @@ -137,6 +137,7 @@ filegroup( "test/**/*.txt", "test/**/*.ini", "test/**/*.jsonl", + "test/**/*.gz", ]), ) diff --git a/include/collection.h b/include/collection.h index bfc6d0cc..3abe576f 100644 --- a/include/collection.h +++ b/include/collection.h @@ -167,7 +167,7 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); - void process_remove_field_for_embedding_fields(const field& the_field, std::vector& garbage_fields); + void process_remove_field_for_embedding_fields(const field& del_field, std::vector& garbage_embed_fields); void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, diff --git a/include/raft_server.h b/include/raft_server.h index dddaccb9..d5fcfec7 100644 --- a/include/raft_server.h +++ b/include/raft_server.h @@ -142,8 +142,6 @@ private: butil::EndPoint peering_endpoint; - Option handle_gzip(const std::shared_ptr& request); - public: static constexpr const char* log_dir_name = "log"; @@ -241,6 +239,8 @@ public: std::string get_leader_url() const; + static Option handle_gzip(const std::shared_ptr& request); + private: friend class ReplicationClosure; diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index 543e8f91..2681e0b9 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -72,6 +72,10 @@ public: Option validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims); Option validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims); + std::unordered_map> _get_text_embedders() { + return text_embedders; + } + private: TextEmbedderManager() = default; diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 8167fefd..b5f219d5 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -31,6 +31,7 @@ class RemoteEmbedder { virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) = 0; + static const std::string get_model_key(const nlohmann::json& model_config); static void init(ReplicationState* rs) { raft_server = rs; } @@ -51,6 +52,7 @@ class OpenAIEmbedder : public RemoteEmbedder { embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; + static std::string get_model_key(const nlohmann::json& model_config); }; @@ -68,6 +70,7 @@ class GoogleEmbedder : public RemoteEmbedder { embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; + static std::string get_model_key(const nlohmann::json& model_config); }; @@ -95,6 +98,7 @@ class GCPEmbedder : public RemoteEmbedder { embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; + static std::string get_model_key(const nlohmann::json& model_config); }; diff --git a/src/collection.cpp b/src/collection.cpp index 0e17d4db..e749e75b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -4254,7 +4254,6 @@ Option Collection::alter(nlohmann::json& alter_payload) { } } - // hide credentials in the alter payload return for(auto& field_json : alter_payload["fields"]) { if(field_json[fields::embed].count(fields::model_config) != 0) { @@ -4267,8 +4266,6 @@ Option Collection::alter(nlohmann::json& alter_payload) { } } - - return Option(true); } @@ -5346,27 +5343,43 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha } // Removes the dropped field from embed_from of all embedding fields. -void Collection::process_remove_field_for_embedding_fields(const field& the_field, std::vector& garbage_fields) { +void Collection::process_remove_field_for_embedding_fields(const field& del_field, + std::vector& garbage_embed_fields) { for(auto& field : fields) { if(field.embed.count(fields::from) == 0) { continue; } - auto embed_from = field.embed[fields::from].get>(); - embed_from.erase(std::remove_if(embed_from.begin(), embed_from.end(), [&the_field](std::string field_name) { - return the_field.name == field_name; - })); - field.embed[fields::from] = std::move(embed_from); - embedding_fields[field.name] = field; - // mark this embedding field as "garbage" if it has no more embed_from fields - if(embed_from.empty()) { - embedding_fields.erase(field.name); - garbage_fields.push_back(field); + bool found_field = false; + nlohmann::json& embed_from_names = field.embed[fields::from]; + for(auto it = embed_from_names.begin(); it != embed_from_names.end();) { + if(it.value() == del_field.name) { + it = embed_from_names.erase(it); + found_field = true; + } else { + it++; + } } - + if(found_field) { + // mark this embedding field as "garbage" if it has no more embed_from fields + if(embed_from_names.empty()) { + garbage_embed_fields.push_back(field); + } else { + // the dropped field was present in `embed_from`, so we have to update the field objects + field.embed[fields::from] = embed_from_names; + embedding_fields[field.name].embed[fields::from] = embed_from_names; + } + } } + for(auto& garbage_field: garbage_embed_fields) { + embedding_fields.erase(garbage_field.name); + search_schema.erase(garbage_field.name); + fields.erase(std::remove_if(fields.begin(), fields.end(), [&garbage_field](const auto &f) { + return f.name == garbage_field.name; + }), fields.end()); + } } void Collection::hide_credential(nlohmann::json& json, const std::string& credential_name) { diff --git a/src/core_api.cpp b/src/core_api.cpp index 06572514..88e8d036 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -730,7 +730,7 @@ bool get_export_documents(const std::shared_ptr& req, const std::share } } - res->content_type_header = "text/plain; charset=utf8"; + res->content_type_header = "text/plain; charset=utf-8"; res->status_code = 200; stream_response(req, res); @@ -903,7 +903,7 @@ bool post_import_documents(const std::shared_ptr& req, const std::shar } } - res->content_type_header = "text/plain; charset=utf8"; + res->content_type_header = "text/plain; charset=utf-8"; res->status_code = 200; res->body = response_stream.str(); diff --git a/src/index.cpp b/src/index.cpp index af6c62fd..8da8276d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -953,12 +953,16 @@ void Index::index_field_in_memory(const field& afield, std::vector try { const std::vector& float_vals = record.doc[afield.name].get>(); - if(afield.vec_dist == cosine) { - std::vector normalized_vals(afield.num_dim); - hnsw_index_t::normalize_vector(float_vals, normalized_vals); - vec_index->addPoint(normalized_vals.data(), (size_t)record.seq_id, true); + if(float_vals.size() != afield.num_dim) { + record.index_failure(400, "Vector size mismatch."); } else { - vec_index->addPoint(float_vals.data(), (size_t)record.seq_id, true); + if(afield.vec_dist == cosine) { + std::vector normalized_vals(afield.num_dim); + hnsw_index_t::normalize_vector(float_vals, normalized_vals); + vec_index->addPoint(normalized_vals.data(), (size_t)record.seq_id, true); + } else { + vec_index->addPoint(float_vals.data(), (size_t)record.seq_id, true); + } } } catch(const std::exception &e) { record.index_failure(400, e.what()); diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index ac2c110f..89400a79 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -43,9 +43,10 @@ Option TextEmbedderManager::validate_and_init_remote_model(const nlohmann: } std::unique_lock lock(text_embedders_mutex); - auto text_embedder_it = text_embedders.find(model_name); + std::string model_key = is_remote_model(model_name) ? RemoteEmbedder::get_model_key(model_config) : model_name; + auto text_embedder_it = text_embedders.find(model_key); if(text_embedder_it == text_embedders.end()) { - text_embedders.emplace(model_name, std::make_shared(model_config, num_dims)); + text_embedders.emplace(model_key, std::make_shared(model_config, num_dims)); } return Option(true); @@ -122,7 +123,8 @@ Option TextEmbedderManager::validate_and_init_local_model(const nlohmann:: Option TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) { std::unique_lock lock(text_embedders_mutex); const std::string& model_name = model_config.at("model_name"); - auto text_embedder_it = text_embedders.find(model_name); + std::string model_key = is_remote_model(model_name) ? RemoteEmbedder::get_model_key(model_config) : model_name; + auto text_embedder_it = text_embedders.find(model_key); if(text_embedder_it == text_embedders.end()) { return Option(404, "Text embedder was not found."); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index e59f93bf..06a53a65 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -53,6 +53,21 @@ long RemoteEmbedder::call_remote_api(const std::string& method, const std::strin proxy_call_timeout_ms, true); } + +const std::string RemoteEmbedder::get_model_key(const nlohmann::json& model_config) { + const std::string model_namespace = TextEmbedderManager::get_model_namespace(model_config["model_name"].get()); + + if(model_namespace == "openai") { + return OpenAIEmbedder::get_model_key(model_config); + } else if(model_namespace == "google") { + return GoogleEmbedder::get_model_key(model_config); + } else if(model_namespace == "gcp") { + return GCPEmbedder::get_model_key(model_config); + } else { + return ""; + } +} + OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) { } @@ -255,6 +270,9 @@ nlohmann::json OpenAIEmbedder::get_error_json(const nlohmann::json& req_body, lo return embedding_res; } +std::string OpenAIEmbedder::get_model_key(const nlohmann::json& model_config) { + return model_config["model_name"].get() + ":" + model_config["api_key"].get(); +} GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) { @@ -372,6 +390,10 @@ nlohmann::json GoogleEmbedder::get_error_json(const nlohmann::json& req_body, lo return embedding_res; } +std::string GoogleEmbedder::get_model_key(const nlohmann::json& model_config) { + return model_config["model_name"].get() + ":" + model_config["api_key"].get(); +} + GCPEmbedder::GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) : @@ -625,3 +647,7 @@ Option GCPEmbedder::generate_access_token(const std::string& refres return Option(access_token); } + +std::string GCPEmbedder::get_model_key(const nlohmann::json& model_config) { + return model_config["model_name"].get() + ":" + model_config["project_id"].get() + ":" + model_config["client_secret"].get(); +} \ No newline at end of file diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index 7bc6f32c..0d8364fd 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1580,9 +1580,13 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { nlohmann::json schema = R"({ "name": "objects", "fields": [ - {"name": "names", "type": "string[]"}, - {"name": "category", "type":"string"}, - {"name": "embedding", "type":"float[]", "embed":{"from": ["names","category"], "model_config": {"model_name": "ts/e5-small"}}} + {"name": "title", "type": "string"}, + {"name": "names", "type": "string[]"}, + {"name": "category", "type":"string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["names","category"], + "model_config": {"model_name": "ts/e5-small"}}}, + {"name": "embedding2", "type":"float[]", "embed":{"from": ["names"], + "model_config": {"model_name": "ts/e5-small"}}} ] })"_json; @@ -1594,20 +1598,28 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { LOG(INFO) << "Created collection"; + auto embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(2, embedding_fields.size()); + ASSERT_EQ(2, embedding_fields["embedding"].embed[fields::from].get>().size()); + ASSERT_EQ(1, embedding_fields["embedding2"].embed[fields::from].get>().size()); + + auto coll_schema = coll->get_schema(); + ASSERT_EQ(5, coll_schema.size()); + + auto the_fields = coll->get_fields(); + ASSERT_EQ(5, the_fields.size()); + auto schema_changes = R"({ "fields": [ {"name": "names", "drop": true} ] })"_json; - - auto embedding_fields = coll->get_embedding_fields(); - ASSERT_EQ(2, embedding_fields["embedding"].embed[fields::from].get>().size()); - auto alter_op = coll->alter(schema_changes); ASSERT_TRUE(alter_op.ok()); embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(1, embedding_fields.size()); ASSERT_EQ(1, embedding_fields["embedding"].embed[fields::from].get>().size()); ASSERT_EQ("category", embedding_fields["embedding"].embed[fields::from].get>()[0]); @@ -1623,6 +1635,16 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { embedding_fields = coll->get_embedding_fields(); ASSERT_EQ(0, embedding_fields.size()); ASSERT_EQ(0, coll->_get_index()->_get_vector_index().size()); + + // only title remains + + coll_schema = coll->get_schema(); + ASSERT_EQ(1, coll_schema.size()); + ASSERT_EQ("title", coll_schema["title"].name); + + the_fields = coll->get_fields(); + ASSERT_EQ(1, the_fields.size()); + ASSERT_EQ("title", the_fields[0].name); } TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 059e4437..02bef6ff 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -1342,3 +1342,98 @@ TEST_F(CollectionVectorTest, HybridSearchReturnAllInfo) { ASSERT_EQ(1, results["hits"][0].count("text_match_info")); ASSERT_EQ(1, results["hits"][0].count("hybrid_search_info")); } + + +TEST_F(CollectionVectorTest, DISABLED_HybridSortingTest) { + auto schema_json = + R"({ + "name": "TEST", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "name": "john doe" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll1->add(R"({ + "name": "john legend" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll1->add(R"({ + "name": "john krasinski" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + add_op = coll1->add(R"({ + "name": "john abraham" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + // first do keyword search + auto results = coll1->search("john", {"name"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(4, results["hits"].size()); + + + // now do hybrid search with sort_by: _text_match:desc,_vector_distance:asc + std::vector sort_by_list = {{"_text_match", "desc"}, {"_vector_distance", "asc"}}; + + auto hybrid_results = coll1->search("john", {"name", "embedding"}, + "", {}, sort_by_list, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + // first 4 results should be same as keyword search + ASSERT_EQ(results["hits"][0]["document"]["name"].get(), hybrid_results["hits"][0]["document"]["name"].get()); + ASSERT_EQ(results["hits"][1]["document"]["name"].get(), hybrid_results["hits"][1]["document"]["name"].get()); + ASSERT_EQ(results["hits"][2]["document"]["name"].get(), hybrid_results["hits"][2]["document"]["name"].get()); + ASSERT_EQ(results["hits"][3]["document"]["name"].get(), hybrid_results["hits"][3]["document"]["name"].get()); +} + +TEST_F(CollectionVectorTest, TestDifferentOpenAIApiKeys) { + if (std::getenv("api_key_1") == nullptr || std::getenv("api_key_2") == nullptr) { + LOG(INFO) << "Skipping test as api_key_1 or api_key_2 is not set"; + return; + } + + auto api_key1 = std::string(std::getenv("api_key_1")); + auto api_key2 = std::string(std::getenv("api_key_2")); + + auto embedder_map = TextEmbedderManager::get_instance()._get_text_embedders(); + + ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002:" + api_key1), embedder_map.end()); + ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002:" + api_key2), embedder_map.end()); + ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002"), embedder_map.end()); + + nlohmann::json model_config1 = R"({ + "model_name": "openai/text-embedding-ada-002" + })"_json; + + nlohmann::json model_config2 = model_config1; + + model_config1["api_key"] = api_key1; + model_config2["api_key"] = api_key2; + + size_t num_dim; + TextEmbedderManager::get_instance().validate_and_init_remote_model(model_config1, num_dim); + TextEmbedderManager::get_instance().validate_and_init_remote_model(model_config2, num_dim); + + embedder_map = TextEmbedderManager::get_instance()._get_text_embedders(); + + ASSERT_NE(embedder_map.find("openai/text-embedding-ada-002:" + api_key1), embedder_map.end()); + ASSERT_NE(embedder_map.find("openai/text-embedding-ada-002:" + api_key2), embedder_map.end()); + ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002"), embedder_map.end()); +} \ No newline at end of file diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index acb597fd..d2ae2599 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -4,6 +4,7 @@ #include #include #include "core_api_utils.h" +#include "raft_server.h" class CoreAPIUtilsTest : public ::testing::Test { protected: @@ -621,6 +622,7 @@ TEST_F(CoreAPIUtilsTest, PresetSingleSearch) { auto op = collectionManager.create_collection(schema); ASSERT_TRUE(op.ok()); + Collection* coll1 = op.get(); auto preset_value = R"( {"collection":"preset_coll", "per_page": "12"} @@ -1157,4 +1159,59 @@ TEST_F(CoreAPIUtilsTest, TestProxyTimeout) { ASSERT_EQ(408, resp->status_code); ASSERT_EQ("Server error on remote server. Please try again later.", nlohmann::json::parse(resp->body)["message"]); +} + +TEST_F(CoreAPIUtilsTest, SampleGzipIndexTest) { + Collection *coll_hnstories; + + std::vector fields = {field("title", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + coll_hnstories = collectionManager.get_collection("coll_hnstories").get(); + if(coll_hnstories == nullptr) { + coll_hnstories = collectionManager.create_collection("coll_hnstories", 4, fields, "title").get(); + } + + auto req = std::make_shared(); + std::ifstream infile(std::string(ROOT_DIR)+"test/resources/hnstories.jsonl.gz"); + std::stringstream outbuffer; + + infile.seekg (0, infile.end); + int length = infile.tellg(); + infile.seekg (0, infile.beg); + + req->body.resize(length); + infile.read(&req->body[0], length); + + auto res = ReplicationState::handle_gzip(req); + if (!res.error().empty()) { + LOG(ERROR) << res.error(); + FAIL(); + } else { + outbuffer << req->body; + } + + std::vector doc_lines; + std::string line; + while(std::getline(outbuffer, line)) { + doc_lines.push_back(line); + } + + ASSERT_EQ(14, doc_lines.size()); + ASSERT_EQ("{\"points\":1,\"title\":\"DuckDuckGo Settings\"}", doc_lines[0]); + ASSERT_EQ("{\"points\":1,\"title\":\"Making Twitter Easier to Use\"}", doc_lines[1]); + ASSERT_EQ("{\"points\":2,\"title\":\"London refers Uber app row to High Court\"}", doc_lines[2]); + ASSERT_EQ("{\"points\":1,\"title\":\"Young Global Leaders, who should be nominated? (World Economic Forum)\"}", doc_lines[3]); + ASSERT_EQ("{\"points\":1,\"title\":\"Blooki.st goes BETA in a few hours\"}", doc_lines[4]); + ASSERT_EQ("{\"points\":1,\"title\":\"Unicode Security Data: Beta Review\"}", doc_lines[5]); + ASSERT_EQ("{\"points\":2,\"title\":\"FileMap: MapReduce on the CLI\"}", doc_lines[6]); + ASSERT_EQ("{\"points\":1,\"title\":\"[Full Video] NBC News Interview with Edward Snowden\"}", doc_lines[7]); + ASSERT_EQ("{\"points\":1,\"title\":\"Hybrid App Monetization Example with Mobile Ads and In-App Purchases\"}", doc_lines[8]); + ASSERT_EQ("{\"points\":1,\"title\":\"We need oppinion from Android Developers\"}", doc_lines[9]); + ASSERT_EQ("{\"points\":1,\"title\":\"\\\\t Why Mobile Developers Should Care About Deep Linking\"}", doc_lines[10]); + ASSERT_EQ("{\"points\":2,\"title\":\"Are we getting too Sassy? Weighing up micro-optimisation vs. maintainability\"}", doc_lines[11]); + ASSERT_EQ("{\"points\":2,\"title\":\"Google's XSS game\"}", doc_lines[12]); + ASSERT_EQ("{\"points\":1,\"title\":\"Telemba Turns Your Old Roomba and Tablet Into a Telepresence Robot\"}", doc_lines[13]); + + infile.close(); } \ No newline at end of file diff --git a/test/resources/hnstories.jsonl.gz b/test/resources/hnstories.jsonl.gz new file mode 100644 index 00000000..ce374189 Binary files /dev/null and b/test/resources/hnstories.jsonl.gz differ