Review Changes

This commit is contained in:
ozanarmagan 2023-04-05 18:05:45 +03:00
parent 2a46120ad4
commit 1ad7bcdce3
8 changed files with 236 additions and 107 deletions

View File

@ -209,6 +209,9 @@ private:
std::vector<sort_by>& sort_fields_std,
bool is_wildcard_query, bool is_group_by_query = false) const;
Option<bool> validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const;
Option<bool> persist_collection_meta();
Option<bool> batch_alter_data(const std::vector<field>& alter_fields,

View File

@ -10,7 +10,7 @@ struct vector_query_t {
std::string field_name;
size_t k = 0;
size_t flat_search_cutoff = 0;
float similarity_cutoff = 0.0;
float distance_threshold = 2.01;
std::vector<float> values;
uint32_t seq_id = 0;
@ -20,7 +20,7 @@ struct vector_query_t {
// used for testing only
field_name.clear();
k = 0;
similarity_cutoff = 0.0;
distance_threshold = 2.01;
values.clear();
seq_id = 0;
query_doc_given = false;

View File

@ -51,6 +51,12 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co
symbols_to_index(to_char_array(symbols_to_index)), token_separators(to_char_array(token_separators)),
index(init_index()) {
for (auto const& field: fields) {
if (!field.create_from.empty()) {
embedding_fields.emplace(field.name, field);
}
}
this->num_documents = 0;
}
@ -253,7 +259,7 @@ nlohmann::json Collection::get_summary_json() const {
field_json[fields::infix] = coll_field.infix;
field_json[fields::locale] = coll_field.locale;
if(coll_field.create_from.size() > 0) {
if(!coll_field.create_from.empty()) {
field_json[fields::create_from] = coll_field.create_from;
}
@ -1008,7 +1014,7 @@ Option<bool> Collection::extract_field_name(const std::string& field_name,
for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) {
bool exact_key_match = (kv.key().size() == field_name.size());
bool exact_primitive_match = exact_key_match && !kv.value().is_object();
bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().create_from.size() > 0;
bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().create_from.empty();
if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) {
if(exact_primitive_match && !is_wildcard) {
@ -3765,7 +3771,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
nested_fields.erase(del_field.name);
}
if(del_field.create_from.size() > 0) {
if(!del_field.create_from.empty()) {
embedding_fields.erase(del_field.name);
}
@ -4063,7 +4069,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
return Option<bool>(400, "Field `" + field_name + "` is not part of collection schema.");
}
if(found_field && field_it.value().create_from.size() > 0) {
if(found_field && !field_it.value().create_from.empty()) {
updated_embedding_fields.erase(field_it.key());
}
@ -4072,7 +4078,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_search_schema.erase(field_it.key());
updated_nested_fields.erase(field_it.key());
if(field_it.value().create_from.size() > 0) {
if(!field_it.value().create_from.empty()) {
updated_embedding_fields.erase(field_it.key());
}
@ -4086,7 +4092,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_search_schema.erase(prefix_kv.key());
updated_nested_fields.erase(prefix_kv.key());
if(prefix_kv.value().create_from.size() > 0) {
if(!prefix_kv.value().create_from.empty()) {
updated_embedding_fields.erase(prefix_kv.key());
}
}
@ -4139,7 +4145,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
addition_fields.push_back(f);
}
if(f.create_from.size() > 0) {
if(!f.create_from.empty()) {
return Option<bool>(400, "Embedding fields can only be added at the time of collection creation.");
}
@ -4154,7 +4160,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value());
updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value());
if(prefix_kv.value().create_from.size() > 0) {
if(!prefix_kv.value().create_from.empty()) {
return Option<bool>(400, "Embedding fields can only be added at the time of collection creation.");
}
@ -4478,7 +4484,7 @@ Index* Collection::init_index() {
nested_fields.emplace(field.name, field);
}
if(field.create_from.size() > 0) {
if(!field.create_from.empty()) {
embedding_fields.emplace(field.name, field);
}
@ -4754,45 +4760,22 @@ Option<bool> Collection::populate_include_exclude_fields_lk(const spp::sparse_ha
Option<bool> Collection::embed_fields(nlohmann::json& document) {
auto validate_res = validate_embed_fields(document, true);
if(!validate_res.ok()) {
return validate_res;
}
for(const auto& field : embedding_fields) {
if(TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
std::string text_to_embed;
for(const auto& field_name : field.create_from) {
auto field_it = search_schema.find(field_name);
if(field_it != search_schema.end()) {
if(field_it.value().type == field_types::STRING) {
if(document.find(field_name) != document.end()) {
if(document[field_name].is_string()) {
text_to_embed += document[field_name].get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
} else if(field_it.value().type == field_types::STRING_ARRAY) {
if(document.find(field_name) != document.end()) {
if(document[field_name].is_array()) {
for(const auto& val : document[field_name]) {
if(val.is_string()) {
text_to_embed += val.get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
if(field_it.value().type == field_types::STRING) {
text_to_embed += document[field_name].get<std::string>() + " ";
} else if(field_it.value().type == field_types::STRING_ARRAY) {
for(const auto& val : document[field_name]) {
text_to_embed += val.get<std::string>() + " ";
}
else {
return Option<bool>(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it.");
}
} else {
return Option<bool>(400, "Field `" + field_name + "` is not a valid field.");
}
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(text_to_embed);
@ -4802,41 +4785,58 @@ Option<bool> Collection::embed_fields(nlohmann::json& document) {
return Option<bool>(true);
}
Option<bool> Collection::validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const {
if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
for(const auto& field : embedding_fields) {
for(const auto& field_name : field.create_from) {
auto schema_field_it = search_schema.find(field_name);
auto doc_field_it = document.find(field_name);
if(schema_field_it == search_schema.end()) {
return Option<bool>(400, "Field `" + field.name + "` has invalid fields to create embeddings from.");
}
if(doc_field_it == document.end()) {
if(error_if_field_not_found) {
return Option<bool>(400, "Field `" + field_name + "` is needed to create embedding.");
} else {
continue;
}
}
if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) ||
(schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
if(doc_field_it.value().is_array()) {
for(const auto& val : doc_field_it.value()) {
if(!val.is_string()) {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
}
}
}
return Option<bool>(true);
}
Option<bool> Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) {
auto validate_res = validate_embed_fields(new_doc, false);
if(!validate_res.ok()) {
return validate_res;
}
nlohmann::json new_doc_copy = new_doc;
for(const auto& field : embedding_fields) {
if(TextEmbedderManager::model_dir.empty()) {
return Option<bool>(400, "Text embedding is not enabled. Please set `model-dir` at startup.");
}
std::string text_to_embed;
for(const auto& field_name : field.create_from) {
auto field_it = search_schema.find(field_name);
if(field_it != search_schema.end()) {
nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name];
if(field_it.value().type == field_types::STRING) {
if(value.is_string()) {
text_to_embed += value.get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
} else if(field_it.value().type == field_types::STRING_ARRAY) {
if(value.is_array()) {
for(const auto& val : value) {
if(val.is_string()) {
text_to_embed += val.get<std::string>() + " ";
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
}
} else {
return Option<bool>(400, "Field `" + field_name + "` has malformed data.");
}
nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name];
if(field_it.value().type == field_types::STRING) {
text_to_embed += value.get<std::string>() + " ";
} else if(field_it.value().type == field_types::STRING_ARRAY) {
for(const auto& val : value) {
text_to_embed += val.get<std::string>() + " ";
}
else {
return Option<bool>(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it.");
}
} else {
return Option<bool>(400, "Field `" + field_name + "` is not a valid field.");
}
}
@ -4860,7 +4860,7 @@ void Collection::process_remove_field_for_embedding_fields(const field& the_fiel
}));
embedding_field = *actual_field;
// store to remove embedding field if it has no field names in 'create_from' anymore.
if(embedding_field.create_from.size() == 0) {
if(embedding_field.create_from.empty()) {
empty_fields.push_back(actual_field);
}
}

View File

@ -2882,7 +2882,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
dist_label.first;
if(vector_query.similarity_cutoff > 0 && vec_dist_score > vector_query.similarity_cutoff) {
if(vec_dist_score > vector_query.distance_threshold) {
continue;
}
@ -3105,10 +3105,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
dist_label.first;
if(vector_query.similarity_cutoff > 0) {
if(vec_dist_score > vector_query.similarity_cutoff) {
continue;
}
if(vec_dist_score > vector_query.distance_threshold) {
continue;
}
vec_results.emplace_back(seq_id, vec_dist_score);
}

View File

@ -146,13 +146,13 @@ Option<bool> VectorQueryOps::parse_vector_query_str(std::string vector_query_str
vector_query.flat_search_cutoff = std::stoi(param_kv[1]);
}
if(param_kv[0] == "similarity_cutoff") {
if(!StringUtils::is_float(param_kv[1])) {
if(param_kv[0] == "distance_threshold") {
if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 2.0) {
return Option<bool>(400, "Malformed vector query string: "
"`similarity_cutoff` parameter must be a float.");
"`distance_threshold` parameter must be a float between 0.0-2.0.");
}
vector_query.similarity_cutoff = std::stof(param_kv[1]);
vector_query.distance_threshold = std::stof(param_kv[1]);
}
}

View File

@ -1446,3 +1446,70 @@ TEST_F(CollectionSchemaChangeTest, GeoFieldSchemaAddition) {
ASSERT_TRUE(res_op.ok());
ASSERT_EQ(2, res_op.get()["found"].get<size_t>());
}
TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"}
]
})"_json;
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json update_schema = R"({
"fields": [
{"name": "embedding", "type":"float[]", "create_from": ["names"]}
]
})"_json;
auto res = coll->alter(update_schema);
ASSERT_FALSE(res.ok());
ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error());
}
TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type":"string"},
{"name": "embedding", "type":"float[]", "create_from": ["names","category"]}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
LOG(INFO) << "Created collection";
auto schema_changes = R"({
"fields": [
{"name": "names", "drop": true}
]
})"_json;
LOG(INFO) << "Dropping field";
auto embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(2, embedding_fields["embedding"].create_from.size());
LOG(INFO) << "Before alter";
auto alter_op = coll->alter(schema_changes);
ASSERT_TRUE(alter_op.ok());
LOG(INFO) << "After alter";
embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(1, embedding_fields["embedding"].create_from.size());
ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]);
}

View File

@ -4620,7 +4620,7 @@ TEST_F(CollectionTest, SemanticSearchTest) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4655,7 +4655,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4686,7 +4686,7 @@ TEST_F(CollectionTest, HybridSearch) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4719,7 +4719,7 @@ TEST_F(CollectionTest, EmbedFielsTest) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4747,7 +4747,7 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4821,7 +4821,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4853,7 +4853,7 @@ TEST_F(CollectionTest, EmbeddingFieldsMapTest) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4891,7 +4891,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) {
]
})"_json;
TextEmbedderManager::model_dir = "/tmp/typesense_test/models";
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
@ -4907,28 +4907,29 @@ TEST_F(CollectionTest, EmbedStringArrayField) {
ASSERT_TRUE(add_op.ok());
}
TEST_F(CollectionTest, UpdateSchemaWithNewEmbeddingField) {
TEST_F(CollectionTest, MissingFieldForEmbedding) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"}
]
})"_json;
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type": "string"},
{"name": "embedding", "type":"float[]", "create_from": ["names", "category"]}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json update_schema = R"({
"fields": [
{"name": "embedding", "type":"float[]", "create_from": ["names"]}
]
})"_json;
auto res = coll->alter(update_schema);
ASSERT_FALSE(res.ok());
ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error());
}
nlohmann::json doc;
doc["names"].push_back("butter");
doc["names"].push_back("butterfly");
doc["names"].push_back("butterball");
auto add_op = coll->add(doc.dump());
ASSERT_FALSE(add_op.ok());
ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error());
}

View File

@ -716,4 +716,64 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
ASSERT_EQ(1, results_op.get()["hits"].size());
}
TEST_F(CollectionVectorTest, DistanceThresholdTest) {
nlohmann::json schema = R"({
"name": "test",
"fields": [
{"name": "vec", "type": "float[]", "num_dim": 3}
]
})"_json;
Collection* coll1 = collectionManager.create_collection(schema).get();
nlohmann::json doc;
doc["vec"] = {0.1, 0.2, 0.3};
ASSERT_TRUE(coll1->add(doc.dump()).ok());
// write a vector which is 0.5 away from the first vector
doc["vec"] = {0.6, 0.7, 0.8};
ASSERT_TRUE(coll1->add(doc.dump()).ok());
auto results_op = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.3,0.4,0.5])");
ASSERT_EQ(true, results_op.ok());
ASSERT_EQ(2, results_op.get()["found"].get<size_t>());
ASSERT_EQ(2, results_op.get()["hits"].size());
ASSERT_FLOAT_EQ(0.6, results_op.get()["hits"][0]["document"]["vec"].get<std::vector<float>>()[0]);
ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get<std::vector<float>>()[1]);
ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get<std::vector<float>>()[2]);
ASSERT_FLOAT_EQ(0.1, results_op.get()["hits"][1]["document"]["vec"].get<std::vector<float>>()[0]);
ASSERT_FLOAT_EQ(0.2, results_op.get()["hits"][1]["document"]["vec"].get<std::vector<float>>()[1]);
ASSERT_FLOAT_EQ(0.3, results_op.get()["hits"][1]["document"]["vec"].get<std::vector<float>>()[2]);
results_op = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.3,0.4,0.5], distance_threshold:0.01)");
ASSERT_EQ(true, results_op.ok());
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
ASSERT_EQ(1, results_op.get()["hits"].size());
ASSERT_FLOAT_EQ(0.6, results_op.get()["hits"][0]["document"]["vec"].get<std::vector<float>>()[0]);
ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get<std::vector<float>>()[1]);
ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get<std::vector<float>>()[2]);
}