Support embedding generation from both image and text fields (#1826)

* Support embedding generation from both image and text fields

* Fix tests

* Fix test query

* Remove unnecessary logs

* Remove commented code

* Refactoring

* Refactor batch_embed_fields

* Refactor batch_embed_fields
This commit is contained in:
Ozan Armağan 2024-07-10 14:13:03 +03:00 committed by GitHub
parent 74e09ed8f5
commit 0b75cdcc3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 142 additions and 45 deletions

View File

@ -565,6 +565,10 @@ private:
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size = 200,
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2);
static void process_embed_results(std::vector<std::pair<index_record*, std::string>>& values_to_embed,
const index_record* record,
const std::vector<embedding_res_t>& embedding_results,
size_t& count, const field& the_field);
public:
// for limiting number of results on multiple candidates / query rewrites
enum {TYPO_TOKENS_THRESHOLD = 1};

View File

@ -741,27 +741,10 @@ Option<bool> field::validate_and_init_embed_field(const tsl::htrie_map<char, fie
} else if (embed_field2->type != field_types::STRING && embed_field2->type != field_types::STRING_ARRAY && embed_field2->type != field_types::IMAGE) {
return Option<bool>(400, err_msg);
}
if(embed_field2->type == field_types::IMAGE) {
if(found_image_field) {
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
}
if(field_json[fields::embed][fields::from].get<std::vector<std::string>>().size() > 1) {
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
}
found_image_field = true;
}
} else if((*embed_field)[fields::type] != field_types::STRING &&
(*embed_field)[fields::type] != field_types::STRING_ARRAY &&
(*embed_field)[fields::type] != field_types::IMAGE) {
return Option<bool>(400, err_msg);
} else if((*embed_field)[fields::type] == field_types::IMAGE) {
if(found_image_field) {
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
}
if(field_json[fields::embed][fields::from].get<std::vector<std::string>>().size() > 1) {
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
}
found_image_field = true;
}
}

View File

@ -7419,8 +7419,7 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size,
const size_t remote_embedding_timeout_ms, const size_t remote_embedding_num_tries) {
for(const auto& field : embedding_fields) {
std::vector<std::pair<index_record*, std::string>> values_to_embed;
bool is_image_embedding = false;
std::vector<std::pair<index_record*, std::string>> values_to_embed_text, values_to_embed_image;
auto indexing_prefix = EmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
for(auto& record : records) {
if(!record->indexed.ok()) {
@ -7451,8 +7450,7 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
continue;
}
if(field_it.value().type == field_types::IMAGE) {
is_image_embedding = true;
value = doc_field_it->get<std::string>();
values_to_embed_image.push_back(std::make_pair(record, doc_field_it->get<std::string>()));
continue;
}
if(field_it.value().type == field_types::STRING) {
@ -7464,19 +7462,19 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
}
}
if(value != indexing_prefix) {
values_to_embed.push_back(std::make_pair(record, value));
values_to_embed_text.push_back(std::make_pair(record, value));
}
}
if(values_to_embed.empty()) {
if(values_to_embed_text.empty() && values_to_embed_image.empty()) {
continue;
}
std::vector<embedding_res_t> embeddings;
std::vector<embedding_res_t> embeddings_text, embeddings_image;
// sort texts by length
if(!is_image_embedding) {
std::sort(values_to_embed.begin(), values_to_embed.end(),
if(!values_to_embed_text.empty()) {
std::sort(values_to_embed_text.begin(), values_to_embed_text.end(),
[](const std::pair<index_record*, std::string>& a,
const std::pair<index_record*, std::string>& b) {
return a.second.size() < b.second.size();
@ -7484,13 +7482,21 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
}
// get vector of values
std::vector<std::string> values;
for(const auto& value_to_embed : values_to_embed) {
values.push_back(value_to_embed.second);
std::vector<std::string> values_text, values_image;
std::unordered_set<index_record*> records_to_index;
for(const auto& value_to_embed : values_to_embed_text) {
values_text.push_back(value_to_embed.second);
records_to_index.insert(value_to_embed.first);
}
for(const auto& value_to_embed : values_to_embed_image) {
values_image.push_back(value_to_embed.second);
records_to_index.insert(value_to_embed.first);
}
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
if(is_image_embedding) {
if(!values_image.empty()) {
auto embedder_op = embedder_manager.get_image_embedder(field.embed[fields::model_config]);
if(!embedder_op.ok()) {
const std::string& error_msg = "Could not find image embedder for model: " + field.embed[fields::model_config][fields::model_name].get<std::string>();
@ -7500,33 +7506,80 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
LOG(ERROR) << "Error: " << error_msg;
return;
}
embeddings = embedder_op.get()->batch_embed(values);
} else {
embeddings_image = embedder_op.get()->batch_embed(values_image);
}
if(!values_text.empty()) {
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
if(!embedder_op.ok()) {
LOG(ERROR) << "Error while getting embedder for model: " << field.embed[fields::model_config];
LOG(ERROR) << "Error: " << embedder_op.error();
return;
}
embeddings = embedder_op.get()->batch_embed(values, remote_embedding_batch_size, remote_embedding_timeout_ms,
embeddings_text = embedder_op.get()->batch_embed(values_text, remote_embedding_batch_size, remote_embedding_timeout_ms,
remote_embedding_num_tries);
}
for(size_t i = 0; i < embeddings.size(); i++) {
auto& embedding_res = embeddings[i];
if(!embedding_res.success) {
values_to_embed[i].first->embedding_res = embedding_res.error;
values_to_embed[i].first->index_failure(embedding_res.status_code, "");
continue;
for(auto& record: records_to_index) {
size_t count = 0;
if(!values_to_embed_text.empty()) {
process_embed_results(values_to_embed_text, record, embeddings_text, count, field);
}
if(!values_to_embed_image.empty()) {
process_embed_results(values_to_embed_image, record, embeddings_image, count, field);
}
if(count > 1) {
auto& doc = record->is_update ? record->new_doc : record->doc;
std::vector<float> existing_embedding = doc[field.name].get<std::vector<float>>();
// average embeddings
for(size_t i = 0; i < existing_embedding.size(); i++) {
existing_embedding[i] /= count;
}
doc[field.name] = existing_embedding;
}
if(values_to_embed[i].first->is_update) {
values_to_embed[i].first->new_doc[field.name] = embedding_res.embedding;
}
values_to_embed[i].first->doc[field.name] = embedding_res.embedding;
}
}
}
void Index::process_embed_results(std::vector<std::pair<index_record*, std::string>>& values_to_embed,
const index_record* record,
const std::vector<embedding_res_t>& embedding_results,
size_t& count, const field& the_field) {
for(size_t i = 0; i < values_to_embed.size(); i++) {
auto& value_to_embed = values_to_embed[i];
if(record == value_to_embed.first) {
if(!value_to_embed.first->embedding_res.empty()) {
continue;
}
if(!embedding_results[i].success) {
value_to_embed.first->embedding_res = embedding_results[i].error;
value_to_embed.first->index_failure(embedding_results[i].status_code, "");
continue;
}
std::vector<float> embedding_vals;
auto& doc = value_to_embed.first->is_update ? value_to_embed.first->new_doc : value_to_embed.first->doc;
if(doc.count(the_field.name) == 0) {
embedding_vals = embedding_results[i].embedding;
} else {
std::vector<float> existing_embedding = doc[the_field.name].get<std::vector<float>>();
// accumulate embeddings
for(size_t j = 0; j < existing_embedding.size(); j++) {
existing_embedding[j] += embedding_results[i].embedding[j];
}
embedding_vals = existing_embedding;
}
doc[the_field.name] = embedding_vals;
count++;
}
}
}
void Index::repair_hnsw_index() {
std::vector<std::string> vector_fields;

File diff suppressed because one or more lines are too long