From 4802f2d3e9200f7785dcbf5f34f9fa153da9baf7 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Mon, 21 Aug 2023 19:17:54 +0300 Subject: [PATCH] Add support for system_prompt --- include/field.h | 1 + src/field.cpp | 5 +++++ src/qa_model.cpp | 10 ++++++---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/field.h b/include/field.h index c3dd53bb..83790bb7 100644 --- a/include/field.h +++ b/include/field.h @@ -56,6 +56,7 @@ namespace fields { static const std::string model_name = "model_name"; static const std::string range_index = "range_index"; static const std::string qa = "qa"; + static const std::string system_prompt = "system_prompt"; // Some models require additional parameters to be passed to the model during indexing/querying // For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying diff --git a/src/field.cpp b/src/field.cpp index 805b06ce..948a4168 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -290,6 +290,11 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso return Option(400, "Property `" + fields::qa + "` should be an object containing a `model_name` property."); } + // it may contain "system_prompt", if exists it should be a string + if(field_json[fields::qa].count(fields::system_prompt) != 0 && !field_json[fields::qa][fields::system_prompt].is_string()) { + return Option(400, "`qa.system_prompt` should be a string."); + } + auto validate_qa_res = QAModel::validate_model(field_json[fields::qa]); if(!validate_qa_res.ok()) { return validate_qa_res; diff --git a/src/qa_model.cpp b/src/qa_model.cpp index 03b5f511..3f688011 100644 --- a/src/qa_model.cpp +++ b/src/qa_model.cpp @@ -114,10 +114,12 @@ Option OpenAIQAModel::get_answer(const std::string& context, const nlohmann::json req_body; req_body["model"] = model_name; req_body["messages"] = nlohmann::json::array(); - static nlohmann::json system_message = nlohmann::json::object(); - system_message["role"] = "system"; - system_message["content"] = "You are an assistant, answer questions according to the data in JSON format. Do not mention the data content in your answer."; - req_body["messages"].push_back(system_message); + if(model_config.count("system_prompt") != 0) { + nlohmann::json system_message = nlohmann::json::object(); + system_message["role"] = "system"; + system_message["content"] = model_config["system_prompt"].get(); + req_body["messages"].push_back(system_message); + } nlohmann::json message = nlohmann::json::object(); message["role"] = "user"; message["content"] = "Data:\n" + context + "\n\nQuestion:\n" + prompt;