From f742417443c764a0fcf64a4902ce005c74fa1df2 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 1 Dec 2023 23:44:36 +0300 Subject: [PATCH] Make ```ConversationManager``` singleton & update conversation garbage collector thread --- include/conversation_manager.h | 44 +++++++++++++++-------- src/collection.cpp | 16 ++++----- src/conversation_manager.cpp | 22 ++++++++++-- src/core_api.cpp | 24 ++++++------- src/typesense_server_utils.cpp | 22 +++++------- test/collection_vector_search_test.cpp | 4 +-- test/conversation_test.cpp | 50 +++++++++++++------------- test/core_api_utils_test.cpp | 2 +- 8 files changed, 105 insertions(+), 79 deletions(-) diff --git a/include/conversation_manager.h b/include/conversation_manager.h index 437876da..6abeffe1 100644 --- a/include/conversation_manager.h +++ b/include/conversation_manager.h @@ -12,31 +12,45 @@ class ConversationManager { public: - ConversationManager() = delete; - static Option create_conversation(const nlohmann::json& conversation); - static Option get_conversation(const std::string& conversation_id); - static Option append_conversation(const std::string& conversation_id, const nlohmann::json& message); - static Option truncate_conversation(nlohmann::json conversation); - static Option update_conversation(nlohmann::json conversation); + + ConversationManager(const ConversationManager&) = delete; + ConversationManager(ConversationManager&&) = delete; + ConversationManager& operator=(const ConversationManager&) = delete; + ConversationManager& operator=(ConversationManager&&) = delete; + static ConversationManager& get_instance() { + static ConversationManager instance; + return instance; + } + Option create_conversation(const nlohmann::json& conversation); + Option get_conversation(const std::string& conversation_id); + Option append_conversation(const std::string& conversation_id, const nlohmann::json& message); + Option truncate_conversation(nlohmann::json conversation); + Option update_conversation(nlohmann::json conversation); static size_t get_token_count(const nlohmann::json& message); - static Option delete_conversation(const std::string& conversation_id); - static Option get_all_conversations(); + Option delete_conversation(const std::string& conversation_id); + Option get_all_conversations(); static constexpr size_t MAX_TOKENS = 3000; - static Option init(Store* store); - static void clear_expired_conversations(); - static void _set_ttl_offset(size_t offset) { + Option init(Store* store); + void clear_expired_conversations(); + void run(); + void stop(); + void _set_ttl_offset(size_t offset) { TTL_OFFSET = offset; } private: - static inline std::unordered_map conversations; - static inline std::shared_mutex conversations_mutex; + ConversationManager() {} + std::unordered_map conversations; + std::mutex conversations_mutex; static constexpr char* CONVERSATION_RPEFIX = "$CNVP"; - static inline Store* store; + Store* store; static const std::string get_conversation_key(const std::string& conversation_id); static constexpr size_t CONVERSATION_TTL = 60 * 60 * 24; - static inline size_t TTL_OFFSET = 0; + size_t TTL_OFFSET = 0; + + std::atomic quit = false; + std::condition_variable cv; }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 694ea1bb..78ec5762 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1712,14 +1712,14 @@ Option Collection::search(std::string raw_query, return Option(400, "Conversation ID provided but conversation is not enabled for this collection."); } - auto conversation_history_op = ConversationManager::get_conversation(conversation_id); + auto conversation_history_op = ConversationManager::get_instance().get_conversation(conversation_id); if(!conversation_history_op.ok()) { return Option(400, conversation_history_op.error()); } auto conversation_history = conversation_history_op.get(); - auto truncate_conversation_history = ConversationManager::truncate_conversation(conversation_history_op.get()["conversation"]); + auto truncate_conversation_history = ConversationManager::get_instance().truncate_conversation(conversation_history_op.get()["conversation"]); conversation_history["conversation"] = truncate_conversation_history.get(); @@ -2594,7 +2594,7 @@ Option Collection::search(std::string raw_query, } // remove document with lowest score until total tokens is less than MAX_TOKENS - while(ConversationManager::get_token_count(docs_array) > ConversationManager::MAX_TOKENS) { + while(ConversationManager::get_instance().get_token_count(docs_array) > ConversationManager::get_instance().MAX_TOKENS) { try { docs_array.erase(docs_array.size() - 1); } catch(...) { @@ -2623,9 +2623,9 @@ Option Collection::search(std::string raw_query, } if(has_conversation_history) { - ConversationManager::append_conversation(conversation_id, formatted_question_op.get()); - ConversationManager::append_conversation(conversation_id, formatted_answer_op.get()); - auto get_conversation_op = ConversationManager::get_conversation(conversation_id); + ConversationManager::get_instance().append_conversation(conversation_id, formatted_question_op.get()); + ConversationManager::get_instance().append_conversation(conversation_id, formatted_answer_op.get()); + auto get_conversation_op = ConversationManager::get_instance().get_conversation(conversation_id); if(!get_conversation_op.ok()) { return Option(get_conversation_op.code(), get_conversation_op.error()); } @@ -2640,12 +2640,12 @@ Option Collection::search(std::string raw_query, conversation_history.push_back(formatted_question_op.get()); conversation_history.push_back(formatted_answer_op.get()); - auto create_conversation_op = ConversationManager::create_conversation(conversation_history); + auto create_conversation_op = ConversationManager::get_instance().create_conversation(conversation_history); if(!create_conversation_op.ok()) { return Option(create_conversation_op.code(), create_conversation_op.error()); } - auto get_conversation_op = ConversationManager::get_conversation(create_conversation_op.get()); + auto get_conversation_op = ConversationManager::get_instance().get_conversation(create_conversation_op.get()); if(!get_conversation_op.ok()) { return Option(get_conversation_op.code(), get_conversation_op.error()); } diff --git a/src/conversation_manager.cpp b/src/conversation_manager.cpp index 3a785f53..52cf9922 100644 --- a/src/conversation_manager.cpp +++ b/src/conversation_manager.cpp @@ -26,7 +26,7 @@ Option ConversationManager::create_conversation(const nlohmann::jso } Option ConversationManager::get_conversation(const std::string& conversation_id) { - std::shared_lock lock(conversations_mutex); + std::unique_lock lock(conversations_mutex); auto conversation = conversations.find(conversation_id); if (conversation == conversations.end()) { return Option(404, "Conversation not found"); @@ -117,7 +117,7 @@ Option ConversationManager::delete_conversation(const std::strin } Option ConversationManager::get_all_conversations() { - std::shared_lock lock(conversations_mutex); + std::unique_lock lock(conversations_mutex); nlohmann::json all_conversations = nlohmann::json::array(); for(auto& conversation : conversations) { all_conversations.push_back(conversation.second); @@ -227,4 +227,22 @@ Option ConversationManager::update_conversation(nlohmann::json c conversations[conversation_id] = actual_conversation; return Option(actual_conversation); +} + +void ConversationManager::run() { + while(!quit) { + std::unique_lock lock(conversations_mutex); + cv.wait_for(lock, std::chrono::seconds(60), [&] { return quit.load(); }); + + if(quit) { + return; + } + + clear_expired_conversations(); + } +} + +void ConversationManager::stop() { + quit = true; + cv.notify_all(); } \ No newline at end of file diff --git a/src/core_api.cpp b/src/core_api.cpp index e3afb1d2..e0c5d030 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -570,7 +570,7 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p if(conversation_history) { std::string conversation_id = orig_req_params["conversation_id"]; - auto conversation_history = ConversationManager::get_conversation(conversation_id); + auto conversation_history = ConversationManager::get_instance().get_conversation(conversation_id); if(!conversation_history.ok()) { res->set_400("`conversation_id` is invalid."); @@ -584,7 +584,7 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p auto conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]); auto conversation_id = orig_req_params["conversation_id"]; auto conversation_model = ConversationModelManager::get_model(conversation_model_id).get(); - auto conversation_history = ConversationManager::get_conversation(conversation_id).get(); + auto conversation_history = ConversationManager::get_instance().get_conversation(conversation_id).get(); auto generate_standalone_q = ConversationModel::get_standalone_question(conversation_history, common_query, conversation_model); if(!generate_standalone_q.ok()) { @@ -704,7 +704,7 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p // We have to pop a document from the search result with max size // Until we do not exceed MAX_TOKENS limit - while(ConversationManager::get_token_count(result_docs_arr) > ConversationManager::MAX_TOKENS) { + while(ConversationManager::get_instance().get_token_count(result_docs_arr) > ConversationManager::get_instance().MAX_TOKENS) { // sort the result_docs_arr by size descending std::sort(result_docs_arr.begin(), result_docs_arr.end(), [](const auto& a, const auto& b) { return a.size() > b.size(); @@ -760,9 +760,9 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p if(conversation_history) { std::string conversation_id = orig_req_params["conversation_id"]; - ConversationManager::append_conversation(conversation_id, formatted_question_op.get()); - ConversationManager::append_conversation(conversation_id, formatted_answer_op.get()); - auto get_conversation_op = ConversationManager::get_conversation(conversation_id); + ConversationManager::get_instance().append_conversation(conversation_id, formatted_question_op.get()); + ConversationManager::get_instance().append_conversation(conversation_id, formatted_answer_op.get()); + auto get_conversation_op = ConversationManager::get_instance().get_conversation(conversation_id); if(!get_conversation_op.ok()) { res->set_400(get_conversation_op.error()); return false; @@ -778,13 +778,13 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p conversation_history.push_back(formatted_question_op.get()); conversation_history.push_back(formatted_answer_op.get()); - auto create_conversation_op = ConversationManager::create_conversation(conversation_history); + auto create_conversation_op = ConversationManager::get_instance().create_conversation(conversation_history); if(!create_conversation_op.ok()) { res->set_400(create_conversation_op.error()); return false; } - auto get_conversation_op = ConversationManager::get_conversation(create_conversation_op.get()); + auto get_conversation_op = ConversationManager::get_instance().get_conversation(create_conversation_op.get()); if(!get_conversation_op.ok()) { res->set_400(get_conversation_op.error()); return false; @@ -2607,7 +2607,7 @@ bool post_proxy(const std::shared_ptr& req, const std::shared_ptr& req, const std::shared_ptr& res) { std::string conversation_id = req->params["id"]; - auto conversation_op = ConversationManager::get_conversation(conversation_id); + auto conversation_op = ConversationManager::get_instance().get_conversation(conversation_id); if(!conversation_op.ok()) { res->set(conversation_op.code(), conversation_op.error()); @@ -2622,7 +2622,7 @@ bool get_conversation(const std::shared_ptr& req, const std::shared_pt bool del_conversation(const std::shared_ptr& req, const std::shared_ptr& res) { std::string conversation_id = req->params["id"]; - auto conversation_op = ConversationManager::delete_conversation(conversation_id); + auto conversation_op = ConversationManager::get_instance().delete_conversation(conversation_id); if(!conversation_op.ok()) { res->set(conversation_op.code(), conversation_op.error()); @@ -2634,7 +2634,7 @@ bool del_conversation(const std::shared_ptr& req, const std::shared_pt } bool get_conversations(const std::shared_ptr& req, const std::shared_ptr& res) { - auto conversations_op = ConversationManager::get_all_conversations(); + auto conversations_op = ConversationManager::get_instance().get_all_conversations(); if(!conversations_op.ok()) { res->set(conversations_op.code(), conversations_op.error()); @@ -2660,7 +2660,7 @@ bool put_conversation(const std::shared_ptr& req, const std::shared_pt req_json["id"] = conversation_id; - auto conversation_op = ConversationManager::update_conversation(req_json); + auto conversation_op = ConversationManager::get_instance().update_conversation(req_json); if(!conversation_op.ok()) { res->set(conversation_op.code(), conversation_op.error()); diff --git a/src/typesense_server_utils.cpp b/src/typesense_server_utils.cpp index 3105e102..f8c775ed 100644 --- a/src/typesense_server_utils.cpp +++ b/src/typesense_server_utils.cpp @@ -449,7 +449,7 @@ int run_server(const Config & config, const std::string & version, void (*master } EmbedderManager::set_model_dir(config.get_data_dir() + "/models"); - auto conversations_init = ConversationManager::init(&store); + auto conversations_init = ConversationManager::get_instance().init(&store); if(!conversations_init.ok()) { LOG(INFO) << "Failed to initialize conversation manager: " << conversations_init.error(); @@ -487,14 +487,7 @@ int run_server(const Config & config, const std::string & version, void (*master std::thread conersation_garbage_collector_thread([]() { LOG(INFO) << "Conversation garbage collector thread started."; - int last_clear_time = 0; - while(!brpc::IsAskedToQuit()) { - if(last_clear_time + 60 < std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()) { - last_clear_time = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); - ConversationManager::clear_expired_conversations(); - } - std::this_thread::sleep_for(std::chrono::seconds(5)); - } + ConversationManager::get_instance().run(); }); HouseKeeper::get_instance().init(config.get_housekeeping_interval()); @@ -526,6 +519,12 @@ int run_server(const Config & config, const std::string & version, void (*master LOG(INFO) << "Waiting for event sink thread to be done..."; event_sink_thread.join(); + LOG(INFO) << "Shutting down conversation garbage collector thread..."; + ConversationManager::get_instance().stop(); + + LOG(INFO) << "Waiting for conversation garbage collector thread to be done..."; + conersation_garbage_collector_thread.join(); + LOG(INFO) << "Waiting for housekeeping thread to be done..."; HouseKeeper::get_instance().stop(); housekeeping_thread.join(); @@ -539,13 +538,8 @@ int run_server(const Config & config, const std::string & version, void (*master app_thread_pool.shutdown(); LOG(INFO) << "Shutting down replication_thread_pool."; - replication_thread_pool.shutdown(); - LOG(INFO) << "Shutting down conversation garbage collector thread."; - - conersation_garbage_collector_thread.join(); - server->stop(); }); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 243a0594..9c0d8a02 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -27,7 +27,7 @@ protected: collectionManager.load(8, 1000); ConversationModelManager::init(store); - ConversationManager::init(store); + ConversationManager::get_instance().init(store); } virtual void SetUp() { @@ -2931,7 +2931,7 @@ TEST_F(CollectionVectorTest, TestQAConversation) { // test getting conversation history - auto history_op = ConversationManager::get_conversation(conversation_id); + auto history_op = ConversationManager::get_instance().get_conversation(conversation_id); ASSERT_TRUE(history_op.ok()); diff --git a/test/conversation_test.cpp b/test/conversation_test.cpp index abfbb32a..375e5a3c 100644 --- a/test/conversation_test.cpp +++ b/test/conversation_test.cpp @@ -8,13 +8,13 @@ class ConversationTest : public ::testing::Test { std::string state_dir_path = "/tmp/typesense_test/conversation_test"; system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); store = new Store(state_dir_path); - ConversationManager::init(store); + ConversationManager::get_instance().init(store); } void TearDown() override { - auto conversations = ConversationManager::get_all_conversations(); + auto conversations = ConversationManager::get_instance().get_all_conversations(); for (auto& conversation : conversations.get()) { - ConversationManager::delete_conversation(conversation["id"]); + ConversationManager::get_instance().delete_conversation(conversation["id"]); } delete store; } @@ -25,20 +25,20 @@ class ConversationTest : public ::testing::Test { TEST_F(ConversationTest, CreateConversation) { nlohmann::json conversation = nlohmann::json::array(); - auto create_res = ConversationManager::create_conversation(conversation); + auto create_res = ConversationManager::get_instance().create_conversation(conversation); ASSERT_TRUE(create_res.ok()); } TEST_F(ConversationTest, CreateConversationInvalidType) { nlohmann::json conversation = nlohmann::json::object(); - auto create_res = ConversationManager::create_conversation(conversation); + auto create_res = ConversationManager::get_instance().create_conversation(conversation); ASSERT_FALSE(create_res.ok()); ASSERT_EQ(create_res.code(), 400); ASSERT_EQ(create_res.error(), "Conversation is not an array"); } TEST_F(ConversationTest, GetInvalidConversation) { - auto get_res = ConversationManager::get_conversation("qwerty"); + auto get_res = ConversationManager::get_instance().get_conversation("qwerty"); ASSERT_FALSE(get_res.ok()); ASSERT_EQ(get_res.code(), 404); ASSERT_EQ(get_res.error(), "Conversation not found"); @@ -49,16 +49,16 @@ TEST_F(ConversationTest, AppendConversation) { nlohmann::json message = nlohmann::json::object(); message["user"] = "Hello"; conversation.push_back(message); - auto create_res = ConversationManager::create_conversation(conversation); + auto create_res = ConversationManager::get_instance().create_conversation(conversation); ASSERT_TRUE(create_res.ok()); std::string conversation_id = create_res.get(); - auto append_res = ConversationManager::append_conversation(conversation_id, message); + auto append_res = ConversationManager::get_instance().append_conversation(conversation_id, message); ASSERT_TRUE(append_res.ok()); ASSERT_EQ(append_res.get(), true); - auto get_res = ConversationManager::get_conversation(conversation_id); + auto get_res = ConversationManager::get_instance().get_conversation(conversation_id); ASSERT_TRUE(get_res.ok()); ASSERT_TRUE(get_res.get()["conversation"].is_array()); @@ -73,14 +73,14 @@ TEST_F(ConversationTest, AppendInvalidConversation) { nlohmann::json conversation = nlohmann::json::array(); nlohmann::json message = nlohmann::json::object(); message["user"] = "Hello"; - auto create_res = ConversationManager::create_conversation(conversation); + auto create_res = ConversationManager::get_instance().create_conversation(conversation); ASSERT_TRUE(create_res.ok()); std::string conversation_id = create_res.get(); message = "invalid"; - auto append_res = ConversationManager::append_conversation(conversation_id, message); + auto append_res = ConversationManager::get_instance().append_conversation(conversation_id, message); ASSERT_FALSE(append_res.ok()); ASSERT_EQ(append_res.code(), 400); ASSERT_EQ(append_res.error(), "Message is not an object or array"); @@ -88,11 +88,11 @@ TEST_F(ConversationTest, AppendInvalidConversation) { TEST_F(ConversationTest, DeleteConversation) { nlohmann::json conversation = nlohmann::json::array(); - auto create_res = ConversationManager::create_conversation(conversation); + auto create_res = ConversationManager::get_instance().create_conversation(conversation); ASSERT_TRUE(create_res.ok()); std::string conversation_id = create_res.get(); - auto delete_res = ConversationManager::delete_conversation(conversation_id); + auto delete_res = ConversationManager::get_instance().delete_conversation(conversation_id); ASSERT_TRUE(delete_res.ok()); auto delete_res_json = delete_res.get(); @@ -100,14 +100,14 @@ TEST_F(ConversationTest, DeleteConversation) { ASSERT_EQ(delete_res_json["id"], conversation_id); ASSERT_TRUE(delete_res_json["conversation"].is_array()); - auto get_res = ConversationManager::get_conversation(conversation_id); + auto get_res = ConversationManager::get_instance().get_conversation(conversation_id); ASSERT_FALSE(get_res.ok()); ASSERT_EQ(get_res.code(), 404); ASSERT_EQ(get_res.error(), "Conversation not found"); } TEST_F(ConversationTest, DeleteInvalidConversation) { - auto delete_res = ConversationManager::delete_conversation("qwerty"); + auto delete_res = ConversationManager::get_instance().delete_conversation("qwerty"); ASSERT_FALSE(delete_res.ok()); ASSERT_EQ(delete_res.code(), 404); ASSERT_EQ(delete_res.error(), "Conversation not found"); @@ -121,21 +121,21 @@ TEST_F(ConversationTest, TruncateConversation) { conversation.push_back(message); } - auto truncated = ConversationManager::truncate_conversation(conversation); + auto truncated = ConversationManager::get_instance().truncate_conversation(conversation); ASSERT_TRUE(truncated.ok()); ASSERT_TRUE(truncated.get().size() < conversation.size()); } TEST_F(ConversationTest, TruncateConversationEmpty) { nlohmann::json conversation = nlohmann::json::array(); - auto truncated = ConversationManager::truncate_conversation(conversation); + auto truncated = ConversationManager::get_instance().truncate_conversation(conversation); ASSERT_TRUE(truncated.ok()); ASSERT_TRUE(truncated.get().size() == 0); } TEST_F(ConversationTest, TruncateConversationInvalidType) { nlohmann::json conversation = nlohmann::json::object(); - auto truncated = ConversationManager::truncate_conversation(conversation); + auto truncated = ConversationManager::get_instance().truncate_conversation(conversation); ASSERT_FALSE(truncated.ok()); ASSERT_EQ(truncated.code(), 400); ASSERT_EQ(truncated.error(), "Conversation history is not an array"); @@ -147,28 +147,28 @@ TEST_F(ConversationTest, TestConversationExpire) { nlohmann::json message = nlohmann::json::object(); message["user"] = "Hello"; conversation.push_back(message); - auto create_res = ConversationManager::create_conversation(conversation); + auto create_res = ConversationManager::get_instance().create_conversation(conversation); ASSERT_TRUE(create_res.ok()); std::string conversation_id = create_res.get(); - ConversationManager::clear_expired_conversations(); + ConversationManager::get_instance().clear_expired_conversations(); - auto get_res = ConversationManager::get_conversation(conversation_id); + auto get_res = ConversationManager::get_instance().get_conversation(conversation_id); ASSERT_TRUE(get_res.ok()); ASSERT_TRUE(get_res.get()["conversation"].is_array()); ASSERT_EQ(get_res.get()["id"], conversation_id); ASSERT_EQ(get_res.get()["conversation"].size(), 1); - ConversationManager::_set_ttl_offset(24 * 60 * 60 * 2); - ConversationManager::clear_expired_conversations(); + ConversationManager::get_instance()._set_ttl_offset(24 * 60 * 60 * 2); + ConversationManager::get_instance().clear_expired_conversations(); - get_res = ConversationManager::get_conversation(conversation_id); + get_res = ConversationManager::get_instance().get_conversation(conversation_id); ASSERT_FALSE(get_res.ok()); ASSERT_EQ(get_res.code(), 404); ASSERT_EQ(get_res.error(), "Conversation not found"); - ConversationManager::_set_ttl_offset(0); + ConversationManager::get_instance()._set_ttl_offset(0); } diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index 9ad3a006..d2734ec3 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -27,7 +27,7 @@ protected: collectionManager.load(8, 1000); ConversationModelManager::init(store); - ConversationManager::init(store); + ConversationManager::get_instance().init(store); } virtual void SetUp() {