Fix for text embedding when schema or document updated

This commit is contained in:
ozanarmagan 2023-04-02 12:48:33 +03:00
parent e85ae5d7d2
commit 401ebbe481
2 changed files with 100 additions and 5 deletions

View File

@ -162,6 +162,8 @@ private:
void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
void process_remove_field_for_embedding_fields(const field& the_field);
void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
@ -355,6 +357,8 @@ public:
Option<bool> embed_fields(nlohmann::json& document);
Option<bool> embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc);
static uint32_t get_seq_id_from_key(const std::string & key);
Option<bool> get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const;

View File

@ -72,10 +72,6 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
const std::string& id) {
try {
document = nlohmann::json::parse(json_str);
auto embed_res = embed_fields(document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
} catch(const std::exception& e) {
LOG(ERROR) << "JSON error: " << e.what();
return Option<doc_seq_id_t>(400, std::string("Bad JSON: ") + e.what());
@ -106,6 +102,13 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
uint32_t seq_id = get_next_seq_id();
document["id"] = std::to_string(seq_id);
// Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc
auto embed_res = embed_fields(document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
// Add reference helper fields in the document.
for (auto const& pair: reference_fields) {
auto field_name = pair.first;
@ -176,9 +179,20 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
if(operation == CREATE) {
return Option<doc_seq_id_t>(409, std::string("A document with id ") + doc_id + " already exists.");
}
// UPSERT, EMPLACE or UPDATE
uint32_t seq_id = (uint32_t) std::stoul(seq_id_str);
//Handle embedding here for UPDATE
nlohmann::json old_doc;
get_document_from_store(get_seq_id_key(seq_id), old_doc);
auto embed_res = embed_fields_update(old_doc, document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
return Option<doc_seq_id_t>(doc_seq_id_t{seq_id, false});
} else {
@ -188,6 +202,12 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
} else {
// for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc
uint32_t seq_id = get_next_seq_id();
// Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc
auto embed_res = embed_fields(document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
return Option<doc_seq_id_t>(doc_seq_id_t{seq_id, true});
}
}
@ -298,7 +318,6 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
const std::string & json_line = json_lines[i];
Option<doc_seq_id_t> doc_seq_id_op = to_doc(json_line, document, operation, dirty_values, id);
const uint32_t seq_id = doc_seq_id_op.ok() ? doc_seq_id_op.get().seq_id : 0;
index_record record(i, seq_id, document, operation, dirty_values);
@ -3759,6 +3778,8 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
if(del_field.name == default_sorting_field) {
default_sorting_field = "";
}
process_remove_field_for_embedding_fields(del_field);
}
index->refresh_schemas({}, del_fields);
@ -4781,3 +4802,73 @@ Option<bool> Collection::embed_fields(nlohmann::json& document) {
return Option<bool>(true);
}
Option<bool> Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) {
nlohmann::json new_doc_copy = new_doc;
for(const auto& field : embedding_fields) {
if(TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
std::string text_to_embed;
for(const auto& field_name : field.create_from) {
auto field_it = search_schema.find(field_name);
if(field_it != search_schema.end()) {
nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name];
if(field_it.value().type == field_types::STRING) {
if(value.is_string()) {
text_to_embed += value.get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
} else if(field_it.value().type == field_types::STRING_ARRAY) {
if(value.is_array()) {
for(const auto& val : value) {
if(val.is_string()) {
text_to_embed += val.get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
else {
return Option<bool>(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it.");
}
} else {
return Option<bool>(400, "Field `" + field_name + "` is not a valid field.");
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
new_doc_copy[field.name] = embedding;
}
new_doc = new_doc_copy;
return Option<bool>(true);
}
void Collection::process_remove_field_for_embedding_fields(const field& the_field) {
std::vector<std::vector<field>::iterator> empty_fields;
for(auto& embedding_field : embedding_fields) {
const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) {
return other_field.name == embedding_field.name;
});
actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) {
return the_field.name == field_name;
}));
embedding_field = *actual_field;
// store to remove embedding field if it has no field names in 'create_from' anymore.
if(embedding_field.create_from.size() == 0) {
empty_fields.push_back(actual_field);
}
}
for(const auto& empty_field : empty_fields) {
search_schema.erase(empty_field->name);
embedding_fields.erase(empty_field->name);
fields.erase(empty_field);
}
}