Merge pull request #1410 from ozanarmagan/v0.26-facets

Make ```ConversationManager``` singleton & update conversation garbag…
This commit is contained in:
Kishore Nallan 2023-12-03 14:37:36 +05:30 committed by GitHub
commit bd50ffe747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 79 deletions

View File

@ -12,31 +12,45 @@
class ConversationManager {
public:
ConversationManager() = delete;
static Option<std::string> create_conversation(const nlohmann::json& conversation);
static Option<nlohmann::json> get_conversation(const std::string& conversation_id);
static Option<bool> append_conversation(const std::string& conversation_id, const nlohmann::json& message);
static Option<nlohmann::json> truncate_conversation(nlohmann::json conversation);
static Option<nlohmann::json> update_conversation(nlohmann::json conversation);
ConversationManager(const ConversationManager&) = delete;
ConversationManager(ConversationManager&&) = delete;
ConversationManager& operator=(const ConversationManager&) = delete;
ConversationManager& operator=(ConversationManager&&) = delete;
static ConversationManager& get_instance() {
static ConversationManager instance;
return instance;
}
Option<std::string> create_conversation(const nlohmann::json& conversation);
Option<nlohmann::json> get_conversation(const std::string& conversation_id);
Option<bool> append_conversation(const std::string& conversation_id, const nlohmann::json& message);
Option<nlohmann::json> truncate_conversation(nlohmann::json conversation);
Option<nlohmann::json> update_conversation(nlohmann::json conversation);
static size_t get_token_count(const nlohmann::json& message);
static Option<nlohmann::json> delete_conversation(const std::string& conversation_id);
static Option<nlohmann::json> get_all_conversations();
Option<nlohmann::json> delete_conversation(const std::string& conversation_id);
Option<nlohmann::json> get_all_conversations();
static constexpr size_t MAX_TOKENS = 3000;
static Option<int> init(Store* store);
static void clear_expired_conversations();
static void _set_ttl_offset(size_t offset) {
Option<int> init(Store* store);
void clear_expired_conversations();
void run();
void stop();
void _set_ttl_offset(size_t offset) {
TTL_OFFSET = offset;
}
private:
static inline std::unordered_map<std::string, nlohmann::json> conversations;
static inline std::shared_mutex conversations_mutex;
ConversationManager() {}
std::unordered_map<std::string, nlohmann::json> conversations;
std::mutex conversations_mutex;
static constexpr char* CONVERSATION_RPEFIX = "$CNVP";
static inline Store* store;
Store* store;
static const std::string get_conversation_key(const std::string& conversation_id);
static constexpr size_t CONVERSATION_TTL = 60 * 60 * 24;
static inline size_t TTL_OFFSET = 0;
size_t TTL_OFFSET = 0;
std::atomic<bool> quit = false;
std::condition_variable cv;
};

View File

@ -1712,14 +1712,14 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
return Option<nlohmann::json>(400, "Conversation ID provided but conversation is not enabled for this collection.");
}
auto conversation_history_op = ConversationManager::get_conversation(conversation_id);
auto conversation_history_op = ConversationManager::get_instance().get_conversation(conversation_id);
if(!conversation_history_op.ok()) {
return Option<nlohmann::json>(400, conversation_history_op.error());
}
auto conversation_history = conversation_history_op.get();
auto truncate_conversation_history = ConversationManager::truncate_conversation(conversation_history_op.get()["conversation"]);
auto truncate_conversation_history = ConversationManager::get_instance().truncate_conversation(conversation_history_op.get()["conversation"]);
conversation_history["conversation"] = truncate_conversation_history.get();
@ -2594,7 +2594,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
// remove document with lowest score until total tokens is less than MAX_TOKENS
while(ConversationManager::get_token_count(docs_array) > ConversationManager::MAX_TOKENS) {
while(ConversationManager::get_instance().get_token_count(docs_array) > ConversationManager::get_instance().MAX_TOKENS) {
try {
docs_array.erase(docs_array.size() - 1);
} catch(...) {
@ -2623,9 +2623,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
if(has_conversation_history) {
ConversationManager::append_conversation(conversation_id, formatted_question_op.get());
ConversationManager::append_conversation(conversation_id, formatted_answer_op.get());
auto get_conversation_op = ConversationManager::get_conversation(conversation_id);
ConversationManager::get_instance().append_conversation(conversation_id, formatted_question_op.get());
ConversationManager::get_instance().append_conversation(conversation_id, formatted_answer_op.get());
auto get_conversation_op = ConversationManager::get_instance().get_conversation(conversation_id);
if(!get_conversation_op.ok()) {
return Option<nlohmann::json>(get_conversation_op.code(), get_conversation_op.error());
}
@ -2640,12 +2640,12 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
conversation_history.push_back(formatted_question_op.get());
conversation_history.push_back(formatted_answer_op.get());
auto create_conversation_op = ConversationManager::create_conversation(conversation_history);
auto create_conversation_op = ConversationManager::get_instance().create_conversation(conversation_history);
if(!create_conversation_op.ok()) {
return Option<nlohmann::json>(create_conversation_op.code(), create_conversation_op.error());
}
auto get_conversation_op = ConversationManager::get_conversation(create_conversation_op.get());
auto get_conversation_op = ConversationManager::get_instance().get_conversation(create_conversation_op.get());
if(!get_conversation_op.ok()) {
return Option<nlohmann::json>(get_conversation_op.code(), get_conversation_op.error());
}

View File

@ -26,7 +26,7 @@ Option<std::string> ConversationManager::create_conversation(const nlohmann::jso
}
Option<nlohmann::json> ConversationManager::get_conversation(const std::string& conversation_id) {
std::shared_lock lock(conversations_mutex);
std::unique_lock lock(conversations_mutex);
auto conversation = conversations.find(conversation_id);
if (conversation == conversations.end()) {
return Option<nlohmann::json>(404, "Conversation not found");
@ -117,7 +117,7 @@ Option<nlohmann::json> ConversationManager::delete_conversation(const std::strin
}
Option<nlohmann::json> ConversationManager::get_all_conversations() {
std::shared_lock lock(conversations_mutex);
std::unique_lock lock(conversations_mutex);
nlohmann::json all_conversations = nlohmann::json::array();
for(auto& conversation : conversations) {
all_conversations.push_back(conversation.second);
@ -227,4 +227,22 @@ Option<nlohmann::json> ConversationManager::update_conversation(nlohmann::json c
conversations[conversation_id] = actual_conversation;
return Option<nlohmann::json>(actual_conversation);
}
void ConversationManager::run() {
while(!quit) {
std::unique_lock lock(conversations_mutex);
cv.wait_for(lock, std::chrono::seconds(60), [&] { return quit.load(); });
if(quit) {
return;
}
clear_expired_conversations();
}
}
void ConversationManager::stop() {
quit = true;
cv.notify_all();
}

View File

@ -575,7 +575,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
if(conversation_history) {
std::string conversation_id = orig_req_params["conversation_id"];
auto conversation_history = ConversationManager::get_conversation(conversation_id);
auto conversation_history = ConversationManager::get_instance().get_conversation(conversation_id);
if(!conversation_history.ok()) {
res->set_400("`conversation_id` is invalid.");
@ -589,7 +589,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
auto conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]);
auto conversation_id = orig_req_params["conversation_id"];
auto conversation_model = ConversationModelManager::get_model(conversation_model_id).get();
auto conversation_history = ConversationManager::get_conversation(conversation_id).get();
auto conversation_history = ConversationManager::get_instance().get_conversation(conversation_id).get();
auto generate_standalone_q = ConversationModel::get_standalone_question(conversation_history, common_query, conversation_model);
if(!generate_standalone_q.ok()) {
@ -709,7 +709,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
// We have to pop a document from the search result with max size
// Until we do not exceed MAX_TOKENS limit
while(ConversationManager::get_token_count(result_docs_arr) > ConversationManager::MAX_TOKENS) {
while(ConversationManager::get_instance().get_token_count(result_docs_arr) > ConversationManager::get_instance().MAX_TOKENS) {
// sort the result_docs_arr by size descending
std::sort(result_docs_arr.begin(), result_docs_arr.end(), [](const auto& a, const auto& b) {
return a.size() > b.size();
@ -765,9 +765,9 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
if(conversation_history) {
std::string conversation_id = orig_req_params["conversation_id"];
ConversationManager::append_conversation(conversation_id, formatted_question_op.get());
ConversationManager::append_conversation(conversation_id, formatted_answer_op.get());
auto get_conversation_op = ConversationManager::get_conversation(conversation_id);
ConversationManager::get_instance().append_conversation(conversation_id, formatted_question_op.get());
ConversationManager::get_instance().append_conversation(conversation_id, formatted_answer_op.get());
auto get_conversation_op = ConversationManager::get_instance().get_conversation(conversation_id);
if(!get_conversation_op.ok()) {
res->set_400(get_conversation_op.error());
return false;
@ -783,13 +783,13 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
conversation_history.push_back(formatted_question_op.get());
conversation_history.push_back(formatted_answer_op.get());
auto create_conversation_op = ConversationManager::create_conversation(conversation_history);
auto create_conversation_op = ConversationManager::get_instance().create_conversation(conversation_history);
if(!create_conversation_op.ok()) {
res->set_400(create_conversation_op.error());
return false;
}
auto get_conversation_op = ConversationManager::get_conversation(create_conversation_op.get());
auto get_conversation_op = ConversationManager::get_instance().get_conversation(create_conversation_op.get());
if(!get_conversation_op.ok()) {
res->set_400(get_conversation_op.error());
return false;
@ -2612,7 +2612,7 @@ bool post_proxy(const std::shared_ptr<http_req>& req, const std::shared_ptr<http
bool get_conversation(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
std::string conversation_id = req->params["id"];
auto conversation_op = ConversationManager::get_conversation(conversation_id);
auto conversation_op = ConversationManager::get_instance().get_conversation(conversation_id);
if(!conversation_op.ok()) {
res->set(conversation_op.code(), conversation_op.error());
@ -2627,7 +2627,7 @@ bool get_conversation(const std::shared_ptr<http_req>& req, const std::shared_pt
bool del_conversation(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
std::string conversation_id = req->params["id"];
auto conversation_op = ConversationManager::delete_conversation(conversation_id);
auto conversation_op = ConversationManager::get_instance().delete_conversation(conversation_id);
if(!conversation_op.ok()) {
res->set(conversation_op.code(), conversation_op.error());
@ -2639,7 +2639,7 @@ bool del_conversation(const std::shared_ptr<http_req>& req, const std::shared_pt
}
bool get_conversations(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
auto conversations_op = ConversationManager::get_all_conversations();
auto conversations_op = ConversationManager::get_instance().get_all_conversations();
if(!conversations_op.ok()) {
res->set(conversations_op.code(), conversations_op.error());
@ -2665,7 +2665,7 @@ bool put_conversation(const std::shared_ptr<http_req>& req, const std::shared_pt
req_json["id"] = conversation_id;
auto conversation_op = ConversationManager::update_conversation(req_json);
auto conversation_op = ConversationManager::get_instance().update_conversation(req_json);
if(!conversation_op.ok()) {
res->set(conversation_op.code(), conversation_op.error());

View File

@ -449,7 +449,7 @@ int run_server(const Config & config, const std::string & version, void (*master
}
EmbedderManager::set_model_dir(config.get_data_dir() + "/models");
auto conversations_init = ConversationManager::init(&store);
auto conversations_init = ConversationManager::get_instance().init(&store);
if(!conversations_init.ok()) {
LOG(INFO) << "Failed to initialize conversation manager: " << conversations_init.error();
@ -487,14 +487,7 @@ int run_server(const Config & config, const std::string & version, void (*master
std::thread conersation_garbage_collector_thread([]() {
LOG(INFO) << "Conversation garbage collector thread started.";
int last_clear_time = 0;
while(!brpc::IsAskedToQuit()) {
if(last_clear_time + 60 < std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count()) {
last_clear_time = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
ConversationManager::clear_expired_conversations();
}
std::this_thread::sleep_for(std::chrono::seconds(5));
}
ConversationManager::get_instance().run();
});
HouseKeeper::get_instance().init(config.get_housekeeping_interval());
@ -526,6 +519,12 @@ int run_server(const Config & config, const std::string & version, void (*master
LOG(INFO) << "Waiting for event sink thread to be done...";
event_sink_thread.join();
LOG(INFO) << "Shutting down conversation garbage collector thread...";
ConversationManager::get_instance().stop();
LOG(INFO) << "Waiting for conversation garbage collector thread to be done...";
conersation_garbage_collector_thread.join();
LOG(INFO) << "Waiting for housekeeping thread to be done...";
HouseKeeper::get_instance().stop();
housekeeping_thread.join();
@ -539,13 +538,8 @@ int run_server(const Config & config, const std::string & version, void (*master
app_thread_pool.shutdown();
LOG(INFO) << "Shutting down replication_thread_pool.";
replication_thread_pool.shutdown();
LOG(INFO) << "Shutting down conversation garbage collector thread.";
conersation_garbage_collector_thread.join();
server->stop();
});

View File

@ -27,7 +27,7 @@ protected:
collectionManager.load(8, 1000);
ConversationModelManager::init(store);
ConversationManager::init(store);
ConversationManager::get_instance().init(store);
}
virtual void SetUp() {
@ -2931,7 +2931,7 @@ TEST_F(CollectionVectorTest, TestQAConversation) {
// test getting conversation history
auto history_op = ConversationManager::get_conversation(conversation_id);
auto history_op = ConversationManager::get_instance().get_conversation(conversation_id);
ASSERT_TRUE(history_op.ok());

View File

@ -8,13 +8,13 @@ class ConversationTest : public ::testing::Test {
std::string state_dir_path = "/tmp/typesense_test/conversation_test";
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
store = new Store(state_dir_path);
ConversationManager::init(store);
ConversationManager::get_instance().init(store);
}
void TearDown() override {
auto conversations = ConversationManager::get_all_conversations();
auto conversations = ConversationManager::get_instance().get_all_conversations();
for (auto& conversation : conversations.get()) {
ConversationManager::delete_conversation(conversation["id"]);
ConversationManager::get_instance().delete_conversation(conversation["id"]);
}
delete store;
}
@ -25,20 +25,20 @@ class ConversationTest : public ::testing::Test {
TEST_F(ConversationTest, CreateConversation) {
nlohmann::json conversation = nlohmann::json::array();
auto create_res = ConversationManager::create_conversation(conversation);
auto create_res = ConversationManager::get_instance().create_conversation(conversation);
ASSERT_TRUE(create_res.ok());
}
TEST_F(ConversationTest, CreateConversationInvalidType) {
nlohmann::json conversation = nlohmann::json::object();
auto create_res = ConversationManager::create_conversation(conversation);
auto create_res = ConversationManager::get_instance().create_conversation(conversation);
ASSERT_FALSE(create_res.ok());
ASSERT_EQ(create_res.code(), 400);
ASSERT_EQ(create_res.error(), "Conversation is not an array");
}
TEST_F(ConversationTest, GetInvalidConversation) {
auto get_res = ConversationManager::get_conversation("qwerty");
auto get_res = ConversationManager::get_instance().get_conversation("qwerty");
ASSERT_FALSE(get_res.ok());
ASSERT_EQ(get_res.code(), 404);
ASSERT_EQ(get_res.error(), "Conversation not found");
@ -49,16 +49,16 @@ TEST_F(ConversationTest, AppendConversation) {
nlohmann::json message = nlohmann::json::object();
message["user"] = "Hello";
conversation.push_back(message);
auto create_res = ConversationManager::create_conversation(conversation);
auto create_res = ConversationManager::get_instance().create_conversation(conversation);
ASSERT_TRUE(create_res.ok());
std::string conversation_id = create_res.get();
auto append_res = ConversationManager::append_conversation(conversation_id, message);
auto append_res = ConversationManager::get_instance().append_conversation(conversation_id, message);
ASSERT_TRUE(append_res.ok());
ASSERT_EQ(append_res.get(), true);
auto get_res = ConversationManager::get_conversation(conversation_id);
auto get_res = ConversationManager::get_instance().get_conversation(conversation_id);
ASSERT_TRUE(get_res.ok());
ASSERT_TRUE(get_res.get()["conversation"].is_array());
@ -73,14 +73,14 @@ TEST_F(ConversationTest, AppendInvalidConversation) {
nlohmann::json conversation = nlohmann::json::array();
nlohmann::json message = nlohmann::json::object();
message["user"] = "Hello";
auto create_res = ConversationManager::create_conversation(conversation);
auto create_res = ConversationManager::get_instance().create_conversation(conversation);
ASSERT_TRUE(create_res.ok());
std::string conversation_id = create_res.get();
message = "invalid";
auto append_res = ConversationManager::append_conversation(conversation_id, message);
auto append_res = ConversationManager::get_instance().append_conversation(conversation_id, message);
ASSERT_FALSE(append_res.ok());
ASSERT_EQ(append_res.code(), 400);
ASSERT_EQ(append_res.error(), "Message is not an object or array");
@ -88,11 +88,11 @@ TEST_F(ConversationTest, AppendInvalidConversation) {
TEST_F(ConversationTest, DeleteConversation) {
nlohmann::json conversation = nlohmann::json::array();
auto create_res = ConversationManager::create_conversation(conversation);
auto create_res = ConversationManager::get_instance().create_conversation(conversation);
ASSERT_TRUE(create_res.ok());
std::string conversation_id = create_res.get();
auto delete_res = ConversationManager::delete_conversation(conversation_id);
auto delete_res = ConversationManager::get_instance().delete_conversation(conversation_id);
ASSERT_TRUE(delete_res.ok());
auto delete_res_json = delete_res.get();
@ -100,14 +100,14 @@ TEST_F(ConversationTest, DeleteConversation) {
ASSERT_EQ(delete_res_json["id"], conversation_id);
ASSERT_TRUE(delete_res_json["conversation"].is_array());
auto get_res = ConversationManager::get_conversation(conversation_id);
auto get_res = ConversationManager::get_instance().get_conversation(conversation_id);
ASSERT_FALSE(get_res.ok());
ASSERT_EQ(get_res.code(), 404);
ASSERT_EQ(get_res.error(), "Conversation not found");
}
TEST_F(ConversationTest, DeleteInvalidConversation) {
auto delete_res = ConversationManager::delete_conversation("qwerty");
auto delete_res = ConversationManager::get_instance().delete_conversation("qwerty");
ASSERT_FALSE(delete_res.ok());
ASSERT_EQ(delete_res.code(), 404);
ASSERT_EQ(delete_res.error(), "Conversation not found");
@ -121,21 +121,21 @@ TEST_F(ConversationTest, TruncateConversation) {
conversation.push_back(message);
}
auto truncated = ConversationManager::truncate_conversation(conversation);
auto truncated = ConversationManager::get_instance().truncate_conversation(conversation);
ASSERT_TRUE(truncated.ok());
ASSERT_TRUE(truncated.get().size() < conversation.size());
}
TEST_F(ConversationTest, TruncateConversationEmpty) {
nlohmann::json conversation = nlohmann::json::array();
auto truncated = ConversationManager::truncate_conversation(conversation);
auto truncated = ConversationManager::get_instance().truncate_conversation(conversation);
ASSERT_TRUE(truncated.ok());
ASSERT_TRUE(truncated.get().size() == 0);
}
TEST_F(ConversationTest, TruncateConversationInvalidType) {
nlohmann::json conversation = nlohmann::json::object();
auto truncated = ConversationManager::truncate_conversation(conversation);
auto truncated = ConversationManager::get_instance().truncate_conversation(conversation);
ASSERT_FALSE(truncated.ok());
ASSERT_EQ(truncated.code(), 400);
ASSERT_EQ(truncated.error(), "Conversation history is not an array");
@ -147,28 +147,28 @@ TEST_F(ConversationTest, TestConversationExpire) {
nlohmann::json message = nlohmann::json::object();
message["user"] = "Hello";
conversation.push_back(message);
auto create_res = ConversationManager::create_conversation(conversation);
auto create_res = ConversationManager::get_instance().create_conversation(conversation);
ASSERT_TRUE(create_res.ok());
std::string conversation_id = create_res.get();
ConversationManager::clear_expired_conversations();
ConversationManager::get_instance().clear_expired_conversations();
auto get_res = ConversationManager::get_conversation(conversation_id);
auto get_res = ConversationManager::get_instance().get_conversation(conversation_id);
ASSERT_TRUE(get_res.ok());
ASSERT_TRUE(get_res.get()["conversation"].is_array());
ASSERT_EQ(get_res.get()["id"], conversation_id);
ASSERT_EQ(get_res.get()["conversation"].size(), 1);
ConversationManager::_set_ttl_offset(24 * 60 * 60 * 2);
ConversationManager::clear_expired_conversations();
ConversationManager::get_instance()._set_ttl_offset(24 * 60 * 60 * 2);
ConversationManager::get_instance().clear_expired_conversations();
get_res = ConversationManager::get_conversation(conversation_id);
get_res = ConversationManager::get_instance().get_conversation(conversation_id);
ASSERT_FALSE(get_res.ok());
ASSERT_EQ(get_res.code(), 404);
ASSERT_EQ(get_res.error(), "Conversation not found");
ConversationManager::_set_ttl_offset(0);
ConversationManager::get_instance()._set_ttl_offset(0);
}

View File

@ -27,7 +27,7 @@ protected:
collectionManager.load(8, 1000);
ConversationModelManager::init(store);
ConversationManager::init(store);
ConversationManager::get_instance().init(store);
}
virtual void SetUp() {