diff --git a/include/collection.h b/include/collection.h index c81c4123..5fcb54c8 100644 --- a/include/collection.h +++ b/include/collection.h @@ -162,6 +162,8 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); + void process_remove_field_for_embedding_fields(const field& the_field); + void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, const std::vector& hidden_hits, @@ -355,6 +357,8 @@ public: Option embed_fields(nlohmann::json& document); + Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc); + static uint32_t get_seq_id_from_key(const std::string & key); Option get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const; diff --git a/src/collection.cpp b/src/collection.cpp index bea45361..3937bbfc 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -72,10 +72,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: const std::string& id) { try { document = nlohmann::json::parse(json_str); - auto embed_res = embed_fields(document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } } catch(const std::exception& e) { LOG(ERROR) << "JSON error: " << e.what(); return Option(400, std::string("Bad JSON: ") + e.what()); @@ -106,6 +102,13 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: uint32_t seq_id = get_next_seq_id(); document["id"] = std::to_string(seq_id); + // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc + auto embed_res = embed_fields(document); + if (!embed_res.ok()) { + return Option(400, embed_res.error()); + } + + // Add reference helper fields in the document. for (auto const& pair: reference_fields) { auto field_name = pair.first; @@ -176,9 +179,20 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: if(operation == CREATE) { return Option(409, std::string("A document with id ") + doc_id + " already exists."); } + + // UPSERT, EMPLACE or UPDATE uint32_t seq_id = (uint32_t) std::stoul(seq_id_str); + + //Handle embedding here for UPDATE + nlohmann::json old_doc; + get_document_from_store(get_seq_id_key(seq_id), old_doc); + auto embed_res = embed_fields_update(old_doc, document); + if (!embed_res.ok()) { + return Option(400, embed_res.error()); + } + return Option(doc_seq_id_t{seq_id, false}); } else { @@ -188,6 +202,12 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: } else { // for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc uint32_t seq_id = get_next_seq_id(); + + // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc + auto embed_res = embed_fields(document); + if (!embed_res.ok()) { + return Option(400, embed_res.error()); + } return Option(doc_seq_id_t{seq_id, true}); } } @@ -298,7 +318,6 @@ nlohmann::json Collection::add_many(std::vector& json_lines, nlohma const std::string & json_line = json_lines[i]; Option doc_seq_id_op = to_doc(json_line, document, operation, dirty_values, id); - const uint32_t seq_id = doc_seq_id_op.ok() ? doc_seq_id_op.get().seq_id : 0; index_record record(i, seq_id, document, operation, dirty_values); @@ -3759,6 +3778,8 @@ Option Collection::batch_alter_data(const std::vector& alter_fields if(del_field.name == default_sorting_field) { default_sorting_field = ""; } + + process_remove_field_for_embedding_fields(del_field); } index->refresh_schemas({}, del_fields); @@ -4781,3 +4802,73 @@ Option Collection::embed_fields(nlohmann::json& document) { return Option(true); } + +Option Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) { + nlohmann::json new_doc_copy = new_doc; + for(const auto& field : embedding_fields) { + if(TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); + } + std::string text_to_embed; + for(const auto& field_name : field.create_from) { + auto field_it = search_schema.find(field_name); + if(field_it != search_schema.end()) { + nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; + if(field_it.value().type == field_types::STRING) { + if(value.is_string()) { + text_to_embed += value.get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } else if(field_it.value().type == field_types::STRING_ARRAY) { + if(value.is_array()) { + for(const auto& val : value) { + if(val.is_string()) { + text_to_embed += val.get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + else { + return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); + } + } else { + return Option(400, "Field `" + field_name + "` is not a valid field."); + } + } + + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + new_doc_copy[field.name] = embedding; + } + new_doc = new_doc_copy; + return Option(true); +} + +void Collection::process_remove_field_for_embedding_fields(const field& the_field) { + std::vector::iterator> empty_fields; + for(auto& embedding_field : embedding_fields) { + const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) { + return other_field.name == embedding_field.name; + }); + actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) { + return the_field.name == field_name; + })); + embedding_field = *actual_field; + // store to remove embedding field if it has no field names in 'create_from' anymore. + if(embedding_field.create_from.size() == 0) { + empty_fields.push_back(actual_field); + } + } + + for(const auto& empty_field : empty_fields) { + search_schema.erase(empty_field->name); + embedding_fields.erase(empty_field->name); + fields.erase(empty_field); + } +} \ No newline at end of file