Review Changes II

This commit is contained in:
ozanarmagan 2023-04-07 23:56:25 +03:00
parent b7c988ab45
commit 7ae3cc9781
13 changed files with 228 additions and 208 deletions

View File

@ -210,8 +210,6 @@ private:
bool is_wildcard_query, bool is_group_by_query = false) const;
Option<bool> validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const;
Option<bool> persist_collection_meta();
Option<bool> batch_alter_data(const std::vector<field>& alter_fields,
@ -358,9 +356,6 @@ public:
const DIRTY_VALUES dirty_values,
const std::string& id="");
Option<bool> embed_fields(nlohmann::json& document);
Option<bool> embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc);
static uint32_t get_seq_id_from_key(const std::string & key);

View File

@ -50,7 +50,7 @@ namespace fields {
static const std::string num_dim = "num_dim";
static const std::string vec_dist = "vec_dist";
static const std::string reference = "reference";
static const std::string create_from = "create_from";
static const std::string embed_from = "embed_from";
static const std::string model_name = "model_name";
}
@ -77,7 +77,7 @@ struct field {
int nested_array;
size_t num_dim;
std::vector<std::string> create_from;
std::vector<std::string> embed_from;
std::string model_name;
vector_distance_type_t vec_dist;
@ -89,9 +89,9 @@ struct field {
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector<std::string> &create_from = {}, const std::string& model_name = "") :
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector<std::string> &embed_from = {}, const std::string& model_name = "") :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), create_from(create_from), model_name(model_name) {
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_name(model_name) {
set_computed_defaults(sort, infix);
}
@ -319,8 +319,8 @@ struct field {
if (!field.reference.empty()) {
field_val[fields::reference] = field.reference;
}
if(!field.create_from.empty()) {
field_val[fields::create_from] = field.create_from;
if(!field.embed_from.empty()) {
field_val[fields::embed_from] = field.embed_from;
if(!field.model_name.empty()) {
field_val[fields::model_name] = field.model_name;
}
@ -421,36 +421,36 @@ struct field {
for(nlohmann::json & field_json: fields_json) {
if(field_json.count(fields::create_from) != 0) {
if(field_json.count(fields::embed_from) != 0) {
if(TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
if(!field_json[fields::create_from].is_array()) {
return Option<bool>(400, "Property `" + fields::create_from + "` must be an array.");
if(!field_json[fields::embed_from].is_array()) {
return Option<bool>(400, "Property `" + fields::embed_from + "` must be an array.");
}
if(field_json[fields::create_from].empty()) {
return Option<bool>(400, "Property `" + fields::create_from + "` must have at least one element.");
if(field_json[fields::embed_from].empty()) {
return Option<bool>(400, "Property `" + fields::embed_from + "` must have at least one element.");
}
for(auto& create_from_field : field_json[fields::create_from]) {
if(!create_from_field.is_string()) {
return Option<bool>(400, "Property `" + fields::create_from + "` must contain only field names as strings.");
for(auto& embed_from_field : field_json[fields::embed_from]) {
if(!embed_from_field.is_string()) {
return Option<bool>(400, "Property `" + fields::embed_from + "` must contain only field names as strings.");
}
}
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
return Option<bool>(400, "Property `" + fields::create_from + "` is only allowed on a float array field.");
return Option<bool>(400, "Property `" + fields::embed_from + "` is only allowed on a float array field.");
}
for(auto& create_from_field : field_json[fields::create_from]) {
for(auto& embed_from_field : field_json[fields::embed_from]) {
bool flag = false;
for(const auto& field : fields_json) {
if(field[fields::name] == create_from_field) {
if(field[fields::name] == embed_from_field) {
if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) {
return Option<bool>(400, "Property `" + fields::create_from + "` can only have string or string array fields.");
return Option<bool>(400, "Property `" + fields::embed_from + "` can only have string or string array fields.");
}
flag = true;
break;
@ -458,9 +458,9 @@ struct field {
}
if(!flag) {
for(const auto& field : the_fields) {
if(field.name == create_from_field) {
if(field.name == embed_from_field) {
if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) {
return Option<bool>(400, "Property `" + fields::create_from + "` can only have used with string or string array fields.");
return Option<bool>(400, "Property `" + fields::embed_from + "` can only have used with string or string array fields.");
}
flag = true;
break;
@ -468,7 +468,7 @@ struct field {
}
}
if(!flag) {
return Option<bool>(400, "Property `" + fields::create_from + "` can only be used with string or string array fields.");
return Option<bool>(400, "Property `" + fields::embed_from + "` can only be used with string or string array fields.");
}
}
}

View File

@ -532,6 +532,16 @@ private:
const std::string& token, uint32_t seq_id);
void initialize_facet_indexes(const field& facet_field);
static Option<bool> embed_fields(nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema);
static Option<bool> embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema);
public:
// for limiting number of results on multiple candidates / query rewrites
@ -663,6 +673,7 @@ public:
const size_t batch_start_index, const size_t batch_size,
const std::string & default_sorting_field,
const tsl::htrie_map<char, field> & search_schema,
const tsl::htrie_map<char, field> & embedding_fields,
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
@ -672,6 +683,7 @@ public:
std::vector<index_record>& iter_batch,
const std::string& default_sorting_field,
const tsl::htrie_map<char, field>& search_schema,
const tsl::htrie_map<char, field> & embedding_fields,
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,

View File

@ -27,6 +27,7 @@ public:
static Option<uint32_t> validate_index_in_memory(nlohmann::json &document, uint32_t seq_id,
const std::string & default_sorting_field,
const tsl::htrie_map<char, field> & search_schema,
const tsl::htrie_map<char, field> & embedding_fields,
const index_operation_t op,
const bool is_update,
const std::string& fallback_field_type,
@ -67,4 +68,9 @@ public:
nlohmann::json::iterator& array_iter,
bool is_array, bool& array_ele_erased);
static Option<bool> validate_embed_fields(const nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema,
const bool& error_if_field_not_found);
};

View File

@ -52,7 +52,7 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co
index(init_index()) {
for (auto const& field: fields) {
if (!field.create_from.empty()) {
if (!field.embed_from.empty()) {
embedding_fields.emplace(field.name, field);
}
}
@ -108,12 +108,6 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
uint32_t seq_id = get_next_seq_id();
document["id"] = std::to_string(seq_id);
// Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc
auto embed_res = embed_fields(document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
// Add reference helper fields in the document.
for (auto const& pair: reference_fields) {
@ -191,14 +185,6 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
// UPSERT, EMPLACE or UPDATE
uint32_t seq_id = (uint32_t) std::stoul(seq_id_str);
//Handle embedding here for UPDATE
nlohmann::json old_doc;
get_document_from_store(get_seq_id_key(seq_id), old_doc);
auto embed_res = embed_fields_update(old_doc, document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
return Option<doc_seq_id_t>(doc_seq_id_t{seq_id, false});
} else {
@ -209,11 +195,6 @@ Option<doc_seq_id_t> Collection::to_doc(const std::string & json_str, nlohmann::
// for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc
uint32_t seq_id = get_next_seq_id();
// Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc
auto embed_res = embed_fields(document);
if (!embed_res.ok()) {
return Option<doc_seq_id_t>(400, embed_res.error());
}
return Option<doc_seq_id_t>(doc_seq_id_t{seq_id, true});
}
}
@ -259,8 +240,8 @@ nlohmann::json Collection::get_summary_json() const {
field_json[fields::infix] = coll_field.infix;
field_json[fields::locale] = coll_field.locale;
if(!coll_field.create_from.empty()) {
field_json[fields::create_from] = coll_field.create_from;
if(!coll_field.embed_from.empty()) {
field_json[fields::embed_from] = coll_field.embed_from;
}
if(coll_field.model_name.size() > 0) {
@ -397,6 +378,7 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
do_batched_index:
if((i+1) % index_batch_size == 0 || i == json_lines.size()-1 || repeated_doc) {
batch_index(index_records, json_lines, num_indexed, return_doc, return_id);
@ -593,7 +575,7 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
std::unique_lock lock(mutex);
Option<uint32_t> validation_op = validator_t::validate_index_in_memory(document, seq_id, default_sorting_field,
search_schema, op, false,
search_schema, embedding_fields, op, false,
fallback_field_type, dirty_values);
if(!validation_op.ok()) {
@ -604,7 +586,7 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
std::vector<index_record> index_batch;
index_batch.emplace_back(std::move(rec));
Index::batch_memory_index(index, index_batch, default_sorting_field, search_schema,
Index::batch_memory_index(index, index_batch, default_sorting_field, search_schema, embedding_fields,
fallback_field_type, token_separators, symbols_to_index, true);
num_documents += 1;
@ -614,7 +596,7 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records) {
std::unique_lock lock(mutex);
size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field,
search_schema, fallback_field_type,
search_schema, embedding_fields, fallback_field_type,
token_separators, symbols_to_index, true);
num_documents += num_indexed;
return num_indexed;
@ -1014,7 +996,7 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) {
bool exact_key_match = (kv.key().size() == field_name.size());
bool exact_primitive_match = exact_key_match && !kv.value().is_object();
bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().create_from.empty();
bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().embed_from.empty();
if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) {
if(exact_primitive_match && !is_wildcard) {
@ -3735,7 +3717,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
}
}
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions,
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
fallback_field_type, token_separators, symbols_to_index, true);
iter_batch.clear();
@ -3771,7 +3753,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
nested_fields.erase(del_field.name);
}
if(!del_field.create_from.empty()) {
if(!del_field.embed_from.empty()) {
embedding_fields.erase(del_field.name);
}
@ -4069,7 +4051,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
return Option<bool>(400, "Field `" + field_name + "` is not part of collection schema.");
}
if(found_field && !field_it.value().create_from.empty()) {
if(found_field && !field_it.value().embed_from.empty()) {
updated_embedding_fields.erase(field_it.key());
}
@ -4078,7 +4060,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_search_schema.erase(field_it.key());
updated_nested_fields.erase(field_it.key());
if(!field_it.value().create_from.empty()) {
if(!field_it.value().embed_from.empty()) {
updated_embedding_fields.erase(field_it.key());
}
@ -4092,7 +4074,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_search_schema.erase(prefix_kv.key());
updated_nested_fields.erase(prefix_kv.key());
if(!prefix_kv.value().create_from.empty()) {
if(!prefix_kv.value().embed_from.empty()) {
updated_embedding_fields.erase(prefix_kv.key());
}
}
@ -4145,7 +4127,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
addition_fields.push_back(f);
}
if(!f.create_from.empty()) {
if(!f.embed_from.empty()) {
return Option<bool>(400, "Embedding fields can only be added at the time of collection creation.");
}
@ -4160,7 +4142,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value());
updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value());
if(!prefix_kv.value().create_from.empty()) {
if(!prefix_kv.value().embed_from.empty()) {
return Option<bool>(400, "Embedding fields can only be added at the time of collection creation.");
}
@ -4234,6 +4216,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
// validate existing data on disk for compatibility via updated_search_schema
auto validate_op = validator_t::validate_index_in_memory(document, seq_id, default_sorting_field,
updated_search_schema,
updated_embedding_fields,
index_operation_t::CREATE,
false,
fallback_field_type,
@ -4484,7 +4467,7 @@ Index* Collection::init_index() {
nested_fields.emplace(field.name, field);
}
if(!field.create_from.empty()) {
if(!field.embed_from.empty()) {
embedding_fields.emplace(field.name, field);
}
@ -4759,108 +4742,18 @@ Option<bool> Collection::populate_include_exclude_fields_lk(const spp::sparse_ha
}
Option<bool> Collection::embed_fields(nlohmann::json& document) {
auto validate_res = validate_embed_fields(document, true);
if(!validate_res.ok()) {
return validate_res;
}
for(const auto& field : embedding_fields) {
std::string text_to_embed;
for(const auto& field_name : field.create_from) {
auto field_it = search_schema.find(field_name);
if(field_it.value().type == field_types::STRING) {
text_to_embed += document[field_name].get<std::string>() + " ";
} else if(field_it.value().type == field_types::STRING_ARRAY) {
for(const auto& val : document[field_name]) {
text_to_embed += val.get<std::string>() + " ";
}
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
document[field.name] = embedding;
}
return Option<bool>(true);
}
Option<bool> Collection::validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const {
if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
for(const auto& field : embedding_fields) {
for(const auto& field_name : field.create_from) {
auto schema_field_it = search_schema.find(field_name);
auto doc_field_it = document.find(field_name);
if(schema_field_it == search_schema.end()) {
return Option<bool>(400, "Field `" + field.name + "` has invalid fields to create embeddings from.");
}
if(doc_field_it == document.end()) {
if(error_if_field_not_found) {
return Option<bool>(400, "Field `" + field_name + "` is needed to create embedding.");
} else {
continue;
}
}
if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) ||
(schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
if(doc_field_it.value().is_array()) {
for(const auto& val : doc_field_it.value()) {
if(!val.is_string()) {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
}
}
}
return Option<bool>(true);
}
Option<bool> Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) {
auto validate_res = validate_embed_fields(new_doc, false);
if(!validate_res.ok()) {
return validate_res;
}
nlohmann::json new_doc_copy = new_doc;
for(const auto& field : embedding_fields) {
std::string text_to_embed;
for(const auto& field_name : field.create_from) {
auto field_it = search_schema.find(field_name);
nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name];
if(field_it.value().type == field_types::STRING) {
text_to_embed += value.get<std::string>() + " ";
} else if(field_it.value().type == field_types::STRING_ARRAY) {
for(const auto& val : value) {
text_to_embed += val.get<std::string>() + " ";
}
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
new_doc_copy[field.name] = embedding;
}
new_doc = new_doc_copy;
return Option<bool>(true);
}
void Collection::process_remove_field_for_embedding_fields(const field& the_field) {
std::vector<std::vector<field>::iterator> empty_fields;
for(auto& embedding_field : embedding_fields) {
const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) {
return other_field.name == embedding_field.name;
});
actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) {
actual_field->embed_from.erase(std::remove_if(actual_field->embed_from.begin(), actual_field->embed_from.end(), [&the_field](std::string field_name) {
return the_field.name == field_name;
}));
embedding_field = *actual_field;
// store to remove embedding field if it has no field names in 'create_from' anymore.
if(embedding_field.create_from.empty()) {
// store to remove embedding field if it has no field names in 'embed_from' anymore.
if(embedding_field.embed_from.empty()) {
empty_fields.push_back(actual_field);
}
}

View File

@ -58,8 +58,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field_obj[fields::reference] = "";
}
if(field_obj.count(fields::create_from) == 0) {
field_obj[fields::create_from] = std::vector<std::string>();
if(field_obj.count(fields::embed_from) == 0) {
field_obj[fields::embed_from] = std::vector<std::string>();
}
if(field_obj.count(fields::model_name) == 0) {
@ -78,7 +78,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet],
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::create_from],
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed_from],
field_obj[fields::model_name]);
// value of `sort` depends on field type

View File

@ -672,11 +672,11 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}
}
if(field_json.count(fields::model_name) > 0 && field_json.count(fields::create_from) == 0) {
return Option<bool>(400, "Property `" + fields::model_name + "` can only be used with `" + fields::create_from + "`.");
if(field_json.count(fields::model_name) > 0 && field_json.count(fields::embed_from) == 0) {
return Option<bool>(400, "Property `" + fields::model_name + "` can only be used with `" + fields::embed_from + "`.");
}
if(field_json.count(fields::create_from) != 0) {
if(field_json.count(fields::embed_from) != 0) {
// If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model)
field_json[fields::num_dim] = static_cast<unsigned int>(384);
if(field_json.count(fields::model_name) != 0) {
@ -695,7 +695,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}
}
} else {
field_json[fields::create_from] = std::vector<std::string>();
field_json[fields::embed_from] = std::vector<std::string>();
}
@ -784,7 +784,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
field_json[fields::reference], field_json[fields::create_from].get<std::vector<std::string>>(),
field_json[fields::reference], field_json[fields::embed_from].get<std::vector<std::string>>(),
field_json[fields::model_name])
);

View File

@ -411,6 +411,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
const size_t batch_start_index, const size_t batch_size,
const std::string& default_sorting_field,
const tsl::htrie_map<char, field>& search_schema,
const tsl::htrie_map<char, field>& embedding_fields,
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
@ -435,6 +436,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
Option<uint32_t> validation_op = validator_t::validate_index_in_memory(index_rec.doc, index_rec.seq_id,
default_sorting_field,
search_schema,
embedding_fields,
index_rec.operation,
index_rec.is_update,
fallback_field_type,
@ -451,6 +453,9 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
get_doc_changes(index_rec.operation, index_rec.doc, index_rec.old_doc, index_rec.new_doc,
index_rec.del_doc);
scrub_reindex_doc(search_schema, index_rec.doc, index_rec.del_doc, index_rec.old_doc);
embed_fields(index_rec.new_doc, embedding_fields, search_schema);
} else {
embed_fields(index_rec.doc, embedding_fields, search_schema);
}
compute_token_offsets_facets(index_rec, search_schema, token_separators, symbols_to_index);
@ -485,6 +490,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_batch,
const std::string & default_sorting_field,
const tsl::htrie_map<char, field> & search_schema,
const tsl::htrie_map<char, field> & embedding_fields,
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
@ -518,7 +524,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
index->thread_pool->enqueue([&, batch_index, batch_len]() {
write_log_index = local_write_log_index;
validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema,
fallback_field_type, token_separators, symbols_to_index, do_validation);
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation);
std::unique_lock<std::mutex> lock(m_process);
num_processed++;
@ -6257,6 +6263,61 @@ bool Index::common_results_exist(std::vector<art_leaf*>& leaves, bool must_match
return phrase_exists;
}
Option<bool> Index::embed_fields(nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema) {
for(const auto& field : embedding_fields) {
std::string text_to_embed;
for(const auto& field_name : field.embed_from) {
auto field_it = search_schema.find(field_name);
if(field_it.value().type == field_types::STRING) {
text_to_embed += document[field_name].get<std::string>() + " ";
} else if(field_it.value().type == field_types::STRING_ARRAY) {
for(const auto& val : document[field_name]) {
text_to_embed += val.get<std::string>() + " ";
}
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
document[field.name] = embedding;
}
return Option<bool>(true);
}
Option<bool> Index::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema) {
nlohmann::json new_doc_copy = new_doc;
for(const auto& field : embedding_fields) {
std::string text_to_embed;
for(const auto& field_name : field.embed_from) {
auto field_it = search_schema.find(field_name);
nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name];
if(field_it.value().type == field_types::STRING) {
text_to_embed += value.get<std::string>() + " ";
} else if(field_it.value().type == field_types::STRING_ARRAY) {
for(const auto& val : value) {
text_to_embed += val.get<std::string>() + " ";
}
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
new_doc_copy[field.name] = embedding;
}
new_doc = new_doc_copy;
return Option<bool>(true);
}
/*
// https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon
// NOTE: polygon and point should have been transformed with `transform_for_180th_meridian`
@ -6302,3 +6363,4 @@ void Index::transform_for_180th_meridian(GeoCoord &point, double offset) {
point.lon = point.lon < 0.0 ? point.lon + offset : point.lon;
}
*/

View File

@ -529,6 +529,7 @@ Option<uint32_t> validator_t::coerce_float(const DIRTY_VALUES& dirty_values, con
Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document, uint32_t seq_id,
const std::string & default_sorting_field,
const tsl::htrie_map<char, field> & search_schema,
const tsl::htrie_map<char, field> & embedding_fields,
const index_operation_t op,
const bool is_update,
const std::string& fallback_field_type,
@ -544,6 +545,11 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
for(const auto& a_field: search_schema) {
const std::string& field_name = a_field.name;
// ignore embedding fields, they will be validated later
if(embedding_fields.count(field_name) > 0) {
continue;
}
if(field_name == "id" || a_field.is_object()) {
continue;
}
@ -574,5 +580,50 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
}
}
// validate embedding fields
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update);
if(!validate_embed_op.ok()) {
return Option<>(validate_embed_op.code(), validate_embed_op.error());
}
return Option<>(200);
}
Option<bool> validator_t::validate_embed_fields(const nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema,
const bool& error_if_field_not_found) {
if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
for(const auto& field : embedding_fields) {
for(const auto& field_name : field.embed_from) {
auto schema_field_it = search_schema.find(field_name);
auto doc_field_it = document.find(field_name);
if(schema_field_it == search_schema.end()) {
return Option<bool>(400, "Field `" + field.name + "` has invalid fields to create embeddings from.");
}
if(doc_field_it == document.end()) {
if(error_if_field_not_found) {
return Option<bool>(400, "Field `" + field_name + "` is needed to create embedding.");
} else {
continue;
}
}
if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) ||
(schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
if(doc_field_it.value().is_array()) {
for(const auto& val : doc_field_it.value()) {
if(!val.is_string()) {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
}
}
}
return Option<bool>(true);
}

View File

@ -1597,7 +1597,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) {
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["create_from"] = {"name"};
field_json["embed_from"] = {"name"};
std::vector<field> fields;
std::string fallback_field_type;
@ -1607,7 +1607,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) {
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `create_from` can only be used with string or string array fields.", field_op.error());
ASSERT_EQ("Property `embed_from` can only be used with string or string array fields.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) {
@ -1615,7 +1615,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) {
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["create_from"] = {"name"};
field_json["embed_from"] = {"name"};
std::vector<field> fields;
std::string fallback_field_type;
@ -1633,7 +1633,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) {
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["create_from"] = "name";
field_json["embed_from"] = "name";
std::vector<field> fields;
std::string fallback_field_type;
@ -1643,7 +1643,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) {
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `create_from` must be an array.", field_op.error());
ASSERT_EQ("Property `embed_from` must be an array.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) {
@ -1660,7 +1660,7 @@ TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) {
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `model_name` can only be used with `create_from`.", field_op.error());
ASSERT_EQ("Property `model_name` can only be used with `embed_from`.", field_op.error());
}
@ -1670,7 +1670,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) {
TextEmbedderManager::download_default_model();
field embedding = field("embedding", field_types::FLOAT_ARRAY, false);
embedding.create_from.push_back("name");
embedding.embed_from.push_back("name");
std::vector<field> fields = {field("name", field_types::STRING, false),
embedding};
auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO);

View File

@ -1462,7 +1462,7 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) {
nlohmann::json update_schema = R"({
"fields": [
{"name": "embedding", "type":"float[]", "create_from": ["names"]}
{"name": "embedding", "type":"float[]", "embed_from": ["names"]}
]
})"_json;
@ -1478,7 +1478,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type":"string"},
{"name": "embedding", "type":"float[]", "create_from": ["names","category"]}
{"name": "embedding", "type":"float[]", "embed_from": ["names","category"]}
]
})"_json;
@ -1500,7 +1500,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
LOG(INFO) << "Dropping field";
auto embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(2, embedding_fields["embedding"].create_from.size());
ASSERT_EQ(2, embedding_fields["embedding"].embed_from.size());
LOG(INFO) << "Before alter";
@ -1510,8 +1510,8 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
LOG(INFO) << "After alter";
embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(1, embedding_fields["embedding"].create_from.size());
ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]);
ASSERT_EQ(1, embedding_fields["embedding"].embed_from.size());
ASSERT_EQ("category", embedding_fields["embedding"].embed_from[0]);
}
TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
@ -1519,7 +1519,7 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;
@ -1535,8 +1535,8 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
auto embedding_field_it = embedding_fields_map.find("embedding");
ASSERT_TRUE(embedding_field_it != embedding_fields_map.end());
ASSERT_EQ("embedding", embedding_field_it.value().name);
ASSERT_EQ(1, embedding_field_it.value().create_from.size());
ASSERT_EQ("name", embedding_field_it.value().create_from[0]);
ASSERT_EQ(1, embedding_field_it.value().embed_from.size());
ASSERT_EQ("name", embedding_field_it.value().embed_from[0]);
// drop the embedding field
nlohmann::json schema_without_embedding = R"({

View File

@ -4616,7 +4616,7 @@ TEST_F(CollectionTest, SemanticSearchTest) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;
@ -4651,7 +4651,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;
@ -4682,7 +4682,7 @@ TEST_F(CollectionTest, HybridSearch) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;
@ -4695,6 +4695,7 @@ TEST_F(CollectionTest, HybridSearch) {
nlohmann::json object;
object["name"] = "apple";
auto add_op = coll->add(object.dump());
LOG(INFO) << "add_op.error(): " << add_op.error();
ASSERT_TRUE(add_op.ok());
ASSERT_EQ("apple", add_op.get()["name"]);
@ -4710,40 +4711,40 @@ TEST_F(CollectionTest, HybridSearch) {
ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size());
}
TEST_F(CollectionTest, EmbedFielsTest) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
]
})"_json;
// TEST_F(CollectionTest, EmbedFielsTest) {
// nlohmann::json schema = R"({
// "name": "objects",
// "fields": [
// {"name": "name", "type": "string"},
// {"name": "embedding", "type":"float[]", "embed_from": ["name"]}
// ]
// })"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
// TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
// TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
// auto op = collectionManager.create_collection(schema);
// ASSERT_TRUE(op.ok());
// Collection* coll = op.get();
nlohmann::json object = R"({
"name": "apple"
})"_json;
// nlohmann::json object = R"({
// "name": "apple"
// })"_json;
auto embed_op = coll->embed_fields(object);
// auto embed_op = coll->embed_fields(object);
ASSERT_TRUE(embed_op.ok());
// ASSERT_TRUE(embed_op.ok());
ASSERT_EQ("apple", object["name"]);
ASSERT_EQ(384, object["embedding"].get<std::vector<float>>().size());
}
// ASSERT_EQ("apple", object["name"]);
// ASSERT_EQ(384, object["embedding"].get<std::vector<float>>().size());
// }
TEST_F(CollectionTest, HybridSearchRankFusionTest) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;
@ -4817,7 +4818,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;
@ -4851,7 +4852,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) {
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"},
{"name": "embedding", "type":"float[]", "create_from": ["names"]}
{"name": "embedding", "type":"float[]", "embed_from": ["names"]}
]
})"_json;
@ -4876,8 +4877,8 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) {
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["names", "category"]}
{"name": "category", "type": "string", "optional": true},
{"name": "embedding", "type":"float[]", "embed_from": ["names", "category"]}
]
})"_json;
@ -4903,7 +4904,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["name"]}
{"name": "embedding", "type":"float[]", "embed_from": ["name"]}
]
})"_json;

View File

@ -683,7 +683,7 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
"name": "coll1",
"fields": [
{"name": "name", "type": "string"},
{"name": "vec", "type": "float[]", "create_from": ["name"]}
{"name": "vec", "type": "float[]", "embed_from": ["name"]}
]
})"_json;