Move conversation model init to raft init db.

This also takes care of a clustering quirk, where this gets called later, once follower restores data from leader.
This commit is contained in:
Kishore Nallan 2024-08-01 23:00:40 +05:30
parent ca2ec6a383
commit 73bde1ed0d
3 changed files with 15 additions and 16 deletions

View File

@ -14,7 +14,6 @@ Option<nlohmann::json> ConversationModelManager::get_model(const std::string& mo
Option<nlohmann::json> ConversationModelManager::add_model(nlohmann::json model, const std::string& model_id) {
std::unique_lock lock(models_mutex);
return add_model_unsafe(model, model_id);
}
@ -24,7 +23,6 @@ Option<nlohmann::json> ConversationModelManager::add_model_unsafe(nlohmann::json
return Option<nlohmann::json>(validate_res.code(), validate_res.error());
}
model["id"] = model_id.empty() ? sole::uuid4().str() : model_id;
auto model_key = get_model_key(model_id);
@ -35,14 +33,12 @@ Option<nlohmann::json> ConversationModelManager::add_model_unsafe(nlohmann::json
models[model_id] = model;
ConversationManager::get_instance().add_history_collection(model["history_collection"]);
return Option<nlohmann::json>(model);
}
Option<nlohmann::json> ConversationModelManager::delete_model(const std::string& model_id) {
std::unique_lock lock(models_mutex);
return delete_model_unsafe(model_id);
}
@ -111,16 +107,22 @@ Option<int> ConversationModelManager::init(Store* store) {
std::vector<std::string> model_strs;
store->scan_fill(std::string(MODEL_KEY_PREFIX) + "_", std::string(MODEL_KEY_PREFIX) + "`", model_strs);
if(!model_strs.empty()) {
LOG(INFO) << "Found " << model_strs.size() << " conversation model(s).";
}
int loaded_models = 0;
for(auto& model_str : model_strs) {
nlohmann::json model_json = nlohmann::json::parse(model_str);
std::string model_id = model_json["id"];
// Migrate cloudflare models to new namespace convention, change namespace from `cf` to `cloudflare`
if(EmbedderManager::get_model_namespace(model_json["model_name"]) == "cf") {
auto delete_op = delete_model(model_id);
if(!delete_op.ok()) {
return Option<int>(delete_op.code(), delete_op.error());
}
model_json["model_name"] = "cloudflare/" + EmbedderManager::get_model_name_without_namespace(model_json["model_name"]);
auto add_res = add_model(model_json, model_id);
if(!add_res.ok()) {
@ -128,8 +130,6 @@ Option<int> ConversationModelManager::init(Store* store) {
}
}
// Migrate models that don't have a conversation collection
if(model_json.count("history_collection") == 0) {
auto migrate_op = migrate_model(model_json);

View File

@ -7,6 +7,7 @@
#include <file_utils.h>
#include <collection_manager.h>
#include <http_client.h>
#include <conversation_model_manager.h>
#include "rocksdb/utilities/checkpoint.h"
#include "thread_local_vars.h"
#include "core_api.h"
@ -628,6 +629,14 @@ int ReplicationState::init_db() {
return 1;
}
// important to init conversation models only after all collections have been loaded
auto conversation_models_init = ConversationModelManager::init(store);
if(!conversation_models_init.ok()) {
LOG(INFO) << "Failed to initialize conversation model manager: " << conversation_models_init.error();
} else {
LOG(INFO) << "Loaded " << conversation_models_init.get() << "conversation model(s).";
}
if(batched_indexer != nullptr) {
LOG(INFO) << "Initializing batched indexer from snapshot state...";
std::string batched_indexer_state_str;

View File

@ -20,11 +20,9 @@
#include "ratelimit_manager.h"
#include "embedder_manager.h"
#include "typesense_server_utils.h"
#include "file_utils.h"
#include "threadpool.h"
#include "stopwords_manager.h"
#include "conversation_manager.h"
#include "conversation_model_manager.h"
#include "vq_model_manager.h"
#ifndef ASAN_BUILD
@ -293,14 +291,6 @@ int start_raft_server(ReplicationState& replication_state, Store& store,
exit(-1);
}
// important to init conversation models only after all collections have been loaded
auto conversation_models_init = ConversationModelManager::init(&store);
if(!conversation_models_init.ok()) {
LOG(INFO) << "Failed to initialize conversation model manager: " << conversation_models_init.error();
} else {
LOG(INFO) << "Loaded " << conversation_models_init.get() << "(s) conversation models.";
}
LOG(INFO) << "Typesense peering service is running on " << raft_server.listen_address();
LOG(INFO) << "Snapshot interval configured as: " << snapshot_interval_seconds << "s";
LOG(INFO) << "Snapshot max byte count configured as: " << snapshot_max_byte_count_per_rpc;