Fix invalid usage of conversation-related parameters in POST body for multi search

This commit is contained in:
ozanarmagan 2023-12-21 00:19:11 +03:00
parent 3972b66d5c
commit c451b77629
2 changed files with 128 additions and 0 deletions

View File

@ -619,6 +619,31 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
continue;
}
if(conversation && search_item.key() == "q") {
// q is common for all searches
res->set_400("`q` parameter cannot be used in POST body if `conversation` is enabled. Please set `q` as a query parameter in the request, instead of inside the POST body");
return false;
}
if(conversation && search_item.key() == "conversation_model_id") {
// conversation_model_id is common for all searches
res->set_400("`conversation_model_id` cannot be used in POST body. Please set `conversation_model_id` as a query parameter in the request, instead of inside the POST body");
return false;
}
if(conversation && search_item.key() == "conversation_id") {
// conversation_id is common for all searches
res->set_400("`conversation_id` cannot be used in POST body. Please set `conversation_id` as a query parameter in the request, instead of inside the POST body");
return false;
}
if(search_item.key() == "conversation") {
res->set_400("`conversation` cannot be used in POST body. Please set `conversation` as a query parameter in the request, instead of inside the POST body");
return false;
}
// overwrite = false since req params will contain embedded params and so has higher priority
bool populated = AuthManager::add_item_to_params(req->params, search_item, false);
if(!populated) {

View File

@ -8,6 +8,7 @@
#include "conversation_manager.h"
#include "conversation_model_manager.h"
#include "index.h"
#include "core_api.h"
class CollectionVectorTest : public ::testing::Test {
protected:
@ -3293,4 +3294,106 @@ TEST_F(CollectionVectorTest, TestEmbeddingValues) {
for (int i = 0; i < 384; i++) {
EXPECT_NEAR(normalized_embeddings[i], actual_values[i], 0.00001);
}
}
TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) {
auto schema_json =
R"({
"name": "test",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/all-MiniLM-L12-v2"}}}
]
})"_json;
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
if (std::getenv("api_key") == nullptr) {
LOG(INFO) << "Skipping test as api_key is not set.";
return;
}
auto api_key = std::string(std::getenv("api_key"));
auto conversation_model_config = R"({
"model_name": "openai/gpt-3.5-turbo"
})"_json;
conversation_model_config["api_key"] = api_key;
auto model_add_op = ConversationModelManager::add_model(conversation_model_config);
ASSERT_TRUE(model_add_op.ok());
auto model_id = model_add_op.get()["id"];
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
nlohmann::json search_body;
search_body["searches"] = nlohmann::json::array();
nlohmann::json search1;
search1["collection"] = "test";
search1["q"] = "dog";
search1["query_by"] = "embedding";
search_body["searches"].push_back(search1);
std::shared_ptr<http_req> req = std::make_shared<http_req>();
std::shared_ptr<http_res> res = std::make_shared<http_res>(nullptr);
req->params["conversation"] = "true";
req->params["conversation_model_id"] = to_string(model_id);
req->params["q"] = "cat";
req->body = search_body.dump();
nlohmann::json embedded_params;
req->embedded_params_vec.push_back(embedded_params);
post_multi_search(req, res);
auto res_json = nlohmann::json::parse(res->body);
ASSERT_EQ(res->status_code, 400);
ASSERT_EQ(res_json["message"], "`q` parameter cannot be used in POST body if `conversation` is enabled. Please set `q` as a query parameter in the request, instead of inside the POST body");
search_body["searches"][0].erase("q");
search_body["searches"][0]["conversation_model_id"] = to_string(model_id);
req->body = search_body.dump();
post_multi_search(req, res);
res_json = nlohmann::json::parse(res->body);
ASSERT_EQ(res->status_code, 400);
ASSERT_EQ(res_json["message"], "`conversation_model_id` cannot be used in POST body. Please set `conversation_model_id` as a query parameter in the request, instead of inside the POST body");
search_body["searches"][0].erase("conversation_model_id");
search_body["searches"][0]["conversation_id"] = "123";
req->body = search_body.dump();
post_multi_search(req, res);
res_json = nlohmann::json::parse(res->body);
ASSERT_EQ(res->status_code, 400);
ASSERT_EQ(res_json["message"], "`conversation_id` cannot be used in POST body. Please set `conversation_id` as a query parameter in the request, instead of inside the POST body");
search_body["searches"][0].erase("conversation_id");
search_body["searches"][0]["conversation"] = true;
req->body = search_body.dump();
post_multi_search(req, res);
res_json = nlohmann::json::parse(res->body);
ASSERT_EQ(res->status_code, 400);
ASSERT_EQ(res_json["message"], "`conversation` cannot be used in POST body. Please set `conversation` as a query parameter in the request, instead of inside the POST body");
}