mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 06:40:30 +08:00
Retrieval Augmented Generation
This commit is contained in:
parent
16b59fd23a
commit
6560952eea
@ -478,7 +478,8 @@ public:
|
||||
const size_t remote_embedding_timeout_ms = 30000,
|
||||
const size_t remote_embedding_num_tries = 2,
|
||||
const std::string& stopwords_set="",
|
||||
const std::vector<std::string>& facet_return_parent = {}) const;
|
||||
const std::vector<std::string>& facet_return_parent = {},
|
||||
const std::string& prompt = "") const;
|
||||
|
||||
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;
|
||||
|
||||
|
@ -55,6 +55,7 @@ namespace fields {
|
||||
static const std::string from = "from";
|
||||
static const std::string model_name = "model_name";
|
||||
static const std::string range_index = "range_index";
|
||||
static const std::string qa = "qa";
|
||||
|
||||
// Some models require additional parameters to be passed to the model during indexing/querying
|
||||
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying
|
||||
@ -88,6 +89,7 @@ struct field {
|
||||
|
||||
size_t num_dim;
|
||||
nlohmann::json embed;
|
||||
nlohmann::json qa;
|
||||
vector_distance_type_t vec_dist;
|
||||
|
||||
static constexpr int VAL_UNKNOWN = 2;
|
||||
@ -101,10 +103,10 @@ struct field {
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
|
||||
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) :
|
||||
std::string reference = "", const nlohmann::json& embed = nlohmann::json(),const nlohmann::json& qa = nlohmann::json(), const bool range_index = false) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
|
||||
embed(embed), range_index(range_index) {
|
||||
embed(embed), range_index(range_index), qa(qa) {
|
||||
|
||||
set_computed_defaults(sort, infix);
|
||||
}
|
||||
|
26
include/qa_model.h
Normal file
26
include/qa_model.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <json.hpp>
|
||||
#include "option.h"
|
||||
|
||||
|
||||
class QAModel {
|
||||
public:
|
||||
virtual ~QAModel() {};
|
||||
static Option<std::string> get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config);
|
||||
static Option<bool> validate_model(const nlohmann::json& model_config);
|
||||
private:
|
||||
};
|
||||
|
||||
|
||||
class OpenAIQAModel : public QAModel {
|
||||
public:
|
||||
static Option<std::string> get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config);
|
||||
static Option<bool> validate_model(const nlohmann::json& model_config);
|
||||
// prevent instantiation
|
||||
OpenAIQAModel() = delete;
|
||||
private:
|
||||
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
|
||||
static constexpr char* OPENAI_CHAT_COMPLETION = "https://api.openai.com/v1/chat/completions";
|
||||
};
|
@ -25,9 +25,9 @@ struct embedding_res_t {
|
||||
class RemoteEmbedder {
|
||||
protected:
|
||||
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
public:
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
|
||||
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
|
||||
virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) = 0;
|
||||
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) = 0;
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "vector_query_ops.h"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "stopwords_manager.h"
|
||||
#include "qa_model.h";
|
||||
|
||||
const std::string override_t::MATCH_EXACT = "exact";
|
||||
const std::string override_t::MATCH_CONTAINS = "contains";
|
||||
@ -241,6 +242,7 @@ nlohmann::json Collection::get_summary_json() const {
|
||||
field_json[fields::sort] = coll_field.sort;
|
||||
field_json[fields::infix] = coll_field.infix;
|
||||
field_json[fields::locale] = coll_field.locale;
|
||||
field_json[fields::qa] = coll_field.qa;
|
||||
if(coll_field.embed.count(fields::from) != 0) {
|
||||
field_json[fields::embed] = coll_field.embed;
|
||||
|
||||
@ -254,6 +256,15 @@ nlohmann::json Collection::get_summary_json() const {
|
||||
}
|
||||
}
|
||||
|
||||
if(!coll_field.qa.empty()) {
|
||||
hide_credential(field_json[fields::qa], "api_key");
|
||||
hide_credential(field_json[fields::qa], "access_token");
|
||||
hide_credential(field_json[fields::qa], "refresh_token");
|
||||
hide_credential(field_json[fields::qa], "client_id");
|
||||
hide_credential(field_json[fields::qa], "client_secret");
|
||||
hide_credential(field_json[fields::qa], "project_id");
|
||||
}
|
||||
|
||||
if(coll_field.num_dim > 0) {
|
||||
field_json[fields::num_dim] = coll_field.num_dim;
|
||||
}
|
||||
@ -1120,7 +1131,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_tries,
|
||||
const std::string& stopwords_set,
|
||||
const std::vector<std::string>& facet_return_parent) const {
|
||||
const std::vector<std::string>& facet_return_parent,
|
||||
const std::string& prompt) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
@ -1293,6 +1305,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
// Set query to * if it is semantic search
|
||||
if(!vector_query.field_name.empty() && processed_search_fields.empty()) {
|
||||
query = "*";
|
||||
} else if(!prompt.empty()) {
|
||||
return Option<nlohmann::json>(400, "Prompt is only supported for semantic search.");
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) {
|
||||
@ -1783,6 +1797,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
index_symbols[uint8_t(c)] = 1;
|
||||
}
|
||||
|
||||
nlohmann::json docs_array = nlohmann::json::array();
|
||||
|
||||
// construct results array
|
||||
for(long result_kvs_index = start_result_index; result_kvs_index <= end_result_index; result_kvs_index++) {
|
||||
const std::vector<KV*> & kv_group = result_group_kvs[result_kvs_index];
|
||||
@ -1951,6 +1967,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
return Option<nlohmann::json>(prune_op.code(), prune_op.error());
|
||||
}
|
||||
|
||||
if(!prompt.empty()) {
|
||||
docs_array.push_back(document);
|
||||
}
|
||||
|
||||
wrapper_doc["document"] = document;
|
||||
wrapper_doc["highlight"] = highlight_res;
|
||||
|
||||
@ -1999,6 +2019,25 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
}
|
||||
|
||||
if(!prompt.empty()) {
|
||||
result["qa"] = nlohmann::json::object();
|
||||
result["qa"]["prompt"] = prompt;
|
||||
auto embedding_field_it = search_schema.find(vector_query.field_name);
|
||||
if(embedding_field_it == search_schema.end()) {
|
||||
return Option<nlohmann::json>(400, "Invalid embedding field name.");
|
||||
}
|
||||
|
||||
// for each document, remove embedding field to reduce request size
|
||||
for(auto& doc : docs_array) {
|
||||
doc.erase(embedding_field_it->name);
|
||||
}
|
||||
auto qa_op = QAModel::get_answer(docs_array.dump(), prompt, embedding_field_it->qa);
|
||||
if(!qa_op.ok()) {
|
||||
return Option<nlohmann::json>(qa_op.code(), qa_op.error());
|
||||
}
|
||||
result["qa"]["answer"] = qa_op.get();
|
||||
}
|
||||
|
||||
result["facet_counts"] = nlohmann::json::array();
|
||||
|
||||
// populate facets
|
||||
@ -4413,6 +4452,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
|
||||
embedding_fields.emplace(f.name, f);
|
||||
}
|
||||
|
||||
|
||||
if(f.nested && enable_nested_fields) {
|
||||
updated_nested_fields.emplace(f.name, f);
|
||||
|
||||
|
@ -65,6 +65,10 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
|
||||
field_obj[fields::embed] = nlohmann::json::object();
|
||||
}
|
||||
|
||||
if(field_obj.count(fields::qa) == 0) {
|
||||
field_obj[fields::qa] = nlohmann::json::object();
|
||||
}
|
||||
|
||||
if(field_obj.count(fields::model_config) == 0) {
|
||||
field_obj[fields::model_config] = nlohmann::json::object();
|
||||
}
|
||||
@ -95,7 +99,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
|
||||
field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet],
|
||||
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
|
||||
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
|
||||
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed]);
|
||||
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed], field_obj[fields::qa]);
|
||||
|
||||
// value of `sort` depends on field type
|
||||
if(field_obj.count(fields::sort) == 0) {
|
||||
@ -808,6 +812,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
|
||||
const char *FACET_RETURN_PARENT = "facet_return_parent";
|
||||
|
||||
const char *PROMPT = "prompt";
|
||||
|
||||
const char *VECTOR_QUERY = "vector_query";
|
||||
|
||||
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
|
||||
@ -977,6 +983,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
|
||||
size_t remote_embedding_timeout_ms = 5000;
|
||||
size_t remote_embedding_num_tries = 2;
|
||||
|
||||
std::string prompt;
|
||||
|
||||
size_t facet_sample_percent = 100;
|
||||
size_t facet_sample_threshold = 0;
|
||||
@ -1017,6 +1025,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{HIGHLIGHT_END_TAG, &highlight_end_tag},
|
||||
{PINNED_HITS, &pinned_hits_str},
|
||||
{HIDDEN_HITS, &hidden_hits_str},
|
||||
{PROMPT, &prompt},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, bool*> bool_values = {
|
||||
@ -1225,7 +1234,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
remote_embedding_timeout_ms,
|
||||
remote_embedding_num_tries,
|
||||
stopwords_set,
|
||||
facet_return_parent
|
||||
facet_return_parent,
|
||||
prompt
|
||||
);
|
||||
|
||||
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "field.h"
|
||||
#include "magic_enum.hpp"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "qa_model.h"
|
||||
#include <stack>
|
||||
#include <collection_manager.h>
|
||||
#include <regex>
|
||||
@ -278,6 +279,24 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json.count(fields::qa) != 0) {
|
||||
if(field_json.count(fields::embed) == 0 || !field_json[fields::embed].is_object() || field_json[fields::embed].count(fields::from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::qa + "` is allowed only on an embedded field.");
|
||||
}
|
||||
|
||||
// qa object should contain "model_name" and it should be a string
|
||||
if(!field_json[fields::qa].is_object() || field_json[fields::qa].count(fields::model_name) == 0 ||
|
||||
!field_json[fields::qa][fields::model_name].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::qa + "` should be an object containing a `model_name` property.");
|
||||
}
|
||||
|
||||
auto validate_qa_res = QAModel::validate_model(field_json[fields::qa]);
|
||||
if(!validate_qa_res.ok()) {
|
||||
return validate_qa_res;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(field_json.count(fields::optional) == 0) {
|
||||
// dynamic type fields are always optional
|
||||
bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]);
|
||||
@ -321,7 +340,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
|
||||
field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index])
|
||||
field_json[fields::reference], field_json[fields::embed], field_json[fields::qa], field_json[fields::range_index])
|
||||
);
|
||||
|
||||
if (!field_json[fields::reference].get<std::string>().empty()) {
|
||||
|
154
src/qa_model.cpp
Normal file
154
src/qa_model.cpp
Normal file
@ -0,0 +1,154 @@
|
||||
#include "qa_model.h"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "text_embedder_remote.h"
|
||||
|
||||
|
||||
Option<bool> QAModel::validate_model(const nlohmann::json& model_config) {
|
||||
const std::string model_namespace = TextEmbedderManager::get_model_namespace(model_config["model_name"].get<std::string>());
|
||||
if(model_namespace == "openai") {
|
||||
return OpenAIQAModel::validate_model(model_config);
|
||||
}
|
||||
|
||||
return Option<bool>(400, "Model namespace `" + model_namespace + "` is not supported.");
|
||||
}
|
||||
|
||||
Option<std::string> QAModel::get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config) {
|
||||
const std::string model_namespace = TextEmbedderManager::get_model_namespace(model_config["model_name"].get<std::string>());
|
||||
|
||||
if(model_namespace == "openai") {
|
||||
return OpenAIQAModel::get_answer(context, prompt, model_config);
|
||||
}
|
||||
|
||||
throw Option<std::string>(400, "Model namespace " + model_namespace + " is not supported.");
|
||||
}
|
||||
|
||||
|
||||
Option<bool> OpenAIQAModel::validate_model(const nlohmann::json& model_config) {
|
||||
if(model_config.count("api_key") == 0) {
|
||||
return Option<bool>(400, "API key is not provided");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + model_config["api_key"].get<std::string>();
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::string res;
|
||||
auto res_code = RemoteEmbedder::call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
|
||||
|
||||
if(res_code == 408) {
|
||||
return Option<bool>(408, "OpenAI API timeout.");
|
||||
}
|
||||
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res;
|
||||
try {
|
||||
json_res = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
return Option<bool>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<bool>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
nlohmann::json models_json;
|
||||
try {
|
||||
models_json = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
return Option<bool>(400, "Got malformed response from OpenAI API.");
|
||||
}
|
||||
bool found = false;
|
||||
// extract model name by removing "openai/" prefix
|
||||
auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_config["model_name"].get<std::string>());
|
||||
for (auto& model : models_json["data"]) {
|
||||
if (model["id"] == model_name_without_namespace) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(!found) {
|
||||
return Option<bool>(400, "Property `qa.model_name` is not a valid OpenAI model.");
|
||||
}
|
||||
|
||||
nlohmann::json req_body;
|
||||
req_body["model"] = model_name_without_namespace;
|
||||
req_body["messages"] = R"([
|
||||
{
|
||||
"role":"user",
|
||||
"content":"hello"
|
||||
}
|
||||
])"_json;
|
||||
std::string chat_res;
|
||||
|
||||
res_code = RemoteEmbedder::call_remote_api("POST", OPENAI_CHAT_COMPLETION, req_body.dump(), chat_res, res_headers, headers);
|
||||
|
||||
if(res_code == 408) {
|
||||
return Option<bool>(408, "OpenAI API timeout.");
|
||||
}
|
||||
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res;
|
||||
try {
|
||||
json_res = nlohmann::json::parse(chat_res);
|
||||
} catch (const std::exception& e) {
|
||||
return Option<bool>(400, "OpenAI API error: " + chat_res);
|
||||
}
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<bool>(400, "OpenAI API error: " + chat_res);
|
||||
}
|
||||
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::string> OpenAIQAModel::get_answer(const std::string& context, const std::string& prompt, const nlohmann::json& model_config) {
|
||||
const std::string model_name = TextEmbedderManager::get_model_name_without_namespace(model_config["model_name"].get<std::string>());
|
||||
const std::string api_key = model_config["api_key"].get<std::string>();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
headers["Content-Type"] = "application/json";
|
||||
nlohmann::json req_body;
|
||||
req_body["model"] = model_name;
|
||||
req_body["messages"] = nlohmann::json::array();
|
||||
static nlohmann::json system_message = nlohmann::json::object();
|
||||
system_message["role"] = "system";
|
||||
system_message["content"] = "You are an assistant, answer questions according to the data in JSON format. Do not mention the data content in your answer.";
|
||||
req_body["messages"].push_back(system_message);
|
||||
nlohmann::json message = nlohmann::json::object();
|
||||
message["role"] = "user";
|
||||
message["content"] = "Data:\n" + context + "\n\nQuestion:\n" + prompt;
|
||||
req_body["messages"].push_back(message);
|
||||
|
||||
std::string res;
|
||||
auto res_code = RemoteEmbedder::call_remote_api("POST", OPENAI_CHAT_COMPLETION, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code == 408) {
|
||||
throw Option<std::string>(400, "OpenAI API timeout.");
|
||||
}
|
||||
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res;
|
||||
try {
|
||||
json_res = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
throw Option<std::string>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
throw Option<std::string>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
throw Option<std::string>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
nlohmann::json json_res;
|
||||
try {
|
||||
json_res = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
throw Option<std::string>(400, "Got malformed response from OpenAI API.");
|
||||
}
|
||||
|
||||
return Option<std::string>(json_res["choices"][0]["message"]["content"].get<std::string>());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user