From 956d596e43057233d2c84486121a80a7ca850110 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 3 Aug 2023 15:23:30 +0530 Subject: [PATCH 1/2] Handle repeated facet values in arrays during searching. --- src/collection.cpp | 7 ------- src/index.cpp | 13 +++++++++---- test/collection_faceting_test.cpp | 18 ++++++++++++++++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index e60ffe41..1b00fb8a 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1727,13 +1727,6 @@ Option Collection::search(std::string raw_query, if(!facet_query.query.empty()) { // identify facet hash tokens - for(const auto& the_facet: facets) { - if(the_facet.field_name == facet_query.field_name) { - //the_facet.hash_tokens - break; - } - } - auto fq_field = search_schema.at(facet_query.field_name); bool is_cyrillic = Tokenizer::is_cyrillic(fq_field.locale); bool normalise = is_cyrillic ? false : true; diff --git a/src/index.cpp b/src/index.cpp index 6a1c99bf..69df9301 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1056,8 +1056,6 @@ void Index::tokenize_string_array_with_facets(const std::vector& st std::unordered_map>& token_to_offsets, std::vector& facet_hashes) { - std::set facet_hash_set; // required to deal with repeating phrases - for(size_t array_index = 0; array_index < strings.size(); array_index++) { const std::string& str = strings[array_index]; std::set token_set; // required to deal with repeating tokens @@ -1091,9 +1089,8 @@ void Index::tokenize_string_array_with_facets(const std::vector& st } } - if(is_facet && facet_hash_set.count(facet_hash) == 0) { + if(is_facet) { facet_hashes.push_back(facet_hash); - facet_hash_set.insert(facet_hash); } if(token_set.empty()) { @@ -1226,11 +1223,19 @@ void Index::do_facets(std::vector & facets, facet_query_t & facet_query, RETURN_CIRCUIT_BREAKER } + std::set unique_facet_hashes; + for(size_t j = 0; j < facet_hash_count; j++) { if(facet_field.is_array()) { fhash = facet_map_it->second.hashes[j]; } + + if(unique_facet_hashes.count(fhash) == 0) { + unique_facet_hashes.insert(fhash); + } else { + continue; + } if(should_compute_stats) { compute_facet_stats(a_facet, fhash, facet_field.type); diff --git a/test/collection_faceting_test.cpp b/test/collection_faceting_test.cpp index fec66c48..ccd8c124 100644 --- a/test/collection_faceting_test.cpp +++ b/test/collection_faceting_test.cpp @@ -1111,7 +1111,7 @@ TEST_F(CollectionFacetingTest, FacetByArrayField) { })"_json; auto doc2 = R"({ - "data": ["Foo", "Foo"] + "data": ["Foo", "Foo", "Bazinga"] })"_json; ASSERT_TRUE(coll1->add(doc1.dump(), CREATE).ok()); @@ -1124,9 +1124,23 @@ TEST_F(CollectionFacetingTest, FacetByArrayField) { ASSERT_EQ(2, results["found"].get()); ASSERT_EQ(1, results["facet_counts"].size()); ASSERT_EQ("data", results["facet_counts"][0]["field_name"]); - ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ(2, results["facet_counts"][0]["counts"].size()); ASSERT_EQ(2, results["facet_counts"][0]["counts"][0]["count"].get()); ASSERT_EQ("Foo", results["facet_counts"][0]["counts"][0]["value"].get()); + + ASSERT_EQ(1, results["facet_counts"][0]["counts"][1]["count"].get()); + ASSERT_EQ("Bazinga", results["facet_counts"][0]["counts"][1]["value"].get()); + + results = coll1->search("*", {}, "", {"data"}, {}, {0}, 10, 1, + token_ordering::FREQUENCY, {true}, 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "data:baz", 30, 4).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ("data", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ(1, results["facet_counts"][0]["counts"][0]["count"].get()); + ASSERT_EQ("Bazinga", results["facet_counts"][0]["counts"][0]["value"].get()); } TEST_F(CollectionFacetingTest, FacetParseTest){ From cc9af18d9ca5051da829db66dd0c5a2f40ea477c Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 3 Aug 2023 15:24:59 +0530 Subject: [PATCH 2/2] Refactor remote/local text embedder initialization. --- include/text_embedder.h | 6 +++-- include/text_embedder_manager.h | 7 +++--- src/collection_manager.cpp | 3 ++- src/field.cpp | 2 +- src/text_embedder.cpp | 16 ++++++++---- src/text_embedder_manager.cpp | 44 ++++++++++++++++----------------- 6 files changed, 43 insertions(+), 35 deletions(-) diff --git a/include/text_embedder.h b/include/text_embedder.h index 660e6ae4..ca64aa52 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -14,15 +14,16 @@ class TextEmbedder { // Constructor for local or public models TextEmbedder(const std::string& model_path); // Constructor for remote models - TextEmbedder(const nlohmann::json& model_config); + TextEmbedder(const nlohmann::json& model_config, size_t num_dims); ~TextEmbedder(); embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2); std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200); const std::string& get_vocab_file_name() const; + const size_t get_num_dim() const; bool is_remote() { return remote_embedder_ != nullptr; } - Option validate(size_t& num_dims); + Option validate(); private: std::unique_ptr session_; Ort::Env env_; @@ -33,5 +34,6 @@ class TextEmbedder { std::string vocab_file_name; static std::vector mean_pooling(const std::vector>& input); std::string output_tensor_name; + size_t num_dim; std::mutex mutex_; }; diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index b9158305..543e8f91 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -36,7 +36,6 @@ public: TextEmbedderManager& operator=(const TextEmbedderManager&) = delete; Option get_text_embedder(const nlohmann::json& model_config); - Option init_text_embedder(const nlohmann::json& model_config, size_t& num_dim); void delete_text_embedder(const std::string& model_path); void delete_all_text_embedders(); @@ -69,9 +68,9 @@ public: bool is_public_model(const std::string& model_name); static bool is_remote_model(const std::string& model_name); - static Option validate_and_init_remote_model(const nlohmann::json& model_config, size_t& num_dims); - static Option validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims); - static Option validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims); + Option validate_and_init_remote_model(const nlohmann::json& model_config, size_t& num_dims); + Option validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims); + Option validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims); private: TextEmbedderManager() = default; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 3610eeb5..9aec9e5c 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -80,13 +80,14 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection size_t num_dim = 0; auto& model_config = field_obj[fields::embed][fields::model_config]; - auto res = TextEmbedderManager::validate_and_init_model(model_config, num_dim); + auto res = TextEmbedderManager::get_instance().validate_and_init_model(model_config, num_dim); if(!res.ok()) { const std::string& model_name = model_config["model_name"].get(); LOG(ERROR) << "Error initializing model: " << model_name << ", error: " << res.error(); continue; } + field_obj[fields::num_dim] = num_dim; LOG(INFO) << "Model init done."; } diff --git a/src/field.cpp b/src/field.cpp index dc82a021..ca7d5149 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -1115,7 +1115,7 @@ Option field::validate_and_init_embed_fields(const std::vector(res.code(), res.error()); } diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index cdfa5eba..61244054 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -8,7 +8,7 @@ #include TextEmbedder::TextEmbedder(const std::string& model_name) { - // create environment + // create environment for local model Ort::SessionOptions session_options; auto providers = Ort::GetAvailableProviders(); for(auto& provider : providers) { @@ -50,14 +50,15 @@ TextEmbedder::TextEmbedder(const std::string& model_name) { if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) { Ort::AllocatorWithDefaultOptions allocator; output_tensor_name = std::string(session_->GetOutputNameAllocated(i, allocator).get()); + num_dim = shape[2]; break; } } } -TextEmbedder::TextEmbedder(const nlohmann::json& model_config) { +TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims) { const std::string& model_name = model_config["model_name"].get(); - LOG(INFO) << "Initializing embedding model: " << model_name; + LOG(INFO) << "Initializing remote embedding model: " << model_name; auto model_namespace = TextEmbedderManager::get_model_namespace(model_name); if(model_namespace == "openai") { @@ -78,6 +79,8 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config) { remote_embedder_ = std::make_unique(project_id, model_name, access_token, refresh_token, client_id, client_secret); } + + num_dim = num_dims; } @@ -267,7 +270,7 @@ batch_encoded_input_t TextEmbedder::batch_encode(const std::vector& return encoded_inputs; } -Option TextEmbedder::validate(size_t& num_dims) { +Option TextEmbedder::validate() { if(session_->GetInputCount() != 3 && session_->GetInputCount() != 2) { LOG(ERROR) << "Invalid model: input count is not 3 or 2"; return Option(400, "Invalid model: input count is not 3 or 2"); @@ -300,7 +303,6 @@ Option TextEmbedder::validate(size_t& num_dims) { for (size_t i = 0; i < output_tensor_count; i++) { auto shape = session_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) { - num_dims = shape[2]; found_output_tensor = true; break; } @@ -313,3 +315,7 @@ Option TextEmbedder::validate(size_t& num_dims) { return Option(true); } + +const size_t TextEmbedder::get_num_dim() const { + return num_dim; +} diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index 392fc61e..ac2c110f 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -42,7 +42,13 @@ Option TextEmbedderManager::validate_and_init_remote_model(const nlohmann: return Option(400, "Invalid model namespace"); } - return TextEmbedderManager::get_instance().init_text_embedder(model_config, num_dims); + std::unique_lock lock(text_embedders_mutex); + auto text_embedder_it = text_embedders.find(model_name); + if(text_embedder_it == text_embedders.end()) { + text_embedders.emplace(model_name, std::make_shared(model_config, num_dims)); + } + + return Option(true); } Option TextEmbedderManager::validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims) { @@ -53,9 +59,8 @@ Option TextEmbedderManager::validate_and_init_local_model(const nlohmann:: return public_model_op; } - Ort::SessionOptions session_options; - Ort::Env env; - std::string abs_path = TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::get_model_name_without_namespace(model_name)); + std::string abs_path = TextEmbedderManager::get_absolute_model_path( + TextEmbedderManager::get_model_name_without_namespace(model_name)); if(!std::filesystem::exists(abs_path)) { LOG(ERROR) << "Model file not found: " << abs_path; @@ -92,30 +97,25 @@ Option TextEmbedderManager::validate_and_init_local_model(const nlohmann:: return Option(400, "Invalid model type"); } } - - return TextEmbedderManager::get_instance().init_text_embedder(model_config, num_dims); -} -Option TextEmbedderManager::init_text_embedder(const nlohmann::json& model_config, size_t& num_dim) { std::unique_lock lock(text_embedders_mutex); - const std::string& model_name = model_config.at("model_name"); auto text_embedder_it = text_embedders.find(model_name); - if(text_embedder_it == text_embedders.end()) { - if(is_remote_model(model_name)) { - text_embedders.emplace(model_name, std::make_shared(model_config)); - } else { - const std::shared_ptr& embedder = std::make_shared( - get_model_name_without_namespace(model_name)); - auto validate_op = embedder->validate(num_dim); - if(!validate_op.ok()) { - return validate_op; - } - - text_embedders.emplace(model_name, embedder); - } + if(text_embedder_it != text_embedders.end()) { + num_dims = text_embedder_it->second->get_num_dim(); + return Option(true); } + const std::shared_ptr& embedder = std::make_shared( + get_model_name_without_namespace(model_name)); + + auto validate_op = embedder->validate(); + if(!validate_op.ok()) { + return validate_op; + } + + num_dims = embedder->get_num_dim(); + text_embedders.emplace(model_name, embedder); return Option(true); }