diff --git a/include/field.h b/include/field.h index d34d32ae..0f959b7d 100644 --- a/include/field.h +++ b/include/field.h @@ -424,10 +424,11 @@ struct field { std::string& fallback_field_type, std::vector& the_fields); - static Option validate_and_init_embed_fields(const std::vector>& embed_json_field_indices, - const tsl::htrie_map& search_schema, - nlohmann::json& fields_json, - std::vector& fields_vec); + static Option validate_and_init_embed_field(const tsl::htrie_map& search_schema, + nlohmann::json& field_json, + const nlohmann::json& fields_json, + field& the_field); + static bool flatten_obj(nlohmann::json& doc, nlohmann::json& value, bool has_array, bool has_obj_array, bool is_update, const field& the_field, const std::string& flat_name, diff --git a/src/collection.cpp b/src/collection.cpp index ba00b5ab..9140e8bb 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -4156,7 +4156,6 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, } std::unordered_map new_dynamic_fields; - std::vector> embed_json_field_indices; int json_array_index = -1; for(const auto& kv: schema_changes["fields"].items()) { @@ -4244,7 +4243,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return parse_op; } - const auto& f = diff_fields.back(); + auto& f = diff_fields.back(); if(f.is_dynamic()) { new_dynamic_fields[f.name] = f; @@ -4252,6 +4251,14 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema[f.name] = f; } + if(!f.embed.empty()) { + auto validate_res = field::validate_and_init_embed_field(search_schema, schema_changes["fields"][json_array_index], schema_changes["fields"], f); + + if(!validate_res.ok()) { + return validate_res; + } + } + if(is_reindex) { reindex_fields.push_back(f); } else { @@ -4286,9 +4293,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, } } - if(!f.embed.empty() && !diff_fields.empty()) { - embed_json_field_indices.emplace_back(json_array_index, diff_fields.size()-1); - } + } else { // partial update is not supported for now @@ -4298,12 +4303,6 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, } } - auto validation_op = field::validate_and_init_embed_fields(embed_json_field_indices, search_schema, - schema_changes["fields"], diff_fields); - if(!validation_op.ok()) { - return validation_op; - } - if(num_auto_detect_fields > 1) { return Option(400, "There can be only one field named `.*`."); } diff --git a/src/field.cpp b/src/field.cpp index 3c4d2e1c..d053e578 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -1083,7 +1083,7 @@ void field::compact_nested_fields(tsl::htrie_map& nested_fields) { Option field::json_fields_to_fields(bool enable_nested_fields, nlohmann::json &fields_json, string &fallback_field_type, std::vector& the_fields) { size_t num_auto_detect_fields = 0; - std::vector> embed_json_field_indices; + const tsl::htrie_map dummy_search_schema; for(size_t i = 0; i < fields_json.size(); i++) { nlohmann::json& field_json = fields_json[i]; @@ -1094,17 +1094,13 @@ Option field::json_fields_to_fields(bool enable_nested_fields, nlohmann::j } if(!the_fields.empty() && !the_fields.back().embed.empty()) { - embed_json_field_indices.emplace_back(i, the_fields.size()-1); + auto validate_res = validate_and_init_embed_field(dummy_search_schema, field_json, fields_json, the_fields.back()); + if(!validate_res.ok()) { + return validate_res; + } } } - const tsl::htrie_map dummy_search_schema; - auto validation_op = field::validate_and_init_embed_fields(embed_json_field_indices, dummy_search_schema, - fields_json, the_fields); - if(!validation_op.ok()) { - return validation_op; - } - if(num_auto_detect_fields > 1) { return Option(400,"There can be only one field named `.*`."); } @@ -1112,49 +1108,47 @@ Option field::json_fields_to_fields(bool enable_nested_fields, nlohmann::j return Option(true); } -Option field::validate_and_init_embed_fields(const std::vector>& embed_json_field_indices, - const tsl::htrie_map& search_schema, - nlohmann::json& fields_json, - std::vector& fields_vec) { - - for(const auto& json_field_index: embed_json_field_indices) { - auto& field_json = fields_json[json_field_index.first]; - const std::string err_msg = "Property `" + fields::embed + "." + fields::from + +Option field::validate_and_init_embed_field(const tsl::htrie_map& search_schema, nlohmann::json& field_json, + const nlohmann::json& fields_json, + field& the_field) { + const std::string err_msg = "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields."; - for(auto& field_name : field_json[fields::embed][fields::from].get>()) { - auto embed_field = std::find_if(fields_json.begin(), fields_json.end(), [&field_name](const nlohmann::json& x) { - return x["name"].get() == field_name; - }); + for(auto& field_name : field_json[fields::embed][fields::from].get>()) { - if(embed_field == fields_json.end()) { - const auto& embed_field2 = search_schema.find(field_name); - if (embed_field2 == search_schema.end()) { - return Option(400, err_msg); - } else if (embed_field2->type != field_types::STRING && embed_field2->type != field_types::STRING_ARRAY) { - return Option(400, err_msg); - } - } else if((*embed_field)[fields::type] != field_types::STRING && - (*embed_field)[fields::type] != field_types::STRING_ARRAY) { + auto embed_field = std::find_if(fields_json.begin(), fields_json.end(), [&field_name](const nlohmann::json& x) { + return x["name"].get() == field_name; + }); + + + if(embed_field == fields_json.end()) { + const auto& embed_field2 = search_schema.find(field_name); + if (embed_field2 == search_schema.end()) { + return Option(400, err_msg); + } else if (embed_field2->type != field_types::STRING && embed_field2->type != field_types::STRING_ARRAY) { return Option(400, err_msg); } + } else if((*embed_field)[fields::type] != field_types::STRING && + (*embed_field)[fields::type] != field_types::STRING_ARRAY) { + return Option(400, err_msg); } - - const auto& model_config = field_json[fields::embed][fields::model_config]; - size_t num_dim = 0; - auto res = TextEmbedderManager::get_instance().validate_and_init_model(model_config, num_dim); - if(!res.ok()) { - return Option(res.code(), res.error()); - } - - LOG(INFO) << "Model init done."; - field_json[fields::num_dim] = num_dim; - fields_vec[json_field_index.second].num_dim = num_dim; } + const auto& model_config = field_json[fields::embed][fields::model_config]; + size_t num_dim = 0; + auto res = TextEmbedderManager::get_instance().validate_and_init_model(model_config, num_dim); + if(!res.ok()) { + return Option(res.code(), res.error()); + } + + LOG(INFO) << "Model init done."; + field_json[fields::num_dim] = num_dim; + the_field.num_dim = num_dim; + return Option(true); } + void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { auto lenA = a.count, lenB = b.count; if (lenA == 0 || lenB == 0) { diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 61244054..126655f9 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -117,6 +117,11 @@ embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote input_shapes.push_back({1, static_cast(encoded_input.input_ids.size())}); input_shapes.push_back({1, static_cast(encoded_input.attention_mask.size())}); if(session_->GetInputCount() == 3) { + // edge case: xlm_roberta does not have token_type_ids, but if the model has it as input, we need to fill it with 0s + if(encoded_input.token_type_ids.size() == 0) { + encoded_input.token_type_ids.resize(encoded_input.input_ids.size(), 0); + } + input_shapes.push_back({1, static_cast(encoded_input.token_type_ids.size())}); } input_tensors.push_back(Ort::Value::CreateTensor(memory_info, encoded_input.input_ids.data(), encoded_input.input_ids.size(), input_shapes[0].data(), input_shapes[0].size())); diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index 0d8364fd..c936ce78 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1566,6 +1566,14 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { ASSERT_TRUE(res.ok()); ASSERT_EQ(1, coll->get_embedding_fields().size()); + auto search_schema = coll->get_schema(); + + auto embedding_field_it = search_schema.find("embedding"); + ASSERT_TRUE(embedding_field_it != coll->get_schema().end()); + ASSERT_EQ("embedding", embedding_field_it.value().name); + ASSERT_EQ("float[]", embedding_field_it.value().type); + ASSERT_EQ(384, embedding_field_it.value().num_dim); + nlohmann::json doc; doc["names"] = {"hello", "world"}; auto add_op = coll->add(doc.dump()); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index ddd5a16f..4b77bb95 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -1718,4 +1718,41 @@ TEST_F(CollectionVectorTest, TestDifferentOpenAIApiKeys) { ASSERT_NE(embedder_map.find("openai/text-embedding-ada-002:" + api_key1), embedder_map.end()); ASSERT_NE(embedder_map.find("openai/text-embedding-ada-002:" + api_key2), embedder_map.end()); ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002"), embedder_map.end()); +} + + +TEST_F(CollectionVectorTest, TestMultilingualE5) { + auto schema_json = + R"({ + "name": "TEST", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/multilingual-e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "name": "john doe" + })"_json.dump()); + + auto hybrid_results = coll1->search("john", {"name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()); + + ASSERT_TRUE(hybrid_results.ok()); + + auto semantic_results = coll1->search("john", {"embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()); + + ASSERT_TRUE(semantic_results.ok()); } \ No newline at end of file