mirror of
https://github.com/typesense/typesense.git
synced 2025-05-17 20:22:32 +08:00
Update conversation_model_id parameter to use string
This commit is contained in:
parent
1cfff7e886
commit
c08beadb51
@ -561,7 +561,7 @@ public:
|
||||
const bool prioritize_num_matching_fields = true,
|
||||
const bool group_missing_values = true,
|
||||
const bool converstaion = false,
|
||||
const int conversation_model_id = -1,
|
||||
const std::string& conversation_model_id = "",
|
||||
std::string conversation_id = "",
|
||||
const std::string& override_tags_str = "") const;
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include <json.hpp>
|
||||
#include <option.h>
|
||||
#include "store.h"
|
||||
|
||||
#include "sole.hpp"
|
||||
|
||||
class ConversationModelManager
|
||||
{
|
||||
@ -16,22 +16,21 @@ class ConversationModelManager
|
||||
ConversationModelManager(ConversationModelManager&&) = delete;
|
||||
ConversationModelManager& operator=(const ConversationModelManager&) = delete;
|
||||
|
||||
static Option<nlohmann::json> get_model(const uint32_t model_id);
|
||||
static Option<nlohmann::json> get_model(const std::string& model_id);
|
||||
static Option<nlohmann::json> add_model(nlohmann::json model);
|
||||
static Option<nlohmann::json> delete_model(const uint32_t model_id);
|
||||
static Option<nlohmann::json> update_model(const uint32_t model_id, nlohmann::json model);
|
||||
static Option<nlohmann::json> delete_model(const std::string& model_id);
|
||||
static Option<nlohmann::json> update_model(const std::string& model_id, nlohmann::json model);
|
||||
static Option<nlohmann::json> get_all_models();
|
||||
|
||||
|
||||
static Option<int> init(Store* store);
|
||||
private:
|
||||
static inline std::unordered_map<uint32_t, nlohmann::json> models;
|
||||
static inline uint32_t model_id = 0;
|
||||
static inline std::unordered_map<std::string, nlohmann::json> models;
|
||||
static inline std::shared_mutex models_mutex;
|
||||
|
||||
static constexpr char* MODEL_NEXT_ID = "$CVMN";
|
||||
static constexpr char* MODEL_KEY_PREFIX = "$CVMP";
|
||||
|
||||
static inline Store* store;
|
||||
static const std::string get_model_key(uint32_t model_id);
|
||||
static const std::string get_model_key(const std::string& model_id);
|
||||
};
|
@ -1616,7 +1616,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const bool prioritize_num_matching_fields,
|
||||
const bool group_missing_values,
|
||||
const bool conversation,
|
||||
const int conversation_model_id,
|
||||
const std::string& conversation_model_id,
|
||||
std::string conversation_id,
|
||||
const std::string& override_tags_str) const {
|
||||
std::shared_lock lock(mutex);
|
||||
@ -1715,7 +1715,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
std::string query = raw_query;
|
||||
|
||||
if(conversation) {
|
||||
if(conversation_model_id == -1) {
|
||||
if(conversation_model_id.empty()) {
|
||||
return Option<nlohmann::json>(400, "Conversation is enabled but no conversation model ID is provided.");
|
||||
}
|
||||
|
||||
|
@ -1260,7 +1260,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
|
||||
bool conversation = false;
|
||||
std::string conversation_id;
|
||||
size_t conversation_model_id = std::numeric_limits<size_t>::max();
|
||||
std::string conversation_model_id;
|
||||
|
||||
std::string drop_tokens_mode_str = "right_to_left";
|
||||
bool prioritize_num_matching_fields = true;
|
||||
@ -1291,7 +1291,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
|
||||
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
|
||||
{REMOTE_EMBEDDING_NUM_TRIES, &remote_embedding_num_tries},
|
||||
{CONVERSATION_MODEL_ID, &conversation_model_id},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string*> str_values = {
|
||||
@ -1307,6 +1306,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{CONVERSATION_ID, &conversation_id},
|
||||
{DROP_TOKENS_MODE, &drop_tokens_mode_str},
|
||||
{OVERRIDE_TAGS, &override_tags},
|
||||
{CONVERSATION_MODEL_ID, &conversation_model_id},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, bool*> bool_values = {
|
||||
@ -1525,7 +1525,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
prioritize_num_matching_fields,
|
||||
group_missing_values,
|
||||
conversation,
|
||||
(conversation_model_id == std::numeric_limits<size_t>::max()) ? -1 : static_cast<int>(conversation_model_id),
|
||||
conversation_model_id,
|
||||
conversation_id,
|
||||
override_tags);
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "conversation_model_manager.h"
|
||||
#include "conversation_model.h"
|
||||
|
||||
Option<nlohmann::json> ConversationModelManager::get_model(const uint32_t model_id) {
|
||||
Option<nlohmann::json> ConversationModelManager::get_model(const std::string& model_id) {
|
||||
std::shared_lock lock(models_mutex);
|
||||
auto it = models.find(model_id);
|
||||
if (it == models.end()) {
|
||||
@ -17,6 +17,7 @@ Option<nlohmann::json> ConversationModelManager::add_model(nlohmann::json model)
|
||||
if (!validate_res.ok()) {
|
||||
return Option<nlohmann::json>(validate_res.code(), validate_res.error());
|
||||
}
|
||||
auto model_id = sole::uuid4().str();
|
||||
model["id"] = model_id;
|
||||
|
||||
auto model_key = get_model_key(model_id);
|
||||
@ -25,12 +26,12 @@ Option<nlohmann::json> ConversationModelManager::add_model(nlohmann::json model)
|
||||
return Option<nlohmann::json>(500, "Error while inserting model into the store");
|
||||
}
|
||||
|
||||
models[model_id++] = model;
|
||||
models[model_id] = model;
|
||||
|
||||
return Option<nlohmann::json>(model);
|
||||
}
|
||||
|
||||
Option<nlohmann::json> ConversationModelManager::delete_model(const uint32_t model_id) {
|
||||
Option<nlohmann::json> ConversationModelManager::delete_model(const std::string& model_id) {
|
||||
std::unique_lock lock(models_mutex);
|
||||
auto it = models.find(model_id);
|
||||
if (it == models.end()) {
|
||||
@ -56,7 +57,7 @@ Option<nlohmann::json> ConversationModelManager::get_all_models() {
|
||||
return Option<nlohmann::json>(models_json);
|
||||
}
|
||||
|
||||
Option<nlohmann::json> ConversationModelManager::update_model(const uint32_t model_id, nlohmann::json model) {
|
||||
Option<nlohmann::json> ConversationModelManager::update_model(const std::string& model_id, nlohmann::json model) {
|
||||
std::unique_lock lock(models_mutex);
|
||||
auto validate_res = ConversationModel::validate_model(model);
|
||||
if (!validate_res.ok()) {
|
||||
@ -85,24 +86,13 @@ Option<int> ConversationModelManager::init(Store* store) {
|
||||
std::unique_lock lock(models_mutex);
|
||||
ConversationModelManager::store = store;
|
||||
|
||||
std::string last_id_str;
|
||||
StoreStatus last_id_str_status = store->get(std::string(MODEL_NEXT_ID), last_id_str);
|
||||
|
||||
if(last_id_str_status == StoreStatus::ERROR) {
|
||||
return Option<int>(500, "Error while loading conversations next id from the store");
|
||||
} else if(last_id_str_status == StoreStatus::FOUND) {
|
||||
model_id = StringUtils::deserialize_uint32_t(last_id_str);
|
||||
} else {
|
||||
model_id = 0;
|
||||
}
|
||||
|
||||
std::vector<std::string> model_strs;
|
||||
store->scan_fill(std::string(MODEL_KEY_PREFIX) + "_", std::string(MODEL_KEY_PREFIX) + "`", model_strs);
|
||||
|
||||
int loaded_models = 0;
|
||||
for(auto& model_str : model_strs) {
|
||||
nlohmann::json model_json = nlohmann::json::parse(model_str);
|
||||
int model_id = model_json["id"];
|
||||
std::string model_id = model_json["id"];
|
||||
models[model_id] = model_json;
|
||||
loaded_models++;
|
||||
}
|
||||
@ -110,6 +100,6 @@ Option<int> ConversationModelManager::init(Store* store) {
|
||||
return Option<int>(loaded_models);
|
||||
}
|
||||
|
||||
const std::string ConversationModelManager::get_model_key(uint32_t model_id) {
|
||||
return std::string(MODEL_KEY_PREFIX) + "_" + std::to_string(model_id);
|
||||
const std::string ConversationModelManager::get_model_key(const std::string& model_id) {
|
||||
return std::string(MODEL_KEY_PREFIX) + "_" + model_id;
|
||||
}
|
@ -559,15 +559,10 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
return false;
|
||||
}
|
||||
|
||||
if(StringUtils::is_uint32_t(orig_req_params["conversation_model_id"])) {
|
||||
uint32_t conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]);
|
||||
auto conversation_model = ConversationModelManager::get_model(conversation_model_id);
|
||||
const std::string& conversation_model_id = orig_req_params["conversation_model_id"];
|
||||
auto conversation_model = ConversationModelManager::get_model(conversation_model_id);
|
||||
|
||||
if(!conversation_model.ok()) {
|
||||
res->set_400("`conversation_model_id` is invalid.");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if(!conversation_model.ok()) {
|
||||
res->set_400("`conversation_model_id` is invalid.");
|
||||
return false;
|
||||
}
|
||||
@ -586,7 +581,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
common_query = orig_req_params["q"];
|
||||
|
||||
if(conversation_history) {
|
||||
auto conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]);
|
||||
const std::string& conversation_model_id = orig_req_params["conversation_model_id"];
|
||||
auto conversation_id = orig_req_params["conversation_id"];
|
||||
auto conversation_model = ConversationModelManager::get_model(conversation_model_id).get();
|
||||
auto conversation_history = ConversationManager::get_instance().get_conversation(conversation_id).get();
|
||||
@ -601,8 +596,6 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
for(size_t i = 0; i < searches.size(); i++) {
|
||||
auto& search_params = searches[i];
|
||||
|
||||
@ -753,7 +746,7 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
}
|
||||
}
|
||||
|
||||
auto conversation_model_id = std::stoul(orig_req_params["conversation_model_id"]);
|
||||
const std::string& conversation_model_id = orig_req_params["conversation_model_id"];
|
||||
auto conversation_model = ConversationModelManager::get_model(conversation_model_id).get();
|
||||
|
||||
auto prompt = req->params["q"];
|
||||
@ -2714,11 +2707,7 @@ bool post_conversation_model(const std::shared_ptr<http_req>& req, const std::sh
|
||||
}
|
||||
|
||||
bool get_conversation_model(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
|
||||
if(!StringUtils::is_uint32_t(req->params["id"])) {
|
||||
res->set_400("Invalid ID.");
|
||||
return false;
|
||||
}
|
||||
const int model_id = std::stoi(req->params["id"]);
|
||||
const std::string& model_id = req->params["id"];
|
||||
|
||||
auto model_op = ConversationModelManager::get_model(model_id);
|
||||
|
||||
@ -2752,13 +2741,8 @@ bool get_conversation_models(const std::shared_ptr<http_req>& req, const std::sh
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool del_conversation_model(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
|
||||
if(!StringUtils::is_uint32_t(req->params["id"])) {
|
||||
res->set_400("Invalid ID.");
|
||||
return false;
|
||||
}
|
||||
const int model_id = std::stoi(req->params["id"]);
|
||||
const std::string& model_id = req->params["id"];
|
||||
|
||||
auto model_op = ConversationModelManager::delete_model(model_id);
|
||||
|
||||
@ -2776,11 +2760,7 @@ bool del_conversation_model(const std::shared_ptr<http_req>& req, const std::sha
|
||||
}
|
||||
|
||||
bool put_conversation_model(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
|
||||
if(!StringUtils::is_uint32_t(req->params["id"])) {
|
||||
res->set_400("Invalid ID.");
|
||||
return false;
|
||||
}
|
||||
const int model_id = std::stoi(req->params["id"]);
|
||||
const std::string& model_id = req->params["id"];
|
||||
|
||||
nlohmann::json req_json;
|
||||
|
||||
|
@ -3397,7 +3397,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "foo").get();
|
||||
true, true, false, "", "", "foo").get();
|
||||
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
@ -3410,7 +3410,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "alpha").get();
|
||||
true, true, false, "", "", "alpha").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3424,7 +3424,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "beta").get();
|
||||
true, true, false, "", "", "beta").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3438,7 +3438,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "alpha,beta").get();
|
||||
true, true, false, "", "", "alpha,beta").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3452,7 +3452,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTags) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "").get();
|
||||
true, true, false, "", "", "").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("2", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3532,7 +3532,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTagsPartialMatch) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "alpha,zeta").get();
|
||||
true, true, false, "", "", "alpha,zeta").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3632,7 +3632,7 @@ TEST_F(CollectionOverrideTest, OverrideWithTagsWithoutStopProcessing) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "alpha").get();
|
||||
true, true, false, "", "", "alpha").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3708,7 +3708,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", override_tags).get();
|
||||
true, true, false, "", "", override_tags).get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3721,7 +3721,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", override_tags).get();
|
||||
true, true, false, "", "", override_tags).get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3750,7 +3750,7 @@ TEST_F(CollectionOverrideTest, WildcardTagRuleThatMatchesAllQueries) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", override_tags).get();
|
||||
true, true, false, "", "", override_tags).get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3804,7 +3804,7 @@ TEST_F(CollectionOverrideTest, TagsOnlyRule) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "listing").get();
|
||||
true, true, false, "", "", "listing").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3833,7 +3833,7 @@ TEST_F(CollectionOverrideTest, TagsOnlyRule) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", "listing2").get();
|
||||
true, true, false, "", "", "listing2").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -3848,7 +3848,7 @@ TEST_F(CollectionOverrideTest, TagsOnlyRule) {
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 10000,
|
||||
4, 7, fallback, 4, {off}, 100, 100, 2, 2, false, "", true, 0, max_score, 100, 0,
|
||||
0, HASH, 30000, 2, "", {}, {}, "right_to_left",
|
||||
true, true, false, -1, "", override_tag).get();
|
||||
true, true, false, "", "", override_tag).get();
|
||||
|
||||
ASSERT_EQ(0, results["hits"].size());
|
||||
|
||||
|
@ -2919,7 +2919,7 @@ TEST_F(CollectionVectorTest, TestQAConversation) {
|
||||
0, spp::sparse_hash_set<std::string>(), {},
|
||||
10, "", 30, 4, "", 1, "", "", {}, 3, "<mark>", "</mark>", {}, 4294967295UL, true, false,
|
||||
true, "", false, 6000000UL, 4, 7, fallback, 4, {off}, 32767UL, 32767UL, 2, 2, false, "",
|
||||
true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, true, true, 0);
|
||||
true, 0, max_score, 100, 0, 0, HASH, 30000, 2, "", {}, {}, "right_to_left", true, true, true, model_add_op.get()["id"]);
|
||||
|
||||
ASSERT_TRUE(results_op.ok());
|
||||
|
||||
@ -3344,7 +3344,7 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) {
|
||||
std::shared_ptr<http_res> res = std::make_shared<http_res>(nullptr);
|
||||
|
||||
req->params["conversation"] = "true";
|
||||
req->params["conversation_model_id"] = to_string(model_id);
|
||||
req->params["conversation_model_id"] = model_id;
|
||||
req->params["q"] = "cat";
|
||||
|
||||
req->body = search_body.dump();
|
||||
|
Loading…
x
Reference in New Issue
Block a user