mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 05:08:43 +08:00
Fix for text embedding when schema or document updated
This commit is contained in:
parent
e85ae5d7d2
commit
401ebbe481
@ -162,6 +162,8 @@ private:
|
||||
|
||||
void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
|
||||
|
||||
void process_remove_field_for_embedding_fields(const field& the_field);
|
||||
|
||||
void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented,
|
||||
const std::map<size_t, std::vector<std::string>>& pinned_hits,
|
||||
const std::vector<std::string>& hidden_hits,
|
||||
@ -355,6 +357,8 @@ public:
|
||||
|
||||
Option<bool> embed_fields(nlohmann::json& document);
|
||||
|
||||
Option<bool> embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc);
|
||||
|
||||
static uint32_t get_seq_id_from_key(const std::string & key);
|
||||
|
||||
Option<bool> get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const;
|
||||
|
@ -72,10 +72,6 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
|
||||
const std::string& id) {
|
||||
try {
|
||||
document = nlohmann::json::parse(json_str);
|
||||
auto embed_res = embed_fields(document);
|
||||
if (!embed_res.ok()) {
|
||||
return Option<doc_seq_id_t>(400, embed_res.error());
|
||||
}
|
||||
} catch(const std::exception& e) {
|
||||
LOG(ERROR) << "JSON error: " << e.what();
|
||||
return Option<doc_seq_id_t>(400, std::string("Bad JSON: ") + e.what());
|
||||
@ -106,6 +102,13 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
|
||||
uint32_t seq_id = get_next_seq_id();
|
||||
document["id"] = std::to_string(seq_id);
|
||||
|
||||
// Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc
|
||||
auto embed_res = embed_fields(document);
|
||||
if (!embed_res.ok()) {
|
||||
return Option<doc_seq_id_t>(400, embed_res.error());
|
||||
}
|
||||
|
||||
|
||||
// Add reference helper fields in the document.
|
||||
for (auto const& pair: reference_fields) {
|
||||
auto field_name = pair.first;
|
||||
@ -176,9 +179,20 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
|
||||
if(operation == CREATE) {
|
||||
return Option<doc_seq_id_t>(409, std::string("A document with id ") + doc_id + " already exists.");
|
||||
}
|
||||
|
||||
|
||||
|
||||
// UPSERT, EMPLACE or UPDATE
|
||||
uint32_t seq_id = (uint32_t) std::stoul(seq_id_str);
|
||||
|
||||
//Handle embedding here for UPDATE
|
||||
nlohmann::json old_doc;
|
||||
get_document_from_store(get_seq_id_key(seq_id), old_doc);
|
||||
auto embed_res = embed_fields_update(old_doc, document);
|
||||
if (!embed_res.ok()) {
|
||||
return Option<doc_seq_id_t>(400, embed_res.error());
|
||||
}
|
||||
|
||||
return Option<doc_seq_id_t>(doc_seq_id_t{seq_id, false});
|
||||
|
||||
} else {
|
||||
@ -188,6 +202,12 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
|
||||
} else {
|
||||
// for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc
|
||||
uint32_t seq_id = get_next_seq_id();
|
||||
|
||||
// Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc
|
||||
auto embed_res = embed_fields(document);
|
||||
if (!embed_res.ok()) {
|
||||
return Option<doc_seq_id_t>(400, embed_res.error());
|
||||
}
|
||||
return Option<doc_seq_id_t>(doc_seq_id_t{seq_id, true});
|
||||
}
|
||||
}
|
||||
@ -298,7 +318,6 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
|
||||
const std::string & json_line = json_lines[i];
|
||||
Option<doc_seq_id_t> doc_seq_id_op = to_doc(json_line, document, operation, dirty_values, id);
|
||||
|
||||
|
||||
const uint32_t seq_id = doc_seq_id_op.ok() ? doc_seq_id_op.get().seq_id : 0;
|
||||
index_record record(i, seq_id, document, operation, dirty_values);
|
||||
|
||||
@ -3759,6 +3778,8 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
if(del_field.name == default_sorting_field) {
|
||||
default_sorting_field = "";
|
||||
}
|
||||
|
||||
process_remove_field_for_embedding_fields(del_field);
|
||||
}
|
||||
|
||||
index->refresh_schemas({}, del_fields);
|
||||
@ -4781,3 +4802,73 @@ Option<bool> Collection::embed_fields(nlohmann::json& document) {
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) {
|
||||
nlohmann::json new_doc_copy = new_doc;
|
||||
for(const auto& field : embedding_fields) {
|
||||
if(TextEmbedderManager::model_dir.empty()) {
|
||||
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
|
||||
}
|
||||
std::string text_to_embed;
|
||||
for(const auto& field_name : field.create_from) {
|
||||
auto field_it = search_schema.find(field_name);
|
||||
if(field_it != search_schema.end()) {
|
||||
nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name];
|
||||
if(field_it.value().type == field_types::STRING) {
|
||||
if(value.is_string()) {
|
||||
text_to_embed += value.get<std::string>() + " ";
|
||||
} else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
|
||||
}
|
||||
} else if(field_it.value().type == field_types::STRING_ARRAY) {
|
||||
if(value.is_array()) {
|
||||
for(const auto& val : value) {
|
||||
if(val.is_string()) {
|
||||
text_to_embed += val.get<std::string>() + " ";
|
||||
} else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it.");
|
||||
}
|
||||
} else {
|
||||
return Option<bool>(400, "Field `" + field_name + "` is not a valid field.");
|
||||
}
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
|
||||
std::vector<float> embedding = embedder->Embed(text_to_embed);
|
||||
new_doc_copy[field.name] = embedding;
|
||||
}
|
||||
new_doc = new_doc_copy;
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
void Collection::process_remove_field_for_embedding_fields(const field& the_field) {
|
||||
std::vector<std::vector<field>::iterator> empty_fields;
|
||||
for(auto& embedding_field : embedding_fields) {
|
||||
const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) {
|
||||
return other_field.name == embedding_field.name;
|
||||
});
|
||||
actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) {
|
||||
return the_field.name == field_name;
|
||||
}));
|
||||
embedding_field = *actual_field;
|
||||
// store to remove embedding field if it has no field names in 'create_from' anymore.
|
||||
if(embedding_field.create_from.size() == 0) {
|
||||
empty_fields.push_back(actual_field);
|
||||
}
|
||||
}
|
||||
|
||||
for(const auto& empty_field : empty_fields) {
|
||||
search_schema.erase(empty_field->name);
|
||||
embedding_fields.erase(empty_field->name);
|
||||
fields.erase(empty_field);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user