Merge pull request #971 from ozanarmagan/v0.25-join

Fix for text embedding prefixes & adding text embedding fields with schema update
This commit is contained in:
Kishore Nallan 2023-04-12 08:13:41 +05:30 committed by GitHub
commit 412df009ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 6 deletions

View File

@ -1186,7 +1186,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(search_field.model_name.size() > 0 ? search_field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME);
std::vector<float> embedding = embedder->Embed(raw_query);
std::string embed_query = "query: " + raw_query;
std::vector<float> embedding = embedder->Embed(embed_query);
vector_query._reset();
vector_query.values = embedding;
vector_query.field_name = field_name;
@ -4129,7 +4131,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
}
if(!f.embed_from.empty()) {
return Option<bool>(400, "Embedding fields can only be added at the time of collection creation.");
embedding_fields.emplace(f.name, f);
}
if(f.nested && enable_nested_fields) {
@ -4144,7 +4146,7 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value());
if(!prefix_kv.value().embed_from.empty()) {
return Option<bool>(400, "Embedding fields can only be added at the time of collection creation.");
embedding_fields.emplace(prefix_kv.key(), prefix_kv.value());
}
if(is_reindex) {

View File

@ -6337,7 +6337,7 @@ Option<bool> Index::embed_fields(nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema) {
for(const auto& field : embedding_fields) {
std::string text_to_embed;
std::string text_to_embed = "passage: ";
for(const auto& field_name : field.embed_from) {
auto field_it = search_schema.find(field_name);
if(field_it.value().type == field_types::STRING) {

View File

@ -1455,10 +1455,13 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) {
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
TextEmbedderManager::download_default_model();
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
nlohmann::json update_schema = R"({
"fields": [
@ -1468,8 +1471,17 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) {
auto res = coll->alter(update_schema);
ASSERT_FALSE(res.ok());
ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error());
ASSERT_TRUE(res.ok());
ASSERT_EQ(1, coll->get_embedding_fields().size());
nlohmann::json doc;
doc["names"] = {"hello", "world"};
auto add_op = coll->add(doc.dump());
ASSERT_TRUE(add_op.ok());
auto added_doc = add_op.get();
ASSERT_EQ(384, added_doc["embedding"].get<std::vector<float>>().size());
}
TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {