From c451b77629105e2993a91697681ffdc7126d5bf5 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 21 Dec 2023 00:19:11 +0300 Subject: [PATCH 1/2] Fix invalid usage of conversation-related parameters in POST body for multi search --- src/core_api.cpp | 25 ++++++ test/collection_vector_search_test.cpp | 103 +++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/src/core_api.cpp b/src/core_api.cpp index bfdcce1e..519a5a42 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -619,6 +619,31 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p continue; } + + + if(conversation && search_item.key() == "q") { + // q is common for all searches + res->set_400("`q` parameter cannot be used in POST body if `conversation` is enabled. Please set `q` as a query parameter in the request, instead of inside the POST body"); + return false; + } + + if(conversation && search_item.key() == "conversation_model_id") { + // conversation_model_id is common for all searches + res->set_400("`conversation_model_id` cannot be used in POST body. Please set `conversation_model_id` as a query parameter in the request, instead of inside the POST body"); + return false; + } + + if(conversation && search_item.key() == "conversation_id") { + // conversation_id is common for all searches + res->set_400("`conversation_id` cannot be used in POST body. Please set `conversation_id` as a query parameter in the request, instead of inside the POST body"); + return false; + } + + if(search_item.key() == "conversation") { + res->set_400("`conversation` cannot be used in POST body. Please set `conversation` as a query parameter in the request, instead of inside the POST body"); + return false; + } + // overwrite = false since req params will contain embedded params and so has higher priority bool populated = AuthManager::add_item_to_params(req->params, search_item, false); if(!populated) { diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 17b75aeb..7a2d6467 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -8,6 +8,7 @@ #include "conversation_manager.h" #include "conversation_model_manager.h" #include "index.h" +#include "core_api.h" class CollectionVectorTest : public ::testing::Test { protected: @@ -3293,4 +3294,106 @@ TEST_F(CollectionVectorTest, TestEmbeddingValues) { for (int i = 0; i < 384; i++) { EXPECT_NEAR(normalized_embeddings[i], actual_values[i], 0.00001); } +} + + +TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { + auto schema_json = + R"({ + "name": "test", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/all-MiniLM-L12-v2"}}} + ] + })"_json; + + EmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + if (std::getenv("api_key") == nullptr) { + LOG(INFO) << "Skipping test as api_key is not set."; + return; + } + + auto api_key = std::string(std::getenv("api_key")); + + auto conversation_model_config = R"({ + "model_name": "openai/gpt-3.5-turbo" + })"_json; + + conversation_model_config["api_key"] = api_key; + + auto model_add_op = ConversationModelManager::add_model(conversation_model_config); + + ASSERT_TRUE(model_add_op.ok()); + + auto model_id = model_add_op.get()["id"]; + + + auto collection_create_op = collectionManager.create_collection(schema_json); + + ASSERT_TRUE(collection_create_op.ok()); + + nlohmann::json search_body; + search_body["searches"] = nlohmann::json::array(); + + nlohmann::json search1; + search1["collection"] = "test"; + search1["q"] = "dog"; + search1["query_by"] = "embedding"; + + search_body["searches"].push_back(search1); + + std::shared_ptr req = std::make_shared(); + std::shared_ptr res = std::make_shared(nullptr); + + req->params["conversation"] = "true"; + req->params["conversation_model_id"] = to_string(model_id); + req->params["q"] = "cat"; + + req->body = search_body.dump(); + nlohmann::json embedded_params; + req->embedded_params_vec.push_back(embedded_params); + + post_multi_search(req, res); + auto res_json = nlohmann::json::parse(res->body); + ASSERT_EQ(res->status_code, 400); + ASSERT_EQ(res_json["message"], "`q` parameter cannot be used in POST body if `conversation` is enabled. Please set `q` as a query parameter in the request, instead of inside the POST body"); + + search_body["searches"][0].erase("q"); + search_body["searches"][0]["conversation_model_id"] = to_string(model_id); + + req->body = search_body.dump(); + + post_multi_search(req, res); + + res_json = nlohmann::json::parse(res->body); + ASSERT_EQ(res->status_code, 400); + ASSERT_EQ(res_json["message"], "`conversation_model_id` cannot be used in POST body. Please set `conversation_model_id` as a query parameter in the request, instead of inside the POST body"); + + + search_body["searches"][0].erase("conversation_model_id"); + search_body["searches"][0]["conversation_id"] = "123"; + + req->body = search_body.dump(); + + post_multi_search(req, res); + + res_json = nlohmann::json::parse(res->body); + ASSERT_EQ(res->status_code, 400); + + ASSERT_EQ(res_json["message"], "`conversation_id` cannot be used in POST body. Please set `conversation_id` as a query parameter in the request, instead of inside the POST body"); + + + search_body["searches"][0].erase("conversation_id"); + search_body["searches"][0]["conversation"] = true; + + req->body = search_body.dump(); + + post_multi_search(req, res); + + res_json = nlohmann::json::parse(res->body); + ASSERT_EQ(res->status_code, 400); + + + ASSERT_EQ(res_json["message"], "`conversation` cannot be used in POST body. Please set `conversation` as a query parameter in the request, instead of inside the POST body"); } \ No newline at end of file From 1cfff7e886c2a76b58b0eb8ce58efa4f5ab0f1b1 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 21 Dec 2023 11:42:24 +0300 Subject: [PATCH 2/2] Remove unnecesarry blank lines --- src/core_api.cpp | 2 -- test/collection_vector_search_test.cpp | 6 ------ 2 files changed, 8 deletions(-) diff --git a/src/core_api.cpp b/src/core_api.cpp index 519a5a42..42bda405 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -619,8 +619,6 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p continue; } - - if(conversation && search_item.key() == "q") { // q is common for all searches res->set_400("`q` parameter cannot be used in POST body if `conversation` is enabled. Please set `q` as a query parameter in the request, instead of inside the POST body"); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 7a2d6467..bbef267f 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -3296,7 +3296,6 @@ TEST_F(CollectionVectorTest, TestEmbeddingValues) { } } - TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { auto schema_json = R"({ @@ -3327,8 +3326,6 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { ASSERT_TRUE(model_add_op.ok()); auto model_id = model_add_op.get()["id"]; - - auto collection_create_op = collectionManager.create_collection(schema_json); ASSERT_TRUE(collection_create_op.ok()); @@ -3370,7 +3367,6 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { ASSERT_EQ(res->status_code, 400); ASSERT_EQ(res_json["message"], "`conversation_model_id` cannot be used in POST body. Please set `conversation_model_id` as a query parameter in the request, instead of inside the POST body"); - search_body["searches"][0].erase("conversation_model_id"); search_body["searches"][0]["conversation_id"] = "123"; @@ -3383,7 +3379,6 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { ASSERT_EQ(res_json["message"], "`conversation_id` cannot be used in POST body. Please set `conversation_id` as a query parameter in the request, instead of inside the POST body"); - search_body["searches"][0].erase("conversation_id"); search_body["searches"][0]["conversation"] = true; @@ -3394,6 +3389,5 @@ TEST_F(CollectionVectorTest, InvalidMultiSearchConversation) { res_json = nlohmann::json::parse(res->body); ASSERT_EQ(res->status_code, 400); - ASSERT_EQ(res_json["message"], "`conversation` cannot be used in POST body. Please set `conversation` as a query parameter in the request, instead of inside the POST body"); } \ No newline at end of file