Add support for system_prompt

This commit is contained in:
ozanarmagan 2023-08-21 19:17:54 +03:00
parent 6560952eea
commit 4802f2d3e9
3 changed files with 12 additions and 4 deletions

View File

@ -56,6 +56,7 @@ namespace fields {
static const std::string model_name = "model_name";
static const std::string range_index = "range_index";
static const std::string qa = "qa";
static const std::string system_prompt = "system_prompt";
// Some models require additional parameters to be passed to the model during indexing/querying
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying

View File

@ -290,6 +290,11 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
return Option<bool>(400, "Property `" + fields::qa + "` should be an object containing a `model_name` property.");
}
// it may contain "system_prompt", if exists it should be a string
if(field_json[fields::qa].count(fields::system_prompt) != 0 && !field_json[fields::qa][fields::system_prompt].is_string()) {
return Option<bool>(400, "`qa.system_prompt` should be a string.");
}
auto validate_qa_res = QAModel::validate_model(field_json[fields::qa]);
if(!validate_qa_res.ok()) {
return validate_qa_res;

View File

@ -114,10 +114,12 @@ Option<std::string> OpenAIQAModel::get_answer(const std::string& context, const
nlohmann::json req_body;
req_body["model"] = model_name;
req_body["messages"] = nlohmann::json::array();
static nlohmann::json system_message = nlohmann::json::object();
system_message["role"] = "system";
system_message["content"] = "You are an assistant, answer questions according to the data in JSON format. Do not mention the data content in your answer.";
req_body["messages"].push_back(system_message);
if(model_config.count("system_prompt") != 0) {
nlohmann::json system_message = nlohmann::json::object();
system_message["role"] = "system";
system_message["content"] = model_config["system_prompt"].get<std::string>();
req_body["messages"].push_back(system_message);
}
nlohmann::json message = nlohmann::json::object();
message["role"] = "user";
message["content"] = "Data:\n" + context + "\n\nQuestion:\n" + prompt;