Retrieval Augmented Generation

This commit is contained in:
ozanarmagan 2023-08-21 12:08:00 +03:00
parent 16b59fd23a
commit 6560952eea
8 changed files with 260 additions and 8 deletions

View File

@ -478,7 +478,8 @@ public:
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_tries = 2,
const std::string& stopwords_set="",
const std::vector<std::string>& facet_return_parent = {}) const;
const std::vector<std::string>& facet_return_parent = {},
const std::string& prompt = "") const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -55,6 +55,7 @@ namespace fields {
static const std::string from = "from";
static const std::string model_name = "model_name";
static const std::string range_index = "range_index";
static const std::string qa = "qa";
// Some models require additional parameters to be passed to the model during indexing/querying
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying
@ -88,6 +89,7 @@ struct field {
size_t num_dim;
nlohmann::json embed;
nlohmann::json qa;
vector_distance_type_t vec_dist;
static constexpr int VAL_UNKNOWN = 2;
@ -101,10 +103,10 @@ struct field {
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) :
std::string reference = "", const nlohmann::json& embed = nlohmann::json(),const nlohmann::json& qa = nlohmann::json(), const bool range_index = false) :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
embed(embed), range_index(range_index) {
embed(embed), range_index(range_index), qa(qa) {
set_computed_defaults(sort, infix);
}

26
include/qa_model.h Normal file
View File

@ -0,0 +1,26 @@
#pragma once
#include <string>
#include <json.hpp>
#include "option.h"
class QAModel {
public:
virtual ~QAModel() {};
static Option<std::string> get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config);
static Option<bool> validate_model(const nlohmann::json& model_config);
private:
};
class OpenAIQAModel : public QAModel {
public:
static Option<std::string> get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config);
static Option<bool> validate_model(const nlohmann::json& model_config);
// prevent instantiation
OpenAIQAModel() = delete;
private:
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
static constexpr char* OPENAI_CHAT_COMPLETION = "https://api.openai.com/v1/chat/completions";
};

View File

@ -25,9 +25,9 @@ struct embedding_res_t {
class RemoteEmbedder {
protected:
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
static inline ReplicationState* raft_server = nullptr;
public:
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) = 0;
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) = 0;

View File

@ -19,6 +19,7 @@
#include "vector_query_ops.h"
#include "text_embedder_manager.h"
#include "stopwords_manager.h"
#include "qa_model.h";
const std::string override_t::MATCH_EXACT = "exact";
const std::string override_t::MATCH_CONTAINS = "contains";
@ -241,6 +242,7 @@ nlohmann::json Collection::get_summary_json() const {
field_json[fields::sort] = coll_field.sort;
field_json[fields::infix] = coll_field.infix;
field_json[fields::locale] = coll_field.locale;
field_json[fields::qa] = coll_field.qa;
if(coll_field.embed.count(fields::from) != 0) {
field_json[fields::embed] = coll_field.embed;
@ -254,6 +256,15 @@ nlohmann::json Collection::get_summary_json() const {
}
}
if(!coll_field.qa.empty()) {
hide_credential(field_json[fields::qa], "api_key");
hide_credential(field_json[fields::qa], "access_token");
hide_credential(field_json[fields::qa], "refresh_token");
hide_credential(field_json[fields::qa], "client_id");
hide_credential(field_json[fields::qa], "client_secret");
hide_credential(field_json[fields::qa], "project_id");
}
if(coll_field.num_dim > 0) {
field_json[fields::num_dim] = coll_field.num_dim;
}
@ -1120,7 +1131,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries,
const std::string& stopwords_set,
const std::vector<std::string>& facet_return_parent) const {
const std::vector<std::string>& facet_return_parent,
const std::string& prompt) const {
std::shared_lock lock(mutex);
@ -1293,6 +1305,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
// Set query to * if it is semantic search
if(!vector_query.field_name.empty() && processed_search_fields.empty()) {
query = "*";
} else if(!prompt.empty()) {
return Option<nlohmann::json>(400, "Prompt is only supported for semantic search.");
}
if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) {
@ -1783,6 +1797,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
index_symbols[uint8_t(c)] = 1;
}
nlohmann::json docs_array = nlohmann::json::array();
// construct results array
for(long result_kvs_index = start_result_index; result_kvs_index <= end_result_index; result_kvs_index++) {
const std::vector<KV*> & kv_group = result_group_kvs[result_kvs_index];
@ -1951,6 +1967,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
return Option<nlohmann::json>(prune_op.code(), prune_op.error());
}
if(!prompt.empty()) {
docs_array.push_back(document);
}
wrapper_doc["document"] = document;
wrapper_doc["highlight"] = highlight_res;
@ -1999,6 +2019,25 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
}
if(!prompt.empty()) {
result["qa"] = nlohmann::json::object();
result["qa"]["prompt"] = prompt;
auto embedding_field_it = search_schema.find(vector_query.field_name);
if(embedding_field_it == search_schema.end()) {
return Option<nlohmann::json>(400, "Invalid embedding field name.");
}
// for each document, remove embedding field to reduce request size
for(auto& doc : docs_array) {
doc.erase(embedding_field_it->name);
}
auto qa_op = QAModel::get_answer(docs_array.dump(), prompt, embedding_field_it->qa);
if(!qa_op.ok()) {
return Option<nlohmann::json>(qa_op.code(), qa_op.error());
}
result["qa"]["answer"] = qa_op.get();
}
result["facet_counts"] = nlohmann::json::array();
// populate facets
@ -4413,6 +4452,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
embedding_fields.emplace(f.name, f);
}
if(f.nested && enable_nested_fields) {
updated_nested_fields.emplace(f.name, f);

View File

@ -65,6 +65,10 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field_obj[fields::embed] = nlohmann::json::object();
}
if(field_obj.count(fields::qa) == 0) {
field_obj[fields::qa] = nlohmann::json::object();
}
if(field_obj.count(fields::model_config) == 0) {
field_obj[fields::model_config] = nlohmann::json::object();
}
@ -95,7 +99,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet],
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed]);
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed], field_obj[fields::qa]);
// value of `sort` depends on field type
if(field_obj.count(fields::sort) == 0) {
@ -808,6 +812,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *FACET_RETURN_PARENT = "facet_return_parent";
const char *PROMPT = "prompt";
const char *VECTOR_QUERY = "vector_query";
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
@ -977,6 +983,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
size_t remote_embedding_timeout_ms = 5000;
size_t remote_embedding_num_tries = 2;
std::string prompt;
size_t facet_sample_percent = 100;
size_t facet_sample_threshold = 0;
@ -1017,6 +1025,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{HIGHLIGHT_END_TAG, &highlight_end_tag},
{PINNED_HITS, &pinned_hits_str},
{HIDDEN_HITS, &hidden_hits_str},
{PROMPT, &prompt},
};
std::unordered_map<std::string, bool*> bool_values = {
@ -1225,7 +1234,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
remote_embedding_timeout_ms,
remote_embedding_num_tries,
stopwords_set,
facet_return_parent
facet_return_parent,
prompt
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -2,6 +2,7 @@
#include "field.h"
#include "magic_enum.hpp"
#include "text_embedder_manager.h"
#include "qa_model.h"
#include <stack>
#include <collection_manager.h>
#include <regex>
@ -278,6 +279,24 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}
}
if(field_json.count(fields::qa) != 0) {
if(field_json.count(fields::embed) == 0 || !field_json[fields::embed].is_object() || field_json[fields::embed].count(fields::from) == 0) {
return Option<bool>(400, "Property `" + fields::qa + "` is allowed only on an embedded field.");
}
// qa object should contain "model_name" and it should be a string
if(!field_json[fields::qa].is_object() || field_json[fields::qa].count(fields::model_name) == 0 ||
!field_json[fields::qa][fields::model_name].is_string()) {
return Option<bool>(400, "Property `" + fields::qa + "` should be an object containing a `model_name` property.");
}
auto validate_qa_res = QAModel::validate_model(field_json[fields::qa]);
if(!validate_qa_res.ok()) {
return validate_qa_res;
}
}
if(field_json.count(fields::optional) == 0) {
// dynamic type fields are always optional
bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]);
@ -321,7 +340,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index])
field_json[fields::reference], field_json[fields::embed], field_json[fields::qa], field_json[fields::range_index])
);
if (!field_json[fields::reference].get<std::string>().empty()) {

154
src/qa_model.cpp Normal file
View File

@ -0,0 +1,154 @@
#include "qa_model.h"
#include "text_embedder_manager.h"
#include "text_embedder_remote.h"
Option<bool> QAModel::validate_model(const nlohmann::json& model_config) {
const std::string model_namespace = TextEmbedderManager::get_model_namespace(model_config["model_name"].get<std::string>());
if(model_namespace == "openai") {
return OpenAIQAModel::validate_model(model_config);
}
return Option<bool>(400, "Model namespace `" + model_namespace + "` is not supported.");
}
Option<std::string> QAModel::get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config) {
const std::string model_namespace = TextEmbedderManager::get_model_namespace(model_config["model_name"].get<std::string>());
if(model_namespace == "openai") {
return OpenAIQAModel::get_answer(context, prompt, model_config);
}
throw Option<std::string>(400, "Model namespace " + model_namespace + " is not supported.");
}
Option<bool> OpenAIQAModel::validate_model(const nlohmann::json& model_config) {
if(model_config.count("api_key") == 0) {
return Option<bool>(400, "API key is not provided");
}
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + model_config["api_key"].get<std::string>();
headers["Content-Type"] = "application/json";
std::string res;
auto res_code = RemoteEmbedder::call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "OpenAI API error: " + res);
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "OpenAI API error: " + res);
}
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
nlohmann::json models_json;
try {
models_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from OpenAI API.");
}
bool found = false;
// extract model name by removing "openai/" prefix
auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_config["model_name"].get<std::string>());
for (auto& model : models_json["data"]) {
if (model["id"] == model_name_without_namespace) {
found = true;
break;
}
}
if(!found) {
return Option<bool>(400, "Property `qa.model_name` is not a valid OpenAI model.");
}
nlohmann::json req_body;
req_body["model"] = model_name_without_namespace;
req_body["messages"] = R"([
{
"role":"user",
"content":"hello"
}
])"_json;
std::string chat_res;
res_code = RemoteEmbedder::call_remote_api("POST", OPENAI_CHAT_COMPLETION, req_body.dump(), chat_res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(chat_res);
} catch (const std::exception& e) {
return Option<bool>(400, "OpenAI API error: " + chat_res);
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "OpenAI API error: " + chat_res);
}
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
return Option<bool>(true);
}
Option<std::string> OpenAIQAModel::get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config) {
const std::string model_name = TextEmbedderManager::get_model_name_without_namespace(model_config["model_name"].get<std::string>());
const std::string api_key = model_config["api_key"].get<std::string>();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
nlohmann::json req_body;
req_body["model"] = model_name;
req_body["messages"] = nlohmann::json::array();
static nlohmann::json system_message = nlohmann::json::object();
system_message["role"] = "system";
system_message["content"] = "You are an assistant, answer questions according to the data in JSON format. Do not mention the data content in your answer.";
req_body["messages"].push_back(system_message);
nlohmann::json message = nlohmann::json::object();
message["role"] = "user";
message["content"] = "Data:\n" + context + "\n\nQuestion:\n" + prompt;
req_body["messages"].push_back(message);
std::string res;
auto res_code = RemoteEmbedder::call_remote_api("POST", OPENAI_CHAT_COMPLETION, req_body.dump(), res, res_headers, headers);
if(res_code == 408) {
throw Option<std::string>(400, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
throw Option<std::string>(400, "OpenAI API error: " + res);
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
throw Option<std::string>(400, "OpenAI API error: " + res);
}
throw Option<std::string>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
throw Option<std::string>(400, "Got malformed response from OpenAI API.");
}
return Option<std::string>(json_res["choices"][0]["message"]["content"].get<std::string>());
}