Fix conversation model initialization (#1806)

* Fix conversation model initialization

* Rename `conversation_collection` to `history_collection`
This commit is contained in:
Ozan Armağan 2024-06-27 03:53:05 +03:00 committed by GitHub
parent fed1a6300b
commit 7b1ed4ee79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 81 additions and 74 deletions

View File

@ -23,7 +23,7 @@ class ConversationManager {
static ConversationManager instance;
return instance;
}
Option<std::string> add_conversation(const nlohmann::json& conversation, const std::string& conversation_collection, const std::string& id = "");
Option<std::string> add_conversation(const nlohmann::json& conversation, const std::string& history_collection, const std::string& id = "");
Option<nlohmann::json> get_conversation(const std::string& conversation_id);
static Option<nlohmann::json> truncate_conversation(nlohmann::json conversation, size_t limit);
Option<nlohmann::json> update_conversation(nlohmann::json conversation);
@ -42,9 +42,9 @@ class ConversationManager {
Option<bool> validate_conversation_store_schema(Collection* collection);
Option<bool> validate_conversation_store_collection(const std::string& collection);
Option<bool> add_conversation_collection(const std::string& collection);
Option<bool> remove_conversation_collection(const std::string& collection);
Option<Collection*> get_conversation_collection(const std::string& conversation_id);
Option<bool> add_history_collection(const std::string& collection);
Option<bool> remove_history_collection(const std::string& collection);
Option<Collection*> get_history_collection(const std::string& conversation_id);
private:
ConversationManager() {}
std::mutex conversations_mutex;
@ -56,6 +56,6 @@ class ConversationManager {
std::atomic<bool> quit = false;
std::condition_variable cv;
std::unordered_map<std::string, uint32_t> conversation_collection_map;
std::unordered_map<std::string, uint32_t> history_collection_map;
std::unordered_map<std::string, std::string> conversation_mapper;
};

View File

@ -30,10 +30,10 @@ class ConversationModelManager
static constexpr char* MODEL_NEXT_ID = "$CVMN";
static constexpr char* MODEL_KEY_PREFIX = "$CVMP";
static inline int64_t DEFAULT_CONVERSATION_COLLECTION_SUFFIX = 0;
static inline int64_t DEFAULT_HISTORY_COLLECTION_SUFFIX = 0;
static inline Store* store;
static const std::string get_model_key(const std::string& model_id);
static Option<Collection*> get_default_conversation_collection();
static Option<Collection*> get_default_history_collection();
static Option<nlohmann::json> delete_model_unsafe(const std::string& model_id);
static Option<nlohmann::json> add_model_unsafe(nlohmann::json model, const std::string& model_id);

View File

@ -173,6 +173,13 @@ std::string BatchedIndexer::get_collection_name(const std::shared_ptr<http_req>&
obj.count("name") != 0 && obj["name"].is_string()) {
coll_name = obj["name"];
}
} else if(route_found && rpath->handler == post_conversation_model) {
nlohmann::json obj = nlohmann::json::parse(req->body, nullptr, false);
if(!obj.is_discarded() && obj.is_object() &&
obj.count("history_collection") != 0 && obj["history_collection"].is_string()) {
coll_name = obj["history_collection"];
}
}
}

View File

@ -85,7 +85,7 @@ Collection::~Collection() {
}
}
ConversationManager::get_instance().remove_conversation_collection(name);
ConversationManager::get_instance().remove_history_collection(name);
}
uint32_t Collection::get_next_seq_id() {
@ -2883,7 +2883,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
conversation_history.push_back(formatted_question_op.get());
conversation_history.push_back(formatted_answer_op.get());
auto add_conversation_op = ConversationManager::get_instance().add_conversation(conversation_history, conversation_model["conversation_collection"].get<std::string>(), conversation_id);
auto add_conversation_op = ConversationManager::get_instance().add_conversation(conversation_history, conversation_model["history_collection"].get<std::string>(), conversation_id);
if(!add_conversation_op.ok()) {
return Option<nlohmann::json>(add_conversation_op.code(), add_conversation_op.error());
}

View File

@ -4,7 +4,7 @@
#include "http_client.h"
#include "core_api.h"
Option<std::string> ConversationManager::add_conversation(const nlohmann::json& conversation, const std::string& conversation_collection, const std::string& id) {
Option<std::string> ConversationManager::add_conversation(const nlohmann::json& conversation, const std::string& history_collection, const std::string& id) {
std::unique_lock lock(conversations_mutex);
if(!conversation.is_array()) {
return Option<std::string>(400, "Conversation is not an array");
@ -19,7 +19,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
std::string conversation_id = id.empty() ? sole::uuid4().str() : id;
auto collection = CollectionManager::get_instance().get_collection(conversation_collection).get();
auto collection = CollectionManager::get_instance().get_collection(history_collection).get();
if(!collection) {
return Option<std::string>(404, "Conversation store collection not found");
}
@ -54,7 +54,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
auto resp = std::make_shared<http_res>(nullptr);
req->params["action"] = "emplace";
req->params["collection"] = conversation_collection;
req->params["collection"] = history_collection;
req->body = body;
auto api_res = post_import_documents(req, resp);
@ -62,7 +62,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
return Option<std::string>(resp->status_code, resp->body);
}
conversation_mapper[conversation_id] = conversation_collection;
conversation_mapper[conversation_id] = history_collection;
return Option<std::string>(conversation_id);
}
@ -70,7 +70,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
std::string leader_url = raft_server->get_leader_url();
if(!leader_url.empty()) {
std::string base_url = leader_url + "collections/" + conversation_collection;
std::string base_url = leader_url + "collections/" + history_collection;
std::string res;
std::string url = base_url + "/documents/import?action=emplace";
std::map<std::string, std::string> res_headers;
@ -82,7 +82,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
LOG(ERROR) << "Status: " << status;
return Option<std::string>(400, "Error while creating conversation");
} else {
conversation_mapper[conversation_id] = conversation_collection;
conversation_mapper[conversation_id] = history_collection;
return Option<std::string>(conversation_id);
}
} else {
@ -92,7 +92,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
Option<nlohmann::json> ConversationManager::get_conversation(const std::string& conversation_id) {
auto collection_op = get_conversation_collection(conversation_id);
auto collection_op = get_history_collection(conversation_id);
if(!collection_op.ok()) {
return Option<nlohmann::json>(collection_op.code(), collection_op.error());
}
@ -169,19 +169,19 @@ Option<nlohmann::json> ConversationManager::delete_conversation(const std::strin
return Option<nlohmann::json>(conversation_exists.code(), conversation_exists.error());
}
auto conversation_collection_op = get_conversation_collection(conversation_id);
if(!conversation_collection_op.ok()) {
return Option<nlohmann::json>(conversation_collection_op.code(), conversation_collection_op.error());
auto history_collection_op = get_history_collection(conversation_id);
if(!history_collection_op.ok()) {
return Option<nlohmann::json>(history_collection_op.code(), history_collection_op.error());
}
auto conversation_collection = conversation_collection_op.get()->get_name();
auto history_collection = history_collection_op.get()->get_name();
if(!raft_server) {
auto req = std::make_shared<http_req>();
auto resp = std::make_shared<http_res>(nullptr);
req->params["filter_by"] = "conversation_id:" + conversation_id;
req->params["collection"] = conversation_collection;
req->params["collection"] = history_collection;
auto api_res = del_remove_documents(req, resp);
if(!api_res) {
@ -200,7 +200,7 @@ Option<nlohmann::json> ConversationManager::delete_conversation(const std::strin
return Option<nlohmann::json>(500, "Leader URL is empty");
}
std::string base_url = leader_url + "collections/" + conversation_collection;
std::string base_url = leader_url + "collections/" + history_collection;
std::string res;
std::string url = base_url + "/documents?filter_by=conversation_id:" + conversation_id;
std::map<std::string, std::string> res_headers;
@ -267,7 +267,7 @@ void ConversationManager::clear_expired_conversations() {
std::vector<sort_by> sort_by_vec = {{"timestamp", sort_field_const::desc}};
for(auto& conversation_id : conversation_ids) {
auto collection_op = get_conversation_collection(conversation_id);
auto collection_op = get_history_collection(conversation_id);
if(!collection_op.ok()) {
LOG(ERROR) << collection_op.error();
continue;
@ -441,7 +441,7 @@ Option<bool> ConversationManager::validate_conversation_store_schema(Collection*
Option<bool> ConversationManager::check_conversation_exists(const std::string& conversation_id) {
auto collection_op = get_conversation_collection(conversation_id);
auto collection_op = get_history_collection(conversation_id);
if(!collection_op.ok()) {
return Option<bool>(collection_op.code(), collection_op.error());
}
@ -472,10 +472,10 @@ Option<std::unordered_set<std::string>> ConversationManager::get_conversation_id
return Option<std::unordered_set<std::string>>(conversation_ids);
}
Option<bool> ConversationManager::add_conversation_collection(const std::string& collection) {
Option<bool> ConversationManager::add_history_collection(const std::string& collection) {
std::unique_lock lock(conversations_mutex);
if(conversation_collection_map.count(collection) > 0) {
conversation_collection_map[collection]++;
if(history_collection_map.count(collection) > 0) {
history_collection_map[collection]++;
} else {
auto collection_ptr = CollectionManager::get_instance().get_collection(collection).get();
if(!collection_ptr) {
@ -487,21 +487,21 @@ Option<bool> ConversationManager::add_conversation_collection(const std::string&
return Option<bool>(validate_op.code(), validate_op.error());
}
conversation_collection_map[collection] = 1;
history_collection_map[collection] = 1;
}
return Option<bool>(true);
}
Option<bool> ConversationManager::remove_conversation_collection(const std::string& collection) {
Option<bool> ConversationManager::remove_history_collection(const std::string& collection) {
std::unique_lock lock(conversations_mutex);
if(conversation_collection_map.count(collection) == 0) {
if(history_collection_map.count(collection) == 0) {
return Option<bool>(404, "Collection not found");
}
conversation_collection_map[collection]--;
history_collection_map[collection]--;
if(conversation_collection_map[collection] == 0) {
if(history_collection_map[collection] == 0) {
std::vector<std::string> conversations_to_delete;
for(auto& conversation : conversation_mapper) {
if(conversation.second == collection) {
@ -511,13 +511,13 @@ Option<bool> ConversationManager::remove_conversation_collection(const std::stri
for(auto conversation_id : conversations_to_delete) {
conversation_mapper.erase(conversation_id);
}
conversation_collection_map.erase(collection);
history_collection_map.erase(collection);
}
return Option<bool>(true);
}
Option<Collection*> ConversationManager::get_conversation_collection(const std::string& conversation_id) {
Option<Collection*> ConversationManager::get_history_collection(const std::string& conversation_id) {
if(conversation_mapper.count(conversation_id) > 0) {
auto collection = CollectionManager::get_instance().get_collection(conversation_mapper[conversation_id]).get();
@ -526,7 +526,7 @@ Option<Collection*> ConversationManager::get_conversation_collection(const std::
}
}
for(auto& collection : conversation_collection_map) {
for(auto& collection : history_collection_map) {
auto collection_ptr = CollectionManager::get_instance().get_collection(collection.first).get();
if(!collection_ptr) {
continue;

View File

@ -28,11 +28,11 @@ Option<bool> ConversationModel::validate_model(const nlohmann::json& model_confi
return Option<bool>(400, "Property `max_bytes` is not provided or not a positive integer.");
}
if(model_config.count("conversation_collection") == 0 || !model_config["conversation_collection"].is_string()) {
return Option<bool>(400, "Property `conversation_collection` is not provided or not a string.");
if(model_config.count("history_collection") == 0 || !model_config["history_collection"].is_string()) {
return Option<bool>(400, "Property `history_collection` is not provided or not a string.");
}
auto validate_converson_collection_op = ConversationManager::get_instance().validate_conversation_store_collection(model_config["conversation_collection"].get<std::string>());
auto validate_converson_collection_op = ConversationManager::get_instance().validate_conversation_store_collection(model_config["history_collection"].get<std::string>());
if(!validate_converson_collection_op.ok()) {
return Option<bool>(400, validate_converson_collection_op.error());
}

View File

@ -36,7 +36,7 @@ Option<nlohmann::json> ConversationModelManager::add_model_unsafe(nlohmann::json
models[model_id] = model;
ConversationManager::get_instance().add_conversation_collection(model["conversation_collection"]);
ConversationManager::get_instance().add_history_collection(model["history_collection"]);
return Option<nlohmann::json>(model);
}
@ -57,8 +57,8 @@ Option<nlohmann::json> ConversationModelManager::delete_model_unsafe(const std::
auto model_key = get_model_key(model_id);
bool delete_op = store->remove(model_key);
if(model.count("conversation_collection") != 0) {
ConversationManager::get_instance().remove_conversation_collection(model["conversation_collection"].get<std::string>());
if(model.count("history_collection") != 0) {
ConversationManager::get_instance().remove_history_collection(model["history_collection"].get<std::string>());
}
models.erase(it);
return Option<nlohmann::json>(model);
@ -94,9 +94,9 @@ Option<nlohmann::json> ConversationModelManager::update_model(const std::string&
return Option<nlohmann::json>(500, "Error while inserting model into the store");
}
if(it->second["conversation_collection"] != model["conversation_collection"]) {
ConversationManager::get_instance().remove_conversation_collection(it->second["conversation_collection"]);
ConversationManager::get_instance().add_conversation_collection(model["conversation_collection"]);
if(it->second["history_collection"] != model["history_collection"]) {
ConversationManager::get_instance().remove_history_collection(it->second["history_collection"]);
ConversationManager::get_instance().add_history_collection(model["history_collection"]);
}
models[model_id] = model;
@ -131,7 +131,7 @@ Option<int> ConversationModelManager::init(Store* store) {
// Migrate models that don't have a conversation collection
if(model_json.count("conversation_collection") == 0) {
if(model_json.count("history_collection") == 0) {
auto delete_op = delete_model_unsafe(model_id);
if(!delete_op.ok()) {
return Option<int>(delete_op.code(), delete_op.error());
@ -144,7 +144,7 @@ Option<int> ConversationModelManager::init(Store* store) {
}
models[model_id] = model_json;
ConversationManager::get_instance().add_conversation_collection(model_json["conversation_collection"].get<std::string>());
ConversationManager::get_instance().add_history_collection(model_json["history_collection"].get<std::string>());
loaded_models++;
}
@ -156,13 +156,13 @@ const std::string ConversationModelManager::get_model_key(const std::string& mod
}
Option<Collection*> ConversationModelManager::get_default_conversation_collection() {
Option<Collection*> ConversationModelManager::get_default_history_collection() {
int64_t time_epoch;
if(DEFAULT_CONVERSATION_COLLECTION_SUFFIX == 0) {
if(DEFAULT_HISTORY_COLLECTION_SUFFIX == 0) {
time_epoch = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
DEFAULT_CONVERSATION_COLLECTION_SUFFIX = time_epoch;
DEFAULT_HISTORY_COLLECTION_SUFFIX = time_epoch;
} else {
time_epoch = DEFAULT_CONVERSATION_COLLECTION_SUFFIX;
time_epoch = DEFAULT_HISTORY_COLLECTION_SUFFIX;
}
std::string collection_id = "default_conversation_history_" + std::to_string(time_epoch);
@ -206,11 +206,11 @@ Option<Collection*> ConversationModelManager::get_default_conversation_collectio
Option<nlohmann::json> ConversationModelManager::migrate_model(nlohmann::json model) {
auto model_id = model["id"];
auto default_collection = get_default_conversation_collection();
auto default_collection = get_default_history_collection();
if(!default_collection.ok()) {
return Option<nlohmann::json>(default_collection.code(), default_collection.error());
}
model["conversation_collection"] = default_collection.get()->get_name();
model["history_collection"] = default_collection.get()->get_name();
auto add_res = add_model_unsafe(model, model_id);
if(!add_res.ok()) {
return Option<nlohmann::json>(add_res.code(), add_res.error());

View File

@ -965,7 +965,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
new_conversation_history.push_back(formatted_answer_op.get());
std::string conversation_id = conversation_history ? orig_req_params["conversation_id"] : "";
auto add_conversation_op = ConversationManager::get_instance().add_conversation(new_conversation_history, conversation_model["conversation_collection"], conversation_id);
auto add_conversation_op = ConversationManager::get_instance().add_conversation(new_conversation_history, conversation_model["history_collection"], conversation_id);
if(!add_conversation_op.ok()) {
res->set_400(add_conversation_op.error());
return false;

View File

@ -3130,7 +3130,7 @@ TEST_F(CollectionVectorTest, TestQAConversation) {
auto conversation_model_config = R"({
"model_name": "openai/gpt-3.5-turbo",
"max_bytes: 1000,
"conversation_collection": "conversation_store",
"history_collection": "conversation_store",
})"_json;
conversation_model_config["api_key"] = api_key;
@ -3578,7 +3578,7 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) {
auto conversation_model_config = R"({
"model_name": "openai/gpt-3.5-turbo",
"max_bytes": 1000,
"conversation_collection": "conversation_store"
"history_collection": "conversation_store"
})"_json;
conversation_model_config["api_key"] = api_key;
@ -3658,7 +3658,7 @@ TEST_F(CollectionVectorTest, TestMigratingConversationModel) {
auto conversation_model_config = R"({
"model_name": "openai/gpt-3.5-turbo",
"max_bytes": 1000,
"conversation_collection": "conversation_store"
"history_collection": "conversation_store"
})"_json;
if (std::getenv("api_key") == nullptr) {
@ -3671,7 +3671,7 @@ TEST_F(CollectionVectorTest, TestMigratingConversationModel) {
auto migrate_res = ConversationModelManager::migrate_model(conversation_model_config);
ASSERT_TRUE(migrate_res.ok());
auto migrated_model = migrate_res.get();
ASSERT_TRUE(migrated_model.count("conversation_collection") == 1);
ASSERT_TRUE(migrated_model.count("history_collection") == 1);
auto collection = CollectionManager::get_instance().get_collection("conversation_store").get();
ASSERT_TRUE(collection != nullptr);

View File

@ -40,11 +40,11 @@ class ConversationTest : public ::testing::Test {
})"_json;
collectionManager.create_collection(schema_json);
ConversationManager::get_instance().add_conversation_collection("conversation_store");
ConversationManager::get_instance().add_history_collection("conversation_store");
}
void TearDown() override {
ConversationManager::get_instance().remove_conversation_collection("conversation_store");
ConversationManager::get_instance().remove_history_collection("conversation_store");
collectionManager.dispose();
delete store;
}
@ -239,7 +239,7 @@ TEST_F(ConversationTest, TestRemoveConversationCollection) {
})"_json;
LOG(INFO) << "Creating collection";
collectionManager.create_collection(schema_json);
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
ConversationManager::get_instance().add_history_collection("conversation_store2");
LOG(INFO) << "Collection created";
nlohmann::json conversation = nlohmann::json::array();
nlohmann::json message = nlohmann::json::object();
@ -257,7 +257,7 @@ TEST_F(ConversationTest, TestRemoveConversationCollection) {
LOG(INFO) << "Removing collection";
auto remove_res = ConversationManager::get_instance().remove_conversation_collection("conversation_store2");
auto remove_res = ConversationManager::get_instance().remove_history_collection("conversation_store2");
ASSERT_TRUE(remove_res.ok());
LOG(INFO) << "Getting conversation";
@ -267,7 +267,7 @@ TEST_F(ConversationTest, TestRemoveConversationCollection) {
ASSERT_EQ(get_res.error(), "Conversation not found");
LOG(INFO) << "Adding collection again";
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
ConversationManager::get_instance().add_history_collection("conversation_store2");
LOG(INFO) << "Adding conversation";
@ -313,8 +313,8 @@ TEST_F(ConversationTest, TestMultipleRefConversationCollection) {
})"_json;
collectionManager.create_collection(schema_json);
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
ConversationManager::get_instance().add_history_collection("conversation_store2");
ConversationManager::get_instance().add_history_collection("conversation_store2");
nlohmann::json conversation = nlohmann::json::array();
nlohmann::json message = nlohmann::json::object();
@ -328,7 +328,7 @@ TEST_F(ConversationTest, TestMultipleRefConversationCollection) {
ASSERT_TRUE(get_res.ok());
ASSERT_EQ(get_res.get()["conversation"].size(), 1);
auto remove_res = ConversationManager::get_instance().remove_conversation_collection("conversation_store2");
auto remove_res = ConversationManager::get_instance().remove_history_collection("conversation_store2");
ASSERT_TRUE(remove_res.ok());
get_res = ConversationManager::get_instance().get_conversation(create_res.get());
@ -348,7 +348,7 @@ TEST_F(ConversationTest, TestInvalidConversationCollection) {
})"_json;
collectionManager.create_collection(schema_json);
auto res = ConversationManager::get_instance().add_conversation_collection("conversation_store2");
auto res = ConversationManager::get_instance().add_history_collection("conversation_store2");
ASSERT_FALSE(res.ok());
ASSERT_EQ(res.code(), 400);
ASSERT_EQ(res.error(), "Schema is missing `conversation_id` field");

View File

@ -54,7 +54,7 @@ protected:
})"_json;
collectionManager.create_collection(schema_json);
ConversationManager::get_instance().add_conversation_collection("conversation_store");
ConversationManager::get_instance().add_history_collection("conversation_store");
}
virtual void SetUp() {
@ -1415,7 +1415,7 @@ TEST_F(CoreAPIUtilsTest, TestConversationModels) {
nlohmann::json model_config = R"({
"model_name": "openai/gpt-3.5-turbo",
"max_bytes": 10000,
"conversation_collection": "conversation_store"
"history_collection": "conversation_store"
})"_json;
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
@ -1456,7 +1456,7 @@ TEST_F(CoreAPIUtilsTest, TestConversationModels) {
TEST_F(CoreAPIUtilsTest, TestInvalidConversationModels) {
// test with no model_name
nlohmann::json model_config = R"({
"conversation_collection": "conversation_store"
"history_collection": "conversation_store"
})"_json;
if (std::getenv("api_key") == nullptr) {
@ -1546,17 +1546,17 @@ TEST_F(CoreAPIUtilsTest, TestInvalidConversationModels) {
ASSERT_EQ("Property `max_bytes` must be a positive number.", nlohmann::json::parse(resp->body)["message"]);
model_config["max_bytes"] = 10000;
model_config["conversation_collection"] = 123;
model_config["history_collection"] = 123;
// test with conversation_collection as integer
// test with history_collection as integer
req->body = model_config.dump();
post_conversation_model(req, resp);
ASSERT_EQ(400, resp->status_code);
ASSERT_EQ("Property `conversation_collection` is not provided or not a string.", nlohmann::json::parse(resp->body)["message"]);
ASSERT_EQ("Property `history_collection` is not provided or not a string.", nlohmann::json::parse(resp->body)["message"]);
// test with conversation_collection as empty string
model_config["conversation_collection"] = "";
// test with history_collection as empty string
model_config["history_collection"] = "";
req->body = model_config.dump();
post_conversation_model(req, resp);