diff --git a/include/collection.h b/include/collection.h index 64f6dd30..bb0f581a 100644 --- a/include/collection.h +++ b/include/collection.h @@ -563,7 +563,7 @@ public: const bool prioritize_num_matching_fields = true, const bool group_missing_values = true, const bool converstaion = false, - const int conversation_model_id = -1, + const std::string& conversation_model_id = "", std::string conversation_id = "", const std::string& override_tags_str = "") const; diff --git a/include/conversation_model_manager.h b/include/conversation_model_manager.h index c695c552..8caf1f15 100644 --- a/include/conversation_model_manager.h +++ b/include/conversation_model_manager.h @@ -6,7 +6,7 @@ #include #include #include "store.h" - +#include "sole.hpp" class ConversationModelManager { @@ -16,22 +16,21 @@ class ConversationModelManager ConversationModelManager(ConversationModelManager&&) = delete; ConversationModelManager& operator=(const ConversationModelManager&) = delete; - static Option get_model(const uint32_t model_id); + static Option get_model(const std::string& model_id); static Option add_model(nlohmann::json model); - static Option delete_model(const uint32_t model_id); - static Option update_model(const uint32_t model_id, nlohmann::json model); + static Option delete_model(const std::string& model_id); + static Option update_model(const std::string& model_id, nlohmann::json model); static Option get_all_models(); static Option init(Store* store); private: - static inline std::unordered_map models; - static inline uint32_t model_id = 0; + static inline std::unordered_map models; static inline std::shared_mutex models_mutex; static constexpr char* MODEL_NEXT_ID = "$CVMN"; static constexpr char* MODEL_KEY_PREFIX = "$CVMP"; static inline Store* store; - static const std::string get_model_key(uint32_t model_id); + static const std::string get_model_key(const std::string& model_id); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index e08e8c59..e2564a40 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1616,7 +1616,7 @@ Option Collection::search(std::string raw_query, const bool prioritize_num_matching_fields, const bool group_missing_values, const bool conversation, - const int conversation_model_id, + const std::string& conversation_model_id, std::string conversation_id, const std::string& override_tags_str) const { std::shared_lock lock(mutex); @@ -1715,7 +1715,7 @@ Option Collection::search(std::string raw_query, std::string query = raw_query; if(conversation) { - if(conversation_model_id == -1) { + if(conversation_model_id.empty()) { return Option(400, "Conversation is enabled but no conversation model ID is provided."); } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 35016481..3939162f 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1260,7 +1260,7 @@ Option CollectionManager::do_search(std::map& re bool conversation = false; std::string conversation_id; - size_t conversation_model_id = std::numeric_limits::max(); + std::string conversation_model_id; std::string drop_tokens_mode_str = "right_to_left"; bool prioritize_num_matching_fields = true; @@ -1291,7 +1291,6 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, {REMOTE_EMBEDDING_NUM_TRIES, &remote_embedding_num_tries}, - {CONVERSATION_MODEL_ID, &conversation_model_id}, }; std::unordered_map str_values = { @@ -1307,6 +1306,7 @@ Option CollectionManager::do_search(std::map& re {CONVERSATION_ID, &conversation_id}, {DROP_TOKENS_MODE, &drop_tokens_mode_str}, {OVERRIDE_TAGS, &override_tags}, + {CONVERSATION_MODEL_ID, &conversation_model_id}, }; std::unordered_map bool_values = { @@ -1525,7 +1525,7 @@ Option CollectionManager::do_search(std::map& re prioritize_num_matching_fields, group_missing_values, conversation, - (conversation_model_id == std::numeric_limits::max()) ? -1 : static_cast(conversation_model_id), + conversation_model_id, conversation_id, override_tags); diff --git a/src/conversation_model_manager.cpp b/src/conversation_model_manager.cpp index 95e71c3d..2a72056a 100644 --- a/src/conversation_model_manager.cpp +++ b/src/conversation_model_manager.cpp @@ -1,7 +1,7 @@ #include "conversation_model_manager.h" #include "conversation_model.h" -Option ConversationModelManager::get_model(const uint32_t model_id) { +Option ConversationModelManager::get_model(const std::string& model_id) { std::shared_lock lock(models_mutex); auto it = models.find(model_id); if (it == models.end()) { @@ -17,6 +17,7 @@ Option ConversationModelManager::add_model(nlohmann::json model) if (!validate_res.ok()) { return Option(validate_res.code(), validate_res.error()); } + auto model_id = sole::uuid4().str(); model["id"] = model_id; auto model_key = get_model_key(model_id); @@ -25,12 +26,12 @@ Option ConversationModelManager::add_model(nlohmann::json model) return Option(500, "Error while inserting model into the store"); } - models[model_id++] = model; + models[model_id] = model; return Option(model); } -Option ConversationModelManager::delete_model(const uint32_t model_id) { +Option ConversationModelManager::delete_model(const std::string& model_id) { std::unique_lock lock(models_mutex); auto it = models.find(model_id); if (it == models.end()) { @@ -56,7 +57,7 @@ Option ConversationModelManager::get_all_models() { return Option(models_json); } -Option ConversationModelManager::update_model(const uint32_t model_id, nlohmann::json model) { +Option ConversationModelManager::update_model(const std::string& model_id, nlohmann::json model) { std::unique_lock lock(models_mutex); auto validate_res = ConversationModel::validate_model(model); if (!validate_res.ok()) { @@ -85,24 +86,13 @@ Option ConversationModelManager::init(Store* store) { std::unique_lock lock(models_mutex); ConversationModelManager::store = store; - std::string last_id_str; - StoreStatus last_id_str_status = store->get(std::string(MODEL_NEXT_ID), last_id_str); - - if(last_id_str_status == StoreStatus::ERROR) { - return Option(500, "Error while loading conversations next id from the store"); - } else if(last_id_str_status == StoreStatus::FOUND) { - model_id = StringUtils::deserialize_uint32_t(last_id_str); - } else { - model_id = 0; - } - std::vector model_strs; store->scan_fill(std::string(MODEL_KEY_PREFIX) + "_", std::string(MODEL_KEY_PREFIX) + "`", model_strs); int loaded_models = 0; for(auto& model_str : model_strs) { nlohmann::json model_json = nlohmann::json::parse(model_str); - int model_id = model_json["id"]; + std::string model_id = model_json["id"]; models[model_id] = model_json; loaded_models++; } @@ -110,6 +100,6 @@ Option ConversationModelManager::init(Store* store) { return Option(loaded_models); } -const std::string ConversationModelManager::get_model_key(uint32_t model_id) { - return std::string(MODEL_KEY_PREFIX) + "_" + std::to_string(model_id); +const std::string ConversationModelManager::get_model_key(const std::string& model_id) { + return std::string(MODEL_KEY_PREFIX) + "_" + model_id; } \ No newline at end of file diff --git a/src/core_api.cpp b/src/core_api.cpp index d1e5e5a9..ffa7ba6a 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -559,15 +559,10 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p return false; } - if(StringUtils::is_uint32_t(orig_req_params["conversation_model_id"])) { - uint32_t conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]); - auto conversation_model = ConversationModelManager::get_model(conversation_model_id); + const std::string& conversation_model_id = orig_req_params["conversation_model_id"]; + auto conversation_model = ConversationModelManager::get_model(conversation_model_id); - if(!conversation_model.ok()) { - res->set_400("`conversation_model_id` is invalid."); - return false; - } - } else { + if(!conversation_model.ok()) { res->set_400("`conversation_model_id` is invalid."); return false; } @@ -586,7 +581,7 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p common_query = orig_req_params["q"]; if(conversation_history) { - auto conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]); + const std::string& conversation_model_id = orig_req_params["conversation_model_id"]; auto conversation_id = orig_req_params["conversation_id"]; auto conversation_model = ConversationModelManager::get_model(conversation_model_id).get(); auto conversation_history = ConversationManager::get_instance().get_conversation(conversation_id).get(); @@ -601,8 +596,6 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p } } - - for(size_t i = 0; i < searches.size(); i++) { auto& search_params = searches[i]; @@ -753,7 +746,7 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p } } - auto conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]); + const std::string& conversation_model_id = orig_req_params["conversation_model_id"]; auto conversation_model = ConversationModelManager::get_model(conversation_model_id).get(); auto prompt = req->params["q"]; @@ -2714,11 +2707,7 @@ bool post_conversation_model(const std::shared_ptr& req, const std::sh } bool get_conversation_model(const std::shared_ptr& req, const std::shared_ptr& res) { - if(!StringUtils::is_uint32_t(req->params["id"])) { - res->set_400("Invalid ID."); - return false; - } - const int model_id = std::stoi(req->params["id"]); + const std::string& model_id = req->params["id"]; auto model_op = ConversationModelManager::get_model(model_id); @@ -2752,13 +2741,8 @@ bool get_conversation_models(const std::shared_ptr& req, const std::sh return true; } - bool del_conversation_model(const std::shared_ptr& req, const std::shared_ptr& res) { - if(!StringUtils::is_uint32_t(req->params["id"])) { - res->set_400("Invalid ID."); - return false; - } - const int model_id = std::stoi(req->params["id"]); + const std::string& model_id = req->params["id"]; auto model_op = ConversationModelManager::delete_model(model_id); @@ -2776,11 +2760,7 @@ bool del_conversation_model(const std::shared_ptr& req, const std::sha } bool put_conversation_model(const std::shared_ptr& req, const std::shared_ptr& res) { - if(!StringUtils::is_uint32_t(req->params["id"])) { - res->set_400("Invalid ID."); - return false; - } - const int model_id = std::stoi(req->params["id"]); + const std::string& model_id = req->params["id"]; nlohmann::json req_json; diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 6f78041b..6d06df10 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -3397,7 +3397,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "foo").get(); + true, true, false, "", "", "foo").get(); ASSERT_EQ(2, results["hits"].size()); @@ -3410,7 +3410,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "alpha").get(); + true, true, false, "", "", "alpha").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3424,7 +3424,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "beta").get(); + true, true, false, "", "", "beta").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3438,7 +3438,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "alpha,beta").get(); + true, true, false, "", "", "alpha,beta").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3452,7 +3452,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "").get(); + true, true, false, "", "", "").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); @@ -3532,7 +3532,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTagsPartialMatch) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "alpha,zeta").get(); + true, true, false, "", "", "alpha,zeta").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); @@ -3632,7 +3632,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTagsWithoutStopProcessing) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "alpha").get(); + true, true, false, "", "", "alpha").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); @@ -3708,7 +3708,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", override_tags).get(); + true, true, false, "", "", override_tags).get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3721,7 +3721,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", override_tags).get(); + true, true, false, "", "", override_tags).get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3751,7 +3751,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", override_tags).get(); + true, true, false, "", "", override_tags).get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); @@ -3805,7 +3805,7 @@ TEST_F(CollectionOverrideTest, TagsOnlyRule) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "listing").get(); + true, true, false, "", "", "listing").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3834,7 +3834,7 @@ TEST_F(CollectionOverrideTest, TagsOnlyRule) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", "listing2").get(); + true, true, false, "", "", "listing2").get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); @@ -3849,7 +3849,7 @@ TEST_F(CollectionOverrideTest, TagsOnlyRule) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", override_tag).get(); + true, true, false, "", "", override_tag).get(); ASSERT_EQ(0, results["hits"].size()); @@ -3964,7 +3964,7 @@ TEST_F(CollectionOverrideTest, WildcardSearchOverride) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", override_tags).get(); + true, true, false, "", "", override_tags).get(); ASSERT_EQ(1, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); @@ -3996,7 +3996,7 @@ TEST_F(CollectionOverrideTest, WildcardSearchOverride) { "", "", {}, 1000, true, false, true, "", false, 10000, 4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", - true, true, false, -1, "", override_tags).get(); + true, true, false, "", "", override_tags).get(); ASSERT_EQ(3, results["hits"].size()); ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index bbef267f..9c0b6b0e 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -2919,7 +2919,7 @@ TEST_F(CollectionVectorTest, TestQAConversation) { 0, spp::sparse_hash_set(), {}, 10, "", 30, 4, "", 1, "", "", {}, 3, "", "", {}, 4294967295UL, true, false, true, "", false, 6000000UL, 4, 7, fallback, 4, {off}, 32767UL, 32767UL, 2, 2, false, "", - true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, true, true, 0); + true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, true, true, model_add_op.get()["id"]); ASSERT_TRUE(results_op.ok()); @@ -3344,7 +3344,7 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { std::shared_ptr res = std::make_shared(nullptr); req->params["conversation"] = "true"; - req->params["conversation_model_id"] = to_string(model_id); + req->params["conversation_model_id"] = model_id; req->params["q"] = "cat"; req->body = search_body.dump();