mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 14:12:27 +08:00
Fix conversation model initialization (#1806)
* Fix conversation model initialization * Rename `conversation_collection` to `history_collection`
This commit is contained in:
parent
fed1a6300b
commit
7b1ed4ee79
@ -23,7 +23,7 @@ class ConversationManager {
|
||||
static ConversationManager instance;
|
||||
return instance;
|
||||
}
|
||||
Option<std::string> add_conversation(const nlohmann::json& conversation, const std::string& conversation_collection, const std::string& id = "");
|
||||
Option<std::string> add_conversation(const nlohmann::json& conversation, const std::string& history_collection, const std::string& id = "");
|
||||
Option<nlohmann::json> get_conversation(const std::string& conversation_id);
|
||||
static Option<nlohmann::json> truncate_conversation(nlohmann::json conversation, size_t limit);
|
||||
Option<nlohmann::json> update_conversation(nlohmann::json conversation);
|
||||
@ -42,9 +42,9 @@ class ConversationManager {
|
||||
|
||||
Option<bool> validate_conversation_store_schema(Collection* collection);
|
||||
Option<bool> validate_conversation_store_collection(const std::string& collection);
|
||||
Option<bool> add_conversation_collection(const std::string& collection);
|
||||
Option<bool> remove_conversation_collection(const std::string& collection);
|
||||
Option<Collection*> get_conversation_collection(const std::string& conversation_id);
|
||||
Option<bool> add_history_collection(const std::string& collection);
|
||||
Option<bool> remove_history_collection(const std::string& collection);
|
||||
Option<Collection*> get_history_collection(const std::string& conversation_id);
|
||||
private:
|
||||
ConversationManager() {}
|
||||
std::mutex conversations_mutex;
|
||||
@ -56,6 +56,6 @@ class ConversationManager {
|
||||
|
||||
std::atomic<bool> quit = false;
|
||||
std::condition_variable cv;
|
||||
std::unordered_map<std::string, uint32_t> conversation_collection_map;
|
||||
std::unordered_map<std::string, uint32_t> history_collection_map;
|
||||
std::unordered_map<std::string, std::string> conversation_mapper;
|
||||
};
|
@ -30,10 +30,10 @@ class ConversationModelManager
|
||||
|
||||
static constexpr char* MODEL_NEXT_ID = "$CVMN";
|
||||
static constexpr char* MODEL_KEY_PREFIX = "$CVMP";
|
||||
static inline int64_t DEFAULT_CONVERSATION_COLLECTION_SUFFIX = 0;
|
||||
static inline int64_t DEFAULT_HISTORY_COLLECTION_SUFFIX = 0;
|
||||
static inline Store* store;
|
||||
static const std::string get_model_key(const std::string& model_id);
|
||||
static Option<Collection*> get_default_conversation_collection();
|
||||
static Option<Collection*> get_default_history_collection();
|
||||
static Option<nlohmann::json> delete_model_unsafe(const std::string& model_id);
|
||||
static Option<nlohmann::json> add_model_unsafe(nlohmann::json model, const std::string& model_id);
|
||||
|
||||
|
@ -173,6 +173,13 @@ std::string BatchedIndexer::get_collection_name(const std::shared_ptr<http_req>&
|
||||
obj.count("name") != 0 && obj["name"].is_string()) {
|
||||
coll_name = obj["name"];
|
||||
}
|
||||
} else if(route_found && rpath->handler == post_conversation_model) {
|
||||
nlohmann::json obj = nlohmann::json::parse(req->body, nullptr, false);
|
||||
|
||||
if(!obj.is_discarded() && obj.is_object() &&
|
||||
obj.count("history_collection") != 0 && obj["history_collection"].is_string()) {
|
||||
coll_name = obj["history_collection"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,7 @@ Collection::~Collection() {
|
||||
}
|
||||
}
|
||||
|
||||
ConversationManager::get_instance().remove_conversation_collection(name);
|
||||
ConversationManager::get_instance().remove_history_collection(name);
|
||||
}
|
||||
|
||||
uint32_t Collection::get_next_seq_id() {
|
||||
@ -2883,7 +2883,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
conversation_history.push_back(formatted_question_op.get());
|
||||
conversation_history.push_back(formatted_answer_op.get());
|
||||
|
||||
auto add_conversation_op = ConversationManager::get_instance().add_conversation(conversation_history, conversation_model["conversation_collection"].get<std::string>(), conversation_id);
|
||||
auto add_conversation_op = ConversationManager::get_instance().add_conversation(conversation_history, conversation_model["history_collection"].get<std::string>(), conversation_id);
|
||||
if(!add_conversation_op.ok()) {
|
||||
return Option<nlohmann::json>(add_conversation_op.code(), add_conversation_op.error());
|
||||
}
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include "http_client.h"
|
||||
#include "core_api.h"
|
||||
|
||||
Option<std::string> ConversationManager::add_conversation(const nlohmann::json& conversation, const std::string& conversation_collection, const std::string& id) {
|
||||
Option<std::string> ConversationManager::add_conversation(const nlohmann::json& conversation, const std::string& history_collection, const std::string& id) {
|
||||
std::unique_lock lock(conversations_mutex);
|
||||
if(!conversation.is_array()) {
|
||||
return Option<std::string>(400, "Conversation is not an array");
|
||||
@ -19,7 +19,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
|
||||
|
||||
std::string conversation_id = id.empty() ? sole::uuid4().str() : id;
|
||||
|
||||
auto collection = CollectionManager::get_instance().get_collection(conversation_collection).get();
|
||||
auto collection = CollectionManager::get_instance().get_collection(history_collection).get();
|
||||
if(!collection) {
|
||||
return Option<std::string>(404, "Conversation store collection not found");
|
||||
}
|
||||
@ -54,7 +54,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
|
||||
auto resp = std::make_shared<http_res>(nullptr);
|
||||
|
||||
req->params["action"] = "emplace";
|
||||
req->params["collection"] = conversation_collection;
|
||||
req->params["collection"] = history_collection;
|
||||
req->body = body;
|
||||
|
||||
auto api_res = post_import_documents(req, resp);
|
||||
@ -62,7 +62,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
|
||||
return Option<std::string>(resp->status_code, resp->body);
|
||||
}
|
||||
|
||||
conversation_mapper[conversation_id] = conversation_collection;
|
||||
conversation_mapper[conversation_id] = history_collection;
|
||||
return Option<std::string>(conversation_id);
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
|
||||
std::string leader_url = raft_server->get_leader_url();
|
||||
|
||||
if(!leader_url.empty()) {
|
||||
std::string base_url = leader_url + "collections/" + conversation_collection;
|
||||
std::string base_url = leader_url + "collections/" + history_collection;
|
||||
std::string res;
|
||||
std::string url = base_url + "/documents/import?action=emplace";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
@ -82,7 +82,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
|
||||
LOG(ERROR) << "Status: " << status;
|
||||
return Option<std::string>(400, "Error while creating conversation");
|
||||
} else {
|
||||
conversation_mapper[conversation_id] = conversation_collection;
|
||||
conversation_mapper[conversation_id] = history_collection;
|
||||
return Option<std::string>(conversation_id);
|
||||
}
|
||||
} else {
|
||||
@ -92,7 +92,7 @@ Option<std::string> ConversationManager::add_conversation(const nlohmann::json&
|
||||
|
||||
Option<nlohmann::json> ConversationManager::get_conversation(const std::string& conversation_id) {
|
||||
|
||||
auto collection_op = get_conversation_collection(conversation_id);
|
||||
auto collection_op = get_history_collection(conversation_id);
|
||||
if(!collection_op.ok()) {
|
||||
return Option<nlohmann::json>(collection_op.code(), collection_op.error());
|
||||
}
|
||||
@ -169,19 +169,19 @@ Option<nlohmann::json> ConversationManager::delete_conversation(const std::strin
|
||||
return Option<nlohmann::json>(conversation_exists.code(), conversation_exists.error());
|
||||
}
|
||||
|
||||
auto conversation_collection_op = get_conversation_collection(conversation_id);
|
||||
if(!conversation_collection_op.ok()) {
|
||||
return Option<nlohmann::json>(conversation_collection_op.code(), conversation_collection_op.error());
|
||||
auto history_collection_op = get_history_collection(conversation_id);
|
||||
if(!history_collection_op.ok()) {
|
||||
return Option<nlohmann::json>(history_collection_op.code(), history_collection_op.error());
|
||||
}
|
||||
|
||||
auto conversation_collection = conversation_collection_op.get()->get_name();
|
||||
auto history_collection = history_collection_op.get()->get_name();
|
||||
|
||||
if(!raft_server) {
|
||||
auto req = std::make_shared<http_req>();
|
||||
auto resp = std::make_shared<http_res>(nullptr);
|
||||
|
||||
req->params["filter_by"] = "conversation_id:" + conversation_id;
|
||||
req->params["collection"] = conversation_collection;
|
||||
req->params["collection"] = history_collection;
|
||||
|
||||
auto api_res = del_remove_documents(req, resp);
|
||||
if(!api_res) {
|
||||
@ -200,7 +200,7 @@ Option<nlohmann::json> ConversationManager::delete_conversation(const std::strin
|
||||
return Option<nlohmann::json>(500, "Leader URL is empty");
|
||||
}
|
||||
|
||||
std::string base_url = leader_url + "collections/" + conversation_collection;
|
||||
std::string base_url = leader_url + "collections/" + history_collection;
|
||||
std::string res;
|
||||
std::string url = base_url + "/documents?filter_by=conversation_id:" + conversation_id;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
@ -267,7 +267,7 @@ void ConversationManager::clear_expired_conversations() {
|
||||
std::vector<sort_by> sort_by_vec = {{"timestamp", sort_field_const::desc}};
|
||||
|
||||
for(auto& conversation_id : conversation_ids) {
|
||||
auto collection_op = get_conversation_collection(conversation_id);
|
||||
auto collection_op = get_history_collection(conversation_id);
|
||||
if(!collection_op.ok()) {
|
||||
LOG(ERROR) << collection_op.error();
|
||||
continue;
|
||||
@ -441,7 +441,7 @@ Option<bool> ConversationManager::validate_conversation_store_schema(Collection*
|
||||
|
||||
|
||||
Option<bool> ConversationManager::check_conversation_exists(const std::string& conversation_id) {
|
||||
auto collection_op = get_conversation_collection(conversation_id);
|
||||
auto collection_op = get_history_collection(conversation_id);
|
||||
if(!collection_op.ok()) {
|
||||
return Option<bool>(collection_op.code(), collection_op.error());
|
||||
}
|
||||
@ -472,10 +472,10 @@ Option<std::unordered_set<std::string>> ConversationManager::get_conversation_id
|
||||
return Option<std::unordered_set<std::string>>(conversation_ids);
|
||||
}
|
||||
|
||||
Option<bool> ConversationManager::add_conversation_collection(const std::string& collection) {
|
||||
Option<bool> ConversationManager::add_history_collection(const std::string& collection) {
|
||||
std::unique_lock lock(conversations_mutex);
|
||||
if(conversation_collection_map.count(collection) > 0) {
|
||||
conversation_collection_map[collection]++;
|
||||
if(history_collection_map.count(collection) > 0) {
|
||||
history_collection_map[collection]++;
|
||||
} else {
|
||||
auto collection_ptr = CollectionManager::get_instance().get_collection(collection).get();
|
||||
if(!collection_ptr) {
|
||||
@ -487,21 +487,21 @@ Option<bool> ConversationManager::add_conversation_collection(const std::string&
|
||||
return Option<bool>(validate_op.code(), validate_op.error());
|
||||
}
|
||||
|
||||
conversation_collection_map[collection] = 1;
|
||||
history_collection_map[collection] = 1;
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> ConversationManager::remove_conversation_collection(const std::string& collection) {
|
||||
Option<bool> ConversationManager::remove_history_collection(const std::string& collection) {
|
||||
std::unique_lock lock(conversations_mutex);
|
||||
if(conversation_collection_map.count(collection) == 0) {
|
||||
if(history_collection_map.count(collection) == 0) {
|
||||
return Option<bool>(404, "Collection not found");
|
||||
}
|
||||
|
||||
conversation_collection_map[collection]--;
|
||||
history_collection_map[collection]--;
|
||||
|
||||
if(conversation_collection_map[collection] == 0) {
|
||||
if(history_collection_map[collection] == 0) {
|
||||
std::vector<std::string> conversations_to_delete;
|
||||
for(auto& conversation : conversation_mapper) {
|
||||
if(conversation.second == collection) {
|
||||
@ -511,13 +511,13 @@ Option<bool> ConversationManager::remove_conversation_collection(const std::stri
|
||||
for(auto conversation_id : conversations_to_delete) {
|
||||
conversation_mapper.erase(conversation_id);
|
||||
}
|
||||
conversation_collection_map.erase(collection);
|
||||
history_collection_map.erase(collection);
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<Collection*> ConversationManager::get_conversation_collection(const std::string& conversation_id) {
|
||||
Option<Collection*> ConversationManager::get_history_collection(const std::string& conversation_id) {
|
||||
|
||||
if(conversation_mapper.count(conversation_id) > 0) {
|
||||
auto collection = CollectionManager::get_instance().get_collection(conversation_mapper[conversation_id]).get();
|
||||
@ -526,7 +526,7 @@ Option<Collection*> ConversationManager::get_conversation_collection(const std::
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& collection : conversation_collection_map) {
|
||||
for(auto& collection : history_collection_map) {
|
||||
auto collection_ptr = CollectionManager::get_instance().get_collection(collection.first).get();
|
||||
if(!collection_ptr) {
|
||||
continue;
|
||||
|
@ -28,11 +28,11 @@ Option<bool> ConversationModel::validate_model(const nlohmann::json& model_confi
|
||||
return Option<bool>(400, "Property `max_bytes` is not provided or not a positive integer.");
|
||||
}
|
||||
|
||||
if(model_config.count("conversation_collection") == 0 || !model_config["conversation_collection"].is_string()) {
|
||||
return Option<bool>(400, "Property `conversation_collection` is not provided or not a string.");
|
||||
if(model_config.count("history_collection") == 0 || !model_config["history_collection"].is_string()) {
|
||||
return Option<bool>(400, "Property `history_collection` is not provided or not a string.");
|
||||
}
|
||||
|
||||
auto validate_converson_collection_op = ConversationManager::get_instance().validate_conversation_store_collection(model_config["conversation_collection"].get<std::string>());
|
||||
auto validate_converson_collection_op = ConversationManager::get_instance().validate_conversation_store_collection(model_config["history_collection"].get<std::string>());
|
||||
if(!validate_converson_collection_op.ok()) {
|
||||
return Option<bool>(400, validate_converson_collection_op.error());
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ Option<nlohmann::json> ConversationModelManager::add_model_unsafe(nlohmann::json
|
||||
models[model_id] = model;
|
||||
|
||||
|
||||
ConversationManager::get_instance().add_conversation_collection(model["conversation_collection"]);
|
||||
ConversationManager::get_instance().add_history_collection(model["history_collection"]);
|
||||
return Option<nlohmann::json>(model);
|
||||
}
|
||||
|
||||
@ -57,8 +57,8 @@ Option<nlohmann::json> ConversationModelManager::delete_model_unsafe(const std::
|
||||
auto model_key = get_model_key(model_id);
|
||||
bool delete_op = store->remove(model_key);
|
||||
|
||||
if(model.count("conversation_collection") != 0) {
|
||||
ConversationManager::get_instance().remove_conversation_collection(model["conversation_collection"].get<std::string>());
|
||||
if(model.count("history_collection") != 0) {
|
||||
ConversationManager::get_instance().remove_history_collection(model["history_collection"].get<std::string>());
|
||||
}
|
||||
models.erase(it);
|
||||
return Option<nlohmann::json>(model);
|
||||
@ -94,9 +94,9 @@ Option<nlohmann::json> ConversationModelManager::update_model(const std::string&
|
||||
return Option<nlohmann::json>(500, "Error while inserting model into the store");
|
||||
}
|
||||
|
||||
if(it->second["conversation_collection"] != model["conversation_collection"]) {
|
||||
ConversationManager::get_instance().remove_conversation_collection(it->second["conversation_collection"]);
|
||||
ConversationManager::get_instance().add_conversation_collection(model["conversation_collection"]);
|
||||
if(it->second["history_collection"] != model["history_collection"]) {
|
||||
ConversationManager::get_instance().remove_history_collection(it->second["history_collection"]);
|
||||
ConversationManager::get_instance().add_history_collection(model["history_collection"]);
|
||||
}
|
||||
|
||||
models[model_id] = model;
|
||||
@ -131,7 +131,7 @@ Option<int> ConversationModelManager::init(Store* store) {
|
||||
|
||||
|
||||
// Migrate models that don't have a conversation collection
|
||||
if(model_json.count("conversation_collection") == 0) {
|
||||
if(model_json.count("history_collection") == 0) {
|
||||
auto delete_op = delete_model_unsafe(model_id);
|
||||
if(!delete_op.ok()) {
|
||||
return Option<int>(delete_op.code(), delete_op.error());
|
||||
@ -144,7 +144,7 @@ Option<int> ConversationModelManager::init(Store* store) {
|
||||
}
|
||||
|
||||
models[model_id] = model_json;
|
||||
ConversationManager::get_instance().add_conversation_collection(model_json["conversation_collection"].get<std::string>());
|
||||
ConversationManager::get_instance().add_history_collection(model_json["history_collection"].get<std::string>());
|
||||
loaded_models++;
|
||||
}
|
||||
|
||||
@ -156,13 +156,13 @@ const std::string ConversationModelManager::get_model_key(const std::string& mod
|
||||
}
|
||||
|
||||
|
||||
Option<Collection*> ConversationModelManager::get_default_conversation_collection() {
|
||||
Option<Collection*> ConversationModelManager::get_default_history_collection() {
|
||||
int64_t time_epoch;
|
||||
if(DEFAULT_CONVERSATION_COLLECTION_SUFFIX == 0) {
|
||||
if(DEFAULT_HISTORY_COLLECTION_SUFFIX == 0) {
|
||||
time_epoch = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
DEFAULT_CONVERSATION_COLLECTION_SUFFIX = time_epoch;
|
||||
DEFAULT_HISTORY_COLLECTION_SUFFIX = time_epoch;
|
||||
} else {
|
||||
time_epoch = DEFAULT_CONVERSATION_COLLECTION_SUFFIX;
|
||||
time_epoch = DEFAULT_HISTORY_COLLECTION_SUFFIX;
|
||||
}
|
||||
|
||||
std::string collection_id = "default_conversation_history_" + std::to_string(time_epoch);
|
||||
@ -206,11 +206,11 @@ Option<Collection*> ConversationModelManager::get_default_conversation_collectio
|
||||
|
||||
Option<nlohmann::json> ConversationModelManager::migrate_model(nlohmann::json model) {
|
||||
auto model_id = model["id"];
|
||||
auto default_collection = get_default_conversation_collection();
|
||||
auto default_collection = get_default_history_collection();
|
||||
if(!default_collection.ok()) {
|
||||
return Option<nlohmann::json>(default_collection.code(), default_collection.error());
|
||||
}
|
||||
model["conversation_collection"] = default_collection.get()->get_name();
|
||||
model["history_collection"] = default_collection.get()->get_name();
|
||||
auto add_res = add_model_unsafe(model, model_id);
|
||||
if(!add_res.ok()) {
|
||||
return Option<nlohmann::json>(add_res.code(), add_res.error());
|
||||
|
@ -965,7 +965,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
new_conversation_history.push_back(formatted_answer_op.get());
|
||||
std::string conversation_id = conversation_history ? orig_req_params["conversation_id"] : "";
|
||||
|
||||
auto add_conversation_op = ConversationManager::get_instance().add_conversation(new_conversation_history, conversation_model["conversation_collection"], conversation_id);
|
||||
auto add_conversation_op = ConversationManager::get_instance().add_conversation(new_conversation_history, conversation_model["history_collection"], conversation_id);
|
||||
if(!add_conversation_op.ok()) {
|
||||
res->set_400(add_conversation_op.error());
|
||||
return false;
|
||||
|
@ -3130,7 +3130,7 @@ TEST_F(CollectionVectorTest, TestQAConversation) {
|
||||
auto conversation_model_config = R"({
|
||||
"model_name": "openai/gpt-3.5-turbo",
|
||||
"max_bytes: 1000,
|
||||
"conversation_collection": "conversation_store",
|
||||
"history_collection": "conversation_store",
|
||||
})"_json;
|
||||
|
||||
conversation_model_config["api_key"] = api_key;
|
||||
@ -3578,7 +3578,7 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) {
|
||||
auto conversation_model_config = R"({
|
||||
"model_name": "openai/gpt-3.5-turbo",
|
||||
"max_bytes": 1000,
|
||||
"conversation_collection": "conversation_store"
|
||||
"history_collection": "conversation_store"
|
||||
})"_json;
|
||||
|
||||
conversation_model_config["api_key"] = api_key;
|
||||
@ -3658,7 +3658,7 @@ TEST_F(CollectionVectorTest, TestMigratingConversationModel) {
|
||||
auto conversation_model_config = R"({
|
||||
"model_name": "openai/gpt-3.5-turbo",
|
||||
"max_bytes": 1000,
|
||||
"conversation_collection": "conversation_store"
|
||||
"history_collection": "conversation_store"
|
||||
})"_json;
|
||||
|
||||
if (std::getenv("api_key") == nullptr) {
|
||||
@ -3671,7 +3671,7 @@ TEST_F(CollectionVectorTest, TestMigratingConversationModel) {
|
||||
auto migrate_res = ConversationModelManager::migrate_model(conversation_model_config);
|
||||
ASSERT_TRUE(migrate_res.ok());
|
||||
auto migrated_model = migrate_res.get();
|
||||
ASSERT_TRUE(migrated_model.count("conversation_collection") == 1);
|
||||
ASSERT_TRUE(migrated_model.count("history_collection") == 1);
|
||||
|
||||
auto collection = CollectionManager::get_instance().get_collection("conversation_store").get();
|
||||
ASSERT_TRUE(collection != nullptr);
|
||||
|
@ -40,11 +40,11 @@ class ConversationTest : public ::testing::Test {
|
||||
})"_json;
|
||||
|
||||
collectionManager.create_collection(schema_json);
|
||||
ConversationManager::get_instance().add_conversation_collection("conversation_store");
|
||||
ConversationManager::get_instance().add_history_collection("conversation_store");
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
ConversationManager::get_instance().remove_conversation_collection("conversation_store");
|
||||
ConversationManager::get_instance().remove_history_collection("conversation_store");
|
||||
collectionManager.dispose();
|
||||
delete store;
|
||||
}
|
||||
@ -239,7 +239,7 @@ TEST_F(ConversationTest, TestRemoveConversationCollection) {
|
||||
})"_json;
|
||||
LOG(INFO) << "Creating collection";
|
||||
collectionManager.create_collection(schema_json);
|
||||
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
|
||||
ConversationManager::get_instance().add_history_collection("conversation_store2");
|
||||
LOG(INFO) << "Collection created";
|
||||
nlohmann::json conversation = nlohmann::json::array();
|
||||
nlohmann::json message = nlohmann::json::object();
|
||||
@ -257,7 +257,7 @@ TEST_F(ConversationTest, TestRemoveConversationCollection) {
|
||||
|
||||
|
||||
LOG(INFO) << "Removing collection";
|
||||
auto remove_res = ConversationManager::get_instance().remove_conversation_collection("conversation_store2");
|
||||
auto remove_res = ConversationManager::get_instance().remove_history_collection("conversation_store2");
|
||||
ASSERT_TRUE(remove_res.ok());
|
||||
|
||||
LOG(INFO) << "Getting conversation";
|
||||
@ -267,7 +267,7 @@ TEST_F(ConversationTest, TestRemoveConversationCollection) {
|
||||
ASSERT_EQ(get_res.error(), "Conversation not found");
|
||||
|
||||
LOG(INFO) << "Adding collection again";
|
||||
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
|
||||
ConversationManager::get_instance().add_history_collection("conversation_store2");
|
||||
|
||||
|
||||
LOG(INFO) << "Adding conversation";
|
||||
@ -313,8 +313,8 @@ TEST_F(ConversationTest, TestMultipleRefConversationCollection) {
|
||||
})"_json;
|
||||
|
||||
collectionManager.create_collection(schema_json);
|
||||
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
|
||||
ConversationManager::get_instance().add_conversation_collection("conversation_store2");
|
||||
ConversationManager::get_instance().add_history_collection("conversation_store2");
|
||||
ConversationManager::get_instance().add_history_collection("conversation_store2");
|
||||
|
||||
nlohmann::json conversation = nlohmann::json::array();
|
||||
nlohmann::json message = nlohmann::json::object();
|
||||
@ -328,7 +328,7 @@ TEST_F(ConversationTest, TestMultipleRefConversationCollection) {
|
||||
ASSERT_TRUE(get_res.ok());
|
||||
ASSERT_EQ(get_res.get()["conversation"].size(), 1);
|
||||
|
||||
auto remove_res = ConversationManager::get_instance().remove_conversation_collection("conversation_store2");
|
||||
auto remove_res = ConversationManager::get_instance().remove_history_collection("conversation_store2");
|
||||
ASSERT_TRUE(remove_res.ok());
|
||||
|
||||
get_res = ConversationManager::get_instance().get_conversation(create_res.get());
|
||||
@ -348,7 +348,7 @@ TEST_F(ConversationTest, TestInvalidConversationCollection) {
|
||||
})"_json;
|
||||
|
||||
collectionManager.create_collection(schema_json);
|
||||
auto res = ConversationManager::get_instance().add_conversation_collection("conversation_store2");
|
||||
auto res = ConversationManager::get_instance().add_history_collection("conversation_store2");
|
||||
ASSERT_FALSE(res.ok());
|
||||
ASSERT_EQ(res.code(), 400);
|
||||
ASSERT_EQ(res.error(), "Schema is missing `conversation_id` field");
|
||||
|
@ -54,7 +54,7 @@ protected:
|
||||
})"_json;
|
||||
|
||||
collectionManager.create_collection(schema_json);
|
||||
ConversationManager::get_instance().add_conversation_collection("conversation_store");
|
||||
ConversationManager::get_instance().add_history_collection("conversation_store");
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
@ -1415,7 +1415,7 @@ TEST_F(CoreAPIUtilsTest, TestConversationModels) {
|
||||
nlohmann::json model_config = R"({
|
||||
"model_name": "openai/gpt-3.5-turbo",
|
||||
"max_bytes": 10000,
|
||||
"conversation_collection": "conversation_store"
|
||||
"history_collection": "conversation_store"
|
||||
})"_json;
|
||||
|
||||
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
@ -1456,7 +1456,7 @@ TEST_F(CoreAPIUtilsTest, TestConversationModels) {
|
||||
TEST_F(CoreAPIUtilsTest, TestInvalidConversationModels) {
|
||||
// test with no model_name
|
||||
nlohmann::json model_config = R"({
|
||||
"conversation_collection": "conversation_store"
|
||||
"history_collection": "conversation_store"
|
||||
})"_json;
|
||||
|
||||
if (std::getenv("api_key") == nullptr) {
|
||||
@ -1546,17 +1546,17 @@ TEST_F(CoreAPIUtilsTest, TestInvalidConversationModels) {
|
||||
ASSERT_EQ("Property `max_bytes` must be a positive number.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
model_config["max_bytes"] = 10000;
|
||||
model_config["conversation_collection"] = 123;
|
||||
model_config["history_collection"] = 123;
|
||||
|
||||
// test with conversation_collection as integer
|
||||
// test with history_collection as integer
|
||||
req->body = model_config.dump();
|
||||
post_conversation_model(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("Property `conversation_collection` is not provided or not a string.", nlohmann::json::parse(resp->body)["message"]);
|
||||
ASSERT_EQ("Property `history_collection` is not provided or not a string.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
// test with conversation_collection as empty string
|
||||
model_config["conversation_collection"] = "";
|
||||
// test with history_collection as empty string
|
||||
model_config["history_collection"] = "";
|
||||
|
||||
req->body = model_config.dump();
|
||||
post_conversation_model(req, resp);
|
||||
|
Loading…
x
Reference in New Issue
Block a user