mirror of
https://github.com/typesense/typesense.git
synced 2025-05-24 07:40:35 +08:00
Merge pull request #1065 from ozanarmagan/v0.25-join
Bug fixes for embedding generation process
This commit is contained in:
commit
23dbfd20d5
@ -397,7 +397,7 @@ public:
|
||||
|
||||
nlohmann::json get_summary_json() const;
|
||||
|
||||
size_t batch_index_in_memory(std::vector<index_record>& index_records);
|
||||
size_t batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings = true);
|
||||
|
||||
Option<nlohmann::json> add(const std::string & json_str,
|
||||
const index_operation_t& operation=CREATE, const std::string& id="",
|
||||
|
@ -201,6 +201,9 @@ struct index_record {
|
||||
nlohmann::json new_doc; // new *full* document to be stored into disk
|
||||
nlohmann::json del_doc; // document containing the fields that should be deleted
|
||||
|
||||
nlohmann::json embedding_res; // embedding result
|
||||
int embedding_status_code; // embedding status code
|
||||
|
||||
index_operation_t operation;
|
||||
bool is_update;
|
||||
|
||||
@ -539,7 +542,7 @@ private:
|
||||
|
||||
void initialize_facet_indexes(const field& facet_field);
|
||||
|
||||
static Option<bool> batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
static void batch_embed_fields(std::vector<index_record*>& documents,
|
||||
const tsl::htrie_map<char, field>& embedding_fields,
|
||||
const tsl::htrie_map<char, field> & search_schema);
|
||||
|
||||
@ -677,7 +680,7 @@ public:
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation);
|
||||
const bool do_validation, const bool generate_embeddings = true);
|
||||
|
||||
static size_t batch_memory_index(Index *index,
|
||||
std::vector<index_record>& iter_batch,
|
||||
@ -687,7 +690,7 @@ public:
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation);
|
||||
const bool do_validation, const bool generate_embeddings = true);
|
||||
|
||||
void index_field_in_memory(const field& afield, std::vector<index_record>& iter_batch);
|
||||
|
||||
|
@ -16,8 +16,8 @@ class TextEmbedder {
|
||||
// Constructor for remote models
|
||||
TextEmbedder(const nlohmann::json& model_config);
|
||||
~TextEmbedder();
|
||||
Option<std::vector<float>> Embed(const std::string& text);
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs);
|
||||
embedding_res_t Embed(const std::string& text);
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs);
|
||||
const std::string& get_vocab_file_name() const;
|
||||
bool is_remote() {
|
||||
return remote_embedder_ != nullptr;
|
||||
|
@ -9,6 +9,18 @@
|
||||
|
||||
|
||||
|
||||
struct embedding_res_t {
|
||||
std::vector<float> embedding;
|
||||
nlohmann::json error = nlohmann::json::object();
|
||||
int status_code;
|
||||
bool success;
|
||||
|
||||
embedding_res_t(const std::vector<float>& embedding) : embedding(embedding), success(true) {}
|
||||
|
||||
embedding_res_t(int status_code, const nlohmann::json& error) : error(error), success(false), status_code(status_code) {}
|
||||
};
|
||||
|
||||
|
||||
|
||||
class RemoteEmbedder {
|
||||
protected:
|
||||
@ -16,11 +28,12 @@ class RemoteEmbedder {
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers);
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
public:
|
||||
virtual Option<std::vector<float>> Embed(const std::string& text) = 0;
|
||||
virtual Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
virtual embedding_res_t Embed(const std::string& text) = 0;
|
||||
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
static void init(ReplicationState* rs) {
|
||||
raft_server = rs;
|
||||
}
|
||||
virtual ~RemoteEmbedder() = default;
|
||||
|
||||
};
|
||||
|
||||
@ -34,8 +47,8 @@ class OpenAIEmbedder : public RemoteEmbedder {
|
||||
public:
|
||||
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
@ -49,8 +62,8 @@ class GoogleEmbedder : public RemoteEmbedder {
|
||||
public:
|
||||
GoogleEmbedder(const std::string& google_api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
@ -75,8 +88,8 @@ class GCPEmbedder : public RemoteEmbedder {
|
||||
GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token,
|
||||
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
|
@ -30,6 +30,7 @@ class TextEmbeddingTokenizer {
|
||||
private:
|
||||
public:
|
||||
virtual encoded_input_t Encode(const std::string& text) = 0;
|
||||
virtual ~TextEmbeddingTokenizer() = default;
|
||||
};
|
||||
|
||||
class BertTokenizerWrapper : public TextEmbeddingTokenizer {
|
||||
|
@ -14,6 +14,11 @@ struct KV {
|
||||
uint64_t key{};
|
||||
uint64_t distinct_key{};
|
||||
int64_t scores[3]{}; // match score + 2 custom attributes
|
||||
|
||||
// only to be used in hybrid search
|
||||
float vector_distance = 0.0f;
|
||||
int64_t text_match_score = 0;
|
||||
|
||||
reference_filter_result_t* reference_filter_result = nullptr;
|
||||
|
||||
// to be used only in final aggregation
|
||||
@ -42,6 +47,9 @@ struct KV {
|
||||
|
||||
query_indices = kv.query_indices;
|
||||
kv.query_indices = nullptr;
|
||||
|
||||
vector_distance = kv.vector_distance;
|
||||
text_match_score = kv.text_match_score;
|
||||
}
|
||||
|
||||
KV& operator=(KV&& kv) noexcept {
|
||||
@ -59,6 +67,9 @@ struct KV {
|
||||
delete[] query_indices;
|
||||
query_indices = kv.query_indices;
|
||||
kv.query_indices = nullptr;
|
||||
|
||||
vector_distance = kv.vector_distance;
|
||||
text_match_score = kv.text_match_score;
|
||||
}
|
||||
|
||||
return *this;
|
||||
@ -79,6 +90,9 @@ struct KV {
|
||||
delete[] query_indices;
|
||||
query_indices = kv.query_indices;
|
||||
kv.query_indices = nullptr;
|
||||
|
||||
vector_distance = kv.vector_distance;
|
||||
text_match_score = kv.text_match_score;
|
||||
}
|
||||
|
||||
return *this;
|
||||
|
@ -31,7 +31,7 @@ public:
|
||||
const index_operation_t op,
|
||||
const bool is_update,
|
||||
const std::string& fallback_field_type,
|
||||
const DIRTY_VALUES& dirty_values);
|
||||
const DIRTY_VALUES& dirty_values, const bool validate_embedding_fields = true);
|
||||
|
||||
|
||||
static Option<uint32_t> coerce_element(const field& a_field, nlohmann::json& document,
|
||||
|
@ -561,12 +561,22 @@ void Collection::batch_index(std::vector<index_record>& index_records, std::vect
|
||||
if(!index_record.indexed.ok()) {
|
||||
res["document"] = json_out[index_record.position];
|
||||
res["error"] = index_record.indexed.error();
|
||||
if (!index_record.embedding_res.empty()) {
|
||||
res["embedding_error"] = nlohmann::json::object();
|
||||
res["embedding_error"] = index_record.embedding_res;
|
||||
res["error"] = index_record.embedding_res["error"];
|
||||
}
|
||||
res["code"] = index_record.indexed.code();
|
||||
}
|
||||
} else {
|
||||
res["success"] = false;
|
||||
res["document"] = json_out[index_record.position];
|
||||
res["error"] = index_record.indexed.error();
|
||||
if (!index_record.embedding_res.empty()) {
|
||||
res["embedding_error"] = nlohmann::json::object();
|
||||
res["error"] = index_record.embedding_res["error"];
|
||||
res["embedding_error"] = index_record.embedding_res;
|
||||
}
|
||||
res["code"] = index_record.indexed.code();
|
||||
}
|
||||
|
||||
@ -598,11 +608,11 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
|
||||
return Option<>(200);
|
||||
}
|
||||
|
||||
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records) {
|
||||
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings) {
|
||||
std::unique_lock lock(mutex);
|
||||
size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field,
|
||||
search_schema, embedding_fields, fallback_field_type,
|
||||
token_separators, symbols_to_index, true);
|
||||
token_separators, symbols_to_index, true, generate_embeddings);
|
||||
num_documents += num_indexed;
|
||||
return num_indexed;
|
||||
}
|
||||
@ -1131,6 +1141,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
group_limit = 0;
|
||||
}
|
||||
|
||||
|
||||
vector_query_t vector_query;
|
||||
if(!vector_query_str.empty()) {
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this);
|
||||
@ -1184,8 +1195,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
// }
|
||||
|
||||
if(raw_query == "*") {
|
||||
std::string error = "Wildcard query is not supported for embedding fields.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
// ignore embedding field if query is a wildcard
|
||||
continue;
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
@ -1205,11 +1216,14 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
|
||||
auto embedding_op = embedder->Embed(embed_query);
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<nlohmann::json>(400, embedding_op.error());
|
||||
if(!embedding_op.success) {
|
||||
if(!embedding_op.error["error"].get<std::string>().empty()) {
|
||||
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());
|
||||
} else {
|
||||
return Option<nlohmann::json>(400, embedding_op.error.dump());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> embedding = embedding_op.get();
|
||||
std::vector<float> embedding = embedding_op.embedding;
|
||||
vector_query._reset();
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
@ -1887,7 +1901,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
populate_text_match_info(wrapper_doc["text_match_info"],
|
||||
field_order_kv->scores[field_order_kv->match_score_index], match_type);
|
||||
} else {
|
||||
wrapper_doc["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]);
|
||||
wrapper_doc["hybrid_search_info"] = nlohmann::json::object();
|
||||
wrapper_doc["hybrid_search_info"]["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]);
|
||||
wrapper_doc["hybrid_search_info"]["text_match_score"] = field_order_kv->text_match_score;
|
||||
wrapper_doc["hybrid_search_info"]["vector_distance"] = field_order_kv->vector_distance;
|
||||
}
|
||||
}
|
||||
|
||||
@ -3707,6 +3724,10 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
nested_field_names.push_back(f.name);
|
||||
}
|
||||
|
||||
if(f.embed.count(fields::from) != 0) {
|
||||
embedding_fields.emplace(f.name, f);
|
||||
}
|
||||
|
||||
fields.push_back(f);
|
||||
}
|
||||
|
||||
@ -3761,9 +3782,9 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
index->remove(seq_id, rec.doc, del_fields, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
|
||||
fallback_field_type, token_separators, symbols_to_index, true);
|
||||
fallback_field_type, token_separators, symbols_to_index, true, false);
|
||||
|
||||
iter_batch.clear();
|
||||
}
|
||||
@ -3868,6 +3889,21 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// hide credentials in the alter payload return
|
||||
for(auto& field_json : alter_payload["fields"]) {
|
||||
if(field_json[fields::embed].count(fields::model_config) != 0) {
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "api_key");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "access_token");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "refresh_token");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "client_id");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "client_secret");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "project_id");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
@ -4211,6 +4247,65 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
|
||||
}
|
||||
}
|
||||
|
||||
for(const auto& kv: schema_changes["fields"].items()) {
|
||||
// validate embedding fields externally
|
||||
auto& field_json = kv.value();
|
||||
if(field_json.count(fields::embed) != 0 && !field_json[fields::embed].empty()) {
|
||||
if(!field_json[fields::embed].is_object()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "` must be an object.");
|
||||
}
|
||||
|
||||
if(field_json[fields::embed].count(fields::from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "` must contain a `" + fields::from + "` property.");
|
||||
}
|
||||
|
||||
if(!field_json[fields::embed][fields::from].is_array()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must be an array.");
|
||||
}
|
||||
|
||||
if(field_json[fields::embed][fields::from].empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must have at least one element.");
|
||||
}
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
|
||||
if(!embed_from_field.is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must contain only field names as strings.");
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` is only allowed on a float array field.");
|
||||
}
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
|
||||
bool flag = false;
|
||||
for(const auto& field : search_schema) {
|
||||
if(field.name == embed_from_field) {
|
||||
if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
for(const auto& other_kv: schema_changes["fields"].items()) {
|
||||
if(other_kv.value()["name"] == embed_from_field) {
|
||||
if(other_kv.value()[fields::type] != field_types::STRING && other_kv.value()[fields::type] != field_types::STRING_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(num_auto_detect_fields > 1) {
|
||||
return Option<bool>(400, "There can be only one field named `.*`.");
|
||||
}
|
||||
@ -4818,10 +4913,10 @@ void Collection::process_remove_field_for_embedding_fields(const field& the_fiel
|
||||
|
||||
void Collection::hide_credential(nlohmann::json& json, const std::string& credential_name) {
|
||||
if(json.count(credential_name) != 0) {
|
||||
// hide api key with * except first 3 chars
|
||||
// hide api key with * except first 5 chars
|
||||
std::string credential_name_str = json[credential_name];
|
||||
if(credential_name_str.size() > 3) {
|
||||
json[credential_name] = credential_name_str.replace(3, credential_name_str.size() - 3, credential_name_str.size() - 3, '*');
|
||||
if(credential_name_str.size() > 5) {
|
||||
json[credential_name] = credential_name_str.replace(5, credential_name_str.size() - 5, credential_name_str.size() - 5, '*');
|
||||
} else {
|
||||
json[credential_name] = credential_name_str.replace(0, credential_name_str.size(), credential_name_str.size(), '*');
|
||||
}
|
||||
|
@ -1383,7 +1383,7 @@ Option<bool> CollectionManager::load_collection(const nlohmann::json &collection
|
||||
// batch must match atleast the number of shards
|
||||
if(exceeds_batch_mem_threshold || (num_valid_docs % batch_size == 0) || last_record) {
|
||||
size_t num_records = index_records.size();
|
||||
size_t num_indexed = collection->batch_index_in_memory(index_records);
|
||||
size_t num_indexed = collection->batch_index_in_memory(index_records, false);
|
||||
batch_doc_str_size = 0;
|
||||
|
||||
if(num_indexed != num_records) {
|
||||
|
@ -928,14 +928,42 @@ bool post_add_document(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
const index_operation_t operation = get_index_operation(req->params[ACTION]);
|
||||
const auto& dirty_values = collection->parse_dirty_values_option(req->params[DIRTY_VALUES_PARAM]);
|
||||
|
||||
Option<nlohmann::json> inserted_doc_op = collection->add(req->body, operation, "", dirty_values);
|
||||
nlohmann::json document;
|
||||
std::vector<std::string> json_lines = {req->body};
|
||||
const nlohmann::json& inserted_doc_op = collection->add_many(json_lines, document, operation, "", dirty_values, false, false);
|
||||
|
||||
if(!inserted_doc_op.ok()) {
|
||||
res->set(inserted_doc_op.code(), inserted_doc_op.error());
|
||||
if(!inserted_doc_op["success"].get<bool>()) {
|
||||
nlohmann::json res_doc;
|
||||
|
||||
try {
|
||||
res_doc = nlohmann::json::parse(json_lines[0]);
|
||||
} catch(const std::exception& e) {
|
||||
LOG(ERROR) << "JSON error: " << e.what();
|
||||
res->set_400("Bad JSON.");
|
||||
return false;
|
||||
}
|
||||
|
||||
res->status_code = res_doc["code"].get<size_t>();
|
||||
// erase keys from res_doc except error and embedding_error
|
||||
for(auto it = res_doc.begin(); it != res_doc.end(); ) {
|
||||
if(it.key() != "error" && it.key() != "embedding_error") {
|
||||
it = res_doc.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// rename error to message if not empty and exists
|
||||
if(res_doc.count("error") != 0 && !res_doc["error"].get<std::string>().empty()) {
|
||||
res_doc["message"] = res_doc["error"];
|
||||
res_doc.erase("error");
|
||||
}
|
||||
|
||||
res->body = res_doc.dump();
|
||||
return false;
|
||||
}
|
||||
|
||||
res->set_201(inserted_doc_op.get().dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore));
|
||||
res->set_201(document.dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
103
src/index.cpp
103
src/index.cpp
@ -415,10 +415,10 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation) {
|
||||
const bool do_validation, const bool generate_embeddings) {
|
||||
|
||||
// runs in a partitioned thread
|
||||
std::vector<nlohmann::json*> docs_to_embed;
|
||||
std::vector<index_record*> records_to_embed;
|
||||
|
||||
for(size_t i = 0; i < batch_size; i++) {
|
||||
index_record& index_rec = iter_batch[batch_start_index + i];
|
||||
@ -441,7 +441,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
index_rec.operation,
|
||||
index_rec.is_update,
|
||||
fallback_field_type,
|
||||
index_rec.dirty_values);
|
||||
index_rec.dirty_values, generate_embeddings);
|
||||
|
||||
if(!validation_op.ok()) {
|
||||
index_rec.index_failure(validation_op.code(), validation_op.error());
|
||||
@ -455,14 +455,16 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
index_rec.new_doc, index_rec.del_doc);
|
||||
scrub_reindex_doc(search_schema, index_rec.doc, index_rec.del_doc, index_rec.old_doc);
|
||||
|
||||
for(auto& field: index_rec.doc.items()) {
|
||||
for(auto& embedding_field : embedding_fields) {
|
||||
if(!embedding_field.embed[fields::from].is_null()) {
|
||||
auto embed_from_vector = embedding_field.embed[fields::from].get<std::vector<std::string>>();
|
||||
for(auto& embed_from: embed_from_vector) {
|
||||
if(embed_from == field.key()) {
|
||||
docs_to_embed.push_back(&index_rec.new_doc);
|
||||
break;
|
||||
if(generate_embeddings) {
|
||||
for(auto& field: index_rec.doc.items()) {
|
||||
for(auto& embedding_field : embedding_fields) {
|
||||
if(!embedding_field.embed[fields::from].is_null()) {
|
||||
auto embed_from_vector = embedding_field.embed[fields::from].get<std::vector<std::string>>();
|
||||
for(auto& embed_from: embed_from_vector) {
|
||||
if(embed_from == field.key()) {
|
||||
records_to_embed.push_back(&index_rec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -470,7 +472,9 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
}
|
||||
} else {
|
||||
handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc);
|
||||
docs_to_embed.push_back(&index_rec.doc);
|
||||
if(generate_embeddings) {
|
||||
records_to_embed.push_back(&index_rec);
|
||||
}
|
||||
}
|
||||
|
||||
compute_token_offsets_facets(index_rec, search_schema, token_separators, symbols_to_index);
|
||||
@ -500,13 +504,8 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
index_rec.index_failure(400, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
auto embed_op = batch_embed_fields(docs_to_embed, embedding_fields, search_schema);
|
||||
if(!embed_op.ok()) {
|
||||
for(size_t i = 0; i < batch_size; i++) {
|
||||
index_record& index_rec = iter_batch[batch_start_index + i];
|
||||
index_rec.index_failure(embed_op.code(), embed_op.error());
|
||||
}
|
||||
if(generate_embeddings) {
|
||||
batch_embed_fields(records_to_embed, embedding_fields, search_schema);
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,7 +516,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation) {
|
||||
const bool do_validation, const bool generate_embeddings) {
|
||||
|
||||
const size_t concurrency = 4;
|
||||
const size_t num_threads = std::min(concurrency, iter_batch.size());
|
||||
@ -547,7 +546,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
|
||||
index->thread_pool->enqueue([&, batch_index, batch_len]() {
|
||||
write_log_index = local_write_log_index;
|
||||
validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema,
|
||||
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation);
|
||||
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, generate_embeddings);
|
||||
|
||||
std::unique_lock<std::mutex> lock(m_process);
|
||||
num_processed++;
|
||||
@ -3178,18 +3177,21 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
continue;
|
||||
}
|
||||
// (1 / rank_of_document) * WEIGHT)
|
||||
result->text_match_score = result->scores[result->match_score_index];
|
||||
LOG(INFO) << "SEQ_ID: " << result->key << ", score: " << result->text_match_score;
|
||||
result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT);
|
||||
}
|
||||
|
||||
for(int i = 0; i < vec_results.size(); i++) {
|
||||
auto& result = vec_results[i];
|
||||
auto doc_id = result.first;
|
||||
auto& vec_result = vec_results[i];
|
||||
auto doc_id = vec_result.first;
|
||||
|
||||
auto result_it = topster->kv_map.find(doc_id);
|
||||
|
||||
if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) {
|
||||
auto result = result_it->second;
|
||||
// old_score + (1 / rank_of_document) * WEIGHT)
|
||||
result->vector_distance = vec_result.second;
|
||||
result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT));
|
||||
} else {
|
||||
int64_t scores[3] = {0};
|
||||
@ -3197,6 +3199,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
scores[0] = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT);
|
||||
int64_t match_score_index = 0;
|
||||
KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores);
|
||||
kv.vector_distance = vec_result.second;
|
||||
topster->add(&kv);
|
||||
++all_result_ids_len;
|
||||
}
|
||||
@ -6470,13 +6473,26 @@ bool Index::common_results_exist(std::vector<art_leaf*>& leaves, bool must_match
|
||||
}
|
||||
|
||||
|
||||
Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
const tsl::htrie_map<char, field>& embedding_fields,
|
||||
const tsl::htrie_map<char, field> & search_schema) {
|
||||
for(const auto& field : embedding_fields) {
|
||||
std::vector<std::pair<nlohmann::json*, std::string>> texts_to_embed;
|
||||
std::vector<std::pair<index_record*, std::string>> texts_to_embed;
|
||||
auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
|
||||
for(auto& document : documents) {
|
||||
for(auto& record : records) {
|
||||
if(!record->indexed.ok()) {
|
||||
continue;
|
||||
}
|
||||
nlohmann::json* document;
|
||||
if(record->is_update) {
|
||||
document = &record->new_doc;
|
||||
} else {
|
||||
document = &record->doc;
|
||||
}
|
||||
|
||||
if(document == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string text = indexing_prefix;
|
||||
auto embed_from = field.embed[fields::from].get<std::vector<std::string>>();
|
||||
for(const auto& field_name : embed_from) {
|
||||
@ -6489,8 +6505,8 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!text.empty()) {
|
||||
texts_to_embed.push_back(std::make_pair(document, text));
|
||||
if(text != indexing_prefix) {
|
||||
texts_to_embed.push_back(std::make_pair(record, text));
|
||||
}
|
||||
}
|
||||
|
||||
@ -6502,13 +6518,15 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
|
||||
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<bool>(400, embedder_op.error());
|
||||
LOG(ERROR) << "Error while getting embedder for model: " << field.embed[fields::model_config];
|
||||
LOG(ERROR) << "Error: " << embedder_op.error();
|
||||
return;
|
||||
}
|
||||
|
||||
// sort texts by length
|
||||
std::sort(texts_to_embed.begin(), texts_to_embed.end(),
|
||||
[](const std::pair<nlohmann::json*, std::string>& a,
|
||||
const std::pair<nlohmann::json*, std::string>& b) {
|
||||
[](const std::pair<index_record*, std::string>& a,
|
||||
const std::pair<index_record*, std::string>& b) {
|
||||
return a.second.size() < b.second.size();
|
||||
});
|
||||
|
||||
@ -6518,19 +6536,24 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
texts.push_back(text_to_embed.second);
|
||||
}
|
||||
|
||||
auto embedding_op = embedder_op.get()->batch_embed(texts);
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<bool>(400, embedding_op.error());
|
||||
}
|
||||
auto embeddings = embedder_op.get()->batch_embed(texts);
|
||||
|
||||
auto embeddings = embedding_op.get();
|
||||
for(size_t i = 0; i < embeddings.size(); i++) {
|
||||
auto& embedding = embeddings[i];
|
||||
auto& document = texts_to_embed[i].first;
|
||||
(*document)[field.name] = embedding;
|
||||
auto& embedding_res = embeddings[i];
|
||||
if(!embedding_res.success) {
|
||||
texts_to_embed[i].first->embedding_res = embedding_res.error;
|
||||
texts_to_embed[i].first->index_failure(embedding_res.status_code, "");
|
||||
continue;
|
||||
}
|
||||
nlohmann::json* document;
|
||||
if(texts_to_embed[i].first->is_update) {
|
||||
document = &texts_to_embed[i].first->new_doc;
|
||||
} else {
|
||||
document = &texts_to_embed[i].first->doc;
|
||||
}
|
||||
(*document)[field.name] = embedding_res.embedding;
|
||||
}
|
||||
}
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -968,6 +968,36 @@ bool ReplicationState::get_ext_snapshot_succeeded() {
|
||||
std::string ReplicationState::get_leader_url() const {
|
||||
std::shared_lock lock(node_mutex);
|
||||
|
||||
if(!node) {
|
||||
const Option<std::string> & refreshed_nodes_op = Config::fetch_nodes_config(config->get_nodes());
|
||||
|
||||
if(!refreshed_nodes_op.ok()) {
|
||||
LOG(WARNING) << "Error while fetching peer configuration: " << refreshed_nodes_op.error();
|
||||
return "";
|
||||
}
|
||||
|
||||
const std::string& nodes_config = ReplicationState::to_nodes_config(peering_endpoint,
|
||||
Config::get_instance().get_api_port(),
|
||||
|
||||
refreshed_nodes_op.get());
|
||||
std::vector<braft::PeerId> peers;
|
||||
braft::Configuration peer_config;
|
||||
peer_config.parse_from(nodes_config);
|
||||
peer_config.list_peers(&peers);
|
||||
|
||||
if(peers.empty()) {
|
||||
LOG(WARNING) << "No peers found in nodes config: " << nodes_config;
|
||||
return "";
|
||||
}
|
||||
|
||||
|
||||
const std::string protocol = api_uses_ssl ? "https" : "http";
|
||||
std::string url = get_node_url_path(peers[0].to_string(), "/", protocol);
|
||||
|
||||
LOG(INFO) << "Returning first peer as leader URL: " << url;
|
||||
return url;
|
||||
}
|
||||
|
||||
if(node->leader_id().is_empty()) {
|
||||
LOG(ERROR) << "Could not get leader status, as node does not have a leader!";
|
||||
return "";
|
||||
|
@ -83,7 +83,7 @@ std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<floa
|
||||
return pooled_output;
|
||||
}
|
||||
|
||||
Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t TextEmbedder::Embed(const std::string& text) {
|
||||
if(is_remote()) {
|
||||
return remote_embedder_->Embed(text);
|
||||
} else {
|
||||
@ -129,12 +129,12 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
auto pooled_output = mean_pooling(output);
|
||||
|
||||
return Option<std::vector<float>>(pooled_output);
|
||||
return embedding_res_t(pooled_output);
|
||||
}
|
||||
}
|
||||
|
||||
Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
if(!is_remote()) {
|
||||
for(int i = 0; i < inputs.size(); i += 8) {
|
||||
auto input_batch = std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + 8, static_cast<int>(inputs.size())));
|
||||
@ -193,7 +193,7 @@ Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vec
|
||||
// if seq length is 0, return empty vector
|
||||
if(input_shapes[0][1] == 0) {
|
||||
for(int i = 0; i < input_batch.size(); i++) {
|
||||
outputs.push_back(std::vector<float>());
|
||||
outputs.push_back(embedding_res_t(400, nlohmann::json({{"error", "Invalid input: empty sequence"}})));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -211,17 +211,14 @@ Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vec
|
||||
}
|
||||
output.push_back(output_row);
|
||||
}
|
||||
outputs.push_back(mean_pooling(output));
|
||||
outputs.push_back(embedding_res_t(mean_pooling(output)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto embed_op = remote_embedder_->batch_embed(inputs);
|
||||
if(!embed_op.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(embed_op.code(), embed_op.error());
|
||||
}
|
||||
outputs = embed_op.get();
|
||||
outputs = std::move(remote_embedder_->batch_embed(inputs));
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
TextEmbedder::~TextEmbedder() {
|
||||
|
@ -94,7 +94,6 @@ const std::string& TextEmbedderManager::get_model_dir() {
|
||||
}
|
||||
|
||||
TextEmbedderManager::~TextEmbedderManager() {
|
||||
delete_all_text_embedders();
|
||||
}
|
||||
|
||||
const std::string TextEmbedderManager::get_absolute_model_path(const std::string& model_name) {
|
||||
@ -117,8 +116,9 @@ const bool TextEmbedderManager::check_md5(const std::string& file_path, const st
|
||||
std::stringstream ss,res;
|
||||
ss << stream.rdbuf();
|
||||
MD5((unsigned char*)ss.str().c_str(), ss.str().length(), md5);
|
||||
for (int i = 0; i < MD5_DIGEST_LENGTH; i++) {
|
||||
res << std::hex << (int)md5[i];
|
||||
// convert md5 to hex string with leading zeros
|
||||
for(int i = 0; i < MD5_DIGEST_LENGTH; i++) {
|
||||
res << std::hex << std::setfill('0') << std::setw(2) << (int)md5[i];
|
||||
}
|
||||
return res.str() == target_md5;
|
||||
}
|
||||
|
@ -13,15 +13,16 @@ Option<bool> RemoteEmbedder::validate_string_properties(const nlohmann::json& mo
|
||||
|
||||
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body,
|
||||
std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers) {
|
||||
if(raft_server == nullptr) {
|
||||
if(raft_server == nullptr || raft_server->get_leader_url().empty()) {
|
||||
if(method == "GET") {
|
||||
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 10000, true);
|
||||
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 100000, true);
|
||||
} else if(method == "POST") {
|
||||
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 10000, true);
|
||||
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 100000, true);
|
||||
} else {
|
||||
return 400;
|
||||
}
|
||||
}
|
||||
|
||||
auto leader_url = raft_server->get_leader_url();
|
||||
leader_url += "proxy";
|
||||
nlohmann::json req_body;
|
||||
@ -103,7 +104,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
@ -116,15 +117,23 @@ Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
return embedding_res_t(res_code, embedding_res);
|
||||
}
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
|
||||
return embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = inputs;
|
||||
// remove "openai/" prefix
|
||||
@ -137,20 +146,35 @@ Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::v
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<std::vector<float>>>(400, "OpenAI API error: " + res);
|
||||
LOG(INFO) << "OpenAI API error: " << json_res.dump();
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
embedding_res["request"]["body"]["input"] = std::vector<std::string>{inputs[0]};
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(400, res);
|
||||
|
||||
for(size_t i = 0; i < inputs.size(); i++) {
|
||||
embedding_res["request"]["body"]["input"][0] = inputs[i];
|
||||
outputs.push_back(embedding_res_t(res_code, embedding_res));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(auto& data : res_json["data"]) {
|
||||
outputs.push_back(data["embedding"].get<std::vector<float>>());
|
||||
outputs.push_back(embedding_res_t(data["embedding"].get<std::vector<float>>()));
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
|
||||
@ -198,7 +222,7 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
@ -210,27 +234,30 @@ Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "Google API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<float>>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
return embedding_res_t(res_code, embedding_res);
|
||||
}
|
||||
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
|
||||
return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
|
||||
Option<std::vector<std::vector<float>>> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(auto& input : inputs) {
|
||||
auto res = Embed(input);
|
||||
if(!res.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(res.code(), res.error());
|
||||
}
|
||||
outputs.push_back(res.get());
|
||||
outputs.push_back(res);
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
|
||||
@ -298,7 +325,7 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t GCPEmbedder::Embed(const std::string& text) {
|
||||
nlohmann::json req_body;
|
||||
req_body["instances"] = nlohmann::json::array();
|
||||
nlohmann::json instance;
|
||||
@ -316,7 +343,9 @@ Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
if(res_code == 401) {
|
||||
auto refresh_op = generate_access_token(refresh_token, client_id, client_secret);
|
||||
if(!refresh_op.ok()) {
|
||||
return Option<std::vector<float>>(refresh_op.code(), refresh_op.error());
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["error"] = refresh_op.error();
|
||||
return embedding_res_t(refresh_op.code(), embedding_res);
|
||||
}
|
||||
access_token = refresh_op.get();
|
||||
// retry
|
||||
@ -327,32 +356,32 @@ Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "GCP API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<float>>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
return embedding_res_t(res_code, embedding_res);
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
return Option<std::vector<float>>(res_json["predictions"][0]["embeddings"]["values"].get<std::vector<float>>());
|
||||
return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
|
||||
Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
// GCP API has a limit of 5 instances per request
|
||||
if(inputs.size() > 5) {
|
||||
std::vector<std::vector<float>> res;
|
||||
std::vector<embedding_res_t> res;
|
||||
for(size_t i = 0; i < inputs.size(); i += 5) {
|
||||
auto batch_res = batch_embed(std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + 5, inputs.size())));
|
||||
if(!batch_res.ok()) {
|
||||
LOG(INFO) << "Batch embedding failed: " << batch_res.error();
|
||||
return Option<std::vector<std::vector<float>>>(batch_res.code(), batch_res.error());
|
||||
}
|
||||
auto batch = batch_res.get();
|
||||
res.insert(res.end(), batch.begin(), batch.end());
|
||||
res.insert(res.end(), batch_res.begin(), batch_res.end());
|
||||
}
|
||||
auto opt = Option<std::vector<std::vector<float>>>(res);
|
||||
return opt;
|
||||
return res;
|
||||
}
|
||||
nlohmann::json req_body;
|
||||
req_body["instances"] = nlohmann::json::array();
|
||||
@ -371,7 +400,13 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
if(res_code == 401) {
|
||||
auto refresh_op = generate_access_token(refresh_token, client_id, client_secret);
|
||||
if(!refresh_op.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(refresh_op.code(), refresh_op.error());
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["error"] = refresh_op.error();
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(size_t i = 0; i < inputs.size(); i++) {
|
||||
outputs.push_back(embedding_res_t(refresh_op.code(), embedding_res));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
access_token = refresh_op.get();
|
||||
// retry
|
||||
@ -382,19 +417,29 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<std::vector<float>>>(400, "GCP API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(size_t i = 0; i < inputs.size(); i++) {
|
||||
outputs.push_back(embedding_res_t(res_code, embedding_res));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(const auto& prediction : res_json["predictions"]) {
|
||||
outputs.push_back(prediction["embeddings"]["values"].get<std::vector<float>>());
|
||||
outputs.push_back(embedding_res_t(prediction["embeddings"]["values"].get<std::vector<float>>()));
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
Option<std::string> GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) {
|
||||
|
@ -605,7 +605,7 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
|
||||
const index_operation_t op,
|
||||
const bool is_update,
|
||||
const std::string& fallback_field_type,
|
||||
const DIRTY_VALUES& dirty_values) {
|
||||
const DIRTY_VALUES& dirty_values, const bool validate_embedding_fields) {
|
||||
|
||||
bool missing_default_sort_field = (!default_sorting_field.empty() && document.count(default_sorting_field) == 0);
|
||||
|
||||
@ -652,10 +652,12 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
|
||||
}
|
||||
}
|
||||
|
||||
// validate embedding fields
|
||||
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update);
|
||||
if(!validate_embed_op.ok()) {
|
||||
return Option<>(validate_embed_op.code(), validate_embed_op.error());
|
||||
if(validate_embedding_fields) {
|
||||
// validate embedding fields
|
||||
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update);
|
||||
if(!validate_embed_op.ok()) {
|
||||
return Option<>(validate_embed_op.code(), validate_embed_op.error());
|
||||
}
|
||||
}
|
||||
|
||||
return Option<>(200);
|
||||
|
@ -1629,4 +1629,46 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
|
||||
|
||||
embedding_fields_map = coll->get_embedding_fields();
|
||||
ASSERT_EQ(0, embedding_fields_map.size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSchemaChangeTest, DropAndReindexEmbeddingField) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
// drop the embedding field and reindex
|
||||
nlohmann::json schema_without_embedding = R"({
|
||||
"fields": [
|
||||
{"name": "embedding", "drop": true},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
auto update_op = op.get()->alter(schema_without_embedding);
|
||||
|
||||
ASSERT_TRUE(update_op.ok());
|
||||
|
||||
auto embedding_fields_map = op.get()->get_embedding_fields();
|
||||
|
||||
ASSERT_EQ(1, embedding_fields_map.size());
|
||||
|
||||
// try adding a document
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "hello";
|
||||
auto add_op = op.get()->add(doc.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
auto added_doc = add_op.get();
|
||||
|
||||
ASSERT_EQ(384, added_doc["embedding"].get<std::vector<float>>().size());
|
||||
}
|
@ -4789,9 +4789,9 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) {
|
||||
ASSERT_EQ("butterfly", search_res["hits"][1]["document"]["name"].get<std::string>());
|
||||
ASSERT_EQ("butterball", search_res["hits"][2]["document"]["name"].get<std::string>());
|
||||
|
||||
ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
|
||||
@ -4812,8 +4812,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error());
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, CreateModelDirIfNotExists) {
|
||||
@ -5059,7 +5058,7 @@ TEST_F(CollectionTest, HideOpenAIApiKey) {
|
||||
ASSERT_TRUE(op.ok());
|
||||
auto summary = op.get()->get_summary_json();
|
||||
// hide api key with * after first 3 characters
|
||||
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'));
|
||||
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(5, api_key.size() - 5, api_key.size() - 5, '*'));
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, PrefixSearchDisabledForOpenAI) {
|
||||
@ -5133,3 +5132,48 @@ TEST_F(CollectionTest, MoreThanOneEmbeddingField) {
|
||||
ASSERT_EQ("Only one embedding field is allowed in the query.", search_res_op.error());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "names", "type": "string[]"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
auto coll = op.get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["names"] = nlohmann::json::array();
|
||||
|
||||
// try adding
|
||||
auto add_op = coll->add(doc.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
ASSERT_TRUE(add_op.get()["embedding"].is_null());
|
||||
|
||||
// try updating
|
||||
auto id = add_op.get()["id"];
|
||||
doc["names"].push_back("butter");
|
||||
std::string dirty_values;
|
||||
|
||||
|
||||
auto update_op = coll->update_matching_filter("id:=" + id.get<std::string>(), doc.dump(), dirty_values);
|
||||
ASSERT_TRUE(update_op.ok());
|
||||
ASSERT_EQ(1, update_op.get()["num_updated"]);
|
||||
|
||||
|
||||
auto get_op = coll->get(id);
|
||||
ASSERT_TRUE(get_op.ok());
|
||||
|
||||
ASSERT_FALSE(get_op.get()["embedding"].is_null());
|
||||
|
||||
ASSERT_EQ(384, get_op.get()["embedding"].size());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user