mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 13:42:26 +08:00
Support embedding generation from both image and text fields (#1826)
* Support embedding generation from both image and text fields * Fix tests * Fix test query * Remove unnecessary logs * Remove commented code * Refactoring * Refactor batch_embed_fields * Refactor batch_embed_fields
This commit is contained in:
parent
74e09ed8f5
commit
0b75cdcc3d
@ -565,6 +565,10 @@ private:
|
||||
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size = 200,
|
||||
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2);
|
||||
|
||||
static void process_embed_results(std::vector<std::pair<index_record*, std::string>>& values_to_embed,
|
||||
const index_record* record,
|
||||
const std::vector<embedding_res_t>& embedding_results,
|
||||
size_t& count, const field& the_field);
|
||||
public:
|
||||
// for limiting number of results on multiple candidates / query rewrites
|
||||
enum {TYPO_TOKENS_THRESHOLD = 1};
|
||||
|
@ -741,27 +741,10 @@ Option<bool> field::validate_and_init_embed_field(const tsl::htrie_map<char, fie
|
||||
} else if (embed_field2->type != field_types::STRING && embed_field2->type != field_types::STRING_ARRAY && embed_field2->type != field_types::IMAGE) {
|
||||
return Option<bool>(400, err_msg);
|
||||
}
|
||||
if(embed_field2->type == field_types::IMAGE) {
|
||||
if(found_image_field) {
|
||||
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
|
||||
}
|
||||
if(field_json[fields::embed][fields::from].get<std::vector<std::string>>().size() > 1) {
|
||||
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
|
||||
}
|
||||
found_image_field = true;
|
||||
}
|
||||
} else if((*embed_field)[fields::type] != field_types::STRING &&
|
||||
(*embed_field)[fields::type] != field_types::STRING_ARRAY &&
|
||||
(*embed_field)[fields::type] != field_types::IMAGE) {
|
||||
return Option<bool>(400, err_msg);
|
||||
} else if((*embed_field)[fields::type] == field_types::IMAGE) {
|
||||
if(found_image_field) {
|
||||
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
|
||||
}
|
||||
if(field_json[fields::embed][fields::from].get<std::vector<std::string>>().size() > 1) {
|
||||
return Option<bool>(400, "Only one field can be used in the `embed.from` property of an embed field when embedding from an image field.");
|
||||
}
|
||||
found_image_field = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
105
src/index.cpp
105
src/index.cpp
@ -7419,8 +7419,7 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size,
|
||||
const size_t remote_embedding_timeout_ms, const size_t remote_embedding_num_tries) {
|
||||
for(const auto& field : embedding_fields) {
|
||||
std::vector<std::pair<index_record*, std::string>> values_to_embed;
|
||||
bool is_image_embedding = false;
|
||||
std::vector<std::pair<index_record*, std::string>> values_to_embed_text, values_to_embed_image;
|
||||
auto indexing_prefix = EmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
|
||||
for(auto& record : records) {
|
||||
if(!record->indexed.ok()) {
|
||||
@ -7451,8 +7450,7 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
continue;
|
||||
}
|
||||
if(field_it.value().type == field_types::IMAGE) {
|
||||
is_image_embedding = true;
|
||||
value = doc_field_it->get<std::string>();
|
||||
values_to_embed_image.push_back(std::make_pair(record, doc_field_it->get<std::string>()));
|
||||
continue;
|
||||
}
|
||||
if(field_it.value().type == field_types::STRING) {
|
||||
@ -7464,19 +7462,19 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
}
|
||||
}
|
||||
if(value != indexing_prefix) {
|
||||
values_to_embed.push_back(std::make_pair(record, value));
|
||||
values_to_embed_text.push_back(std::make_pair(record, value));
|
||||
}
|
||||
}
|
||||
|
||||
if(values_to_embed.empty()) {
|
||||
if(values_to_embed_text.empty() && values_to_embed_image.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<embedding_res_t> embeddings;
|
||||
std::vector<embedding_res_t> embeddings_text, embeddings_image;
|
||||
|
||||
// sort texts by length
|
||||
if(!is_image_embedding) {
|
||||
std::sort(values_to_embed.begin(), values_to_embed.end(),
|
||||
if(!values_to_embed_text.empty()) {
|
||||
std::sort(values_to_embed_text.begin(), values_to_embed_text.end(),
|
||||
[](const std::pair<index_record*, std::string>& a,
|
||||
const std::pair<index_record*, std::string>& b) {
|
||||
return a.second.size() < b.second.size();
|
||||
@ -7484,13 +7482,21 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
}
|
||||
|
||||
// get vector of values
|
||||
std::vector<std::string> values;
|
||||
for(const auto& value_to_embed : values_to_embed) {
|
||||
values.push_back(value_to_embed.second);
|
||||
std::vector<std::string> values_text, values_image;
|
||||
|
||||
std::unordered_set<index_record*> records_to_index;
|
||||
for(const auto& value_to_embed : values_to_embed_text) {
|
||||
values_text.push_back(value_to_embed.second);
|
||||
records_to_index.insert(value_to_embed.first);
|
||||
}
|
||||
|
||||
for(const auto& value_to_embed : values_to_embed_image) {
|
||||
values_image.push_back(value_to_embed.second);
|
||||
records_to_index.insert(value_to_embed.first);
|
||||
}
|
||||
|
||||
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
|
||||
if(is_image_embedding) {
|
||||
if(!values_image.empty()) {
|
||||
auto embedder_op = embedder_manager.get_image_embedder(field.embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
const std::string& error_msg = "Could not find image embedder for model: " + field.embed[fields::model_config][fields::model_name].get<std::string>();
|
||||
@ -7500,33 +7506,80 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
LOG(ERROR) << "Error: " << error_msg;
|
||||
return;
|
||||
}
|
||||
embeddings = embedder_op.get()->batch_embed(values);
|
||||
} else {
|
||||
embeddings_image = embedder_op.get()->batch_embed(values_image);
|
||||
}
|
||||
|
||||
if(!values_text.empty()) {
|
||||
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
LOG(ERROR) << "Error while getting embedder for model: " << field.embed[fields::model_config];
|
||||
LOG(ERROR) << "Error: " << embedder_op.error();
|
||||
return;
|
||||
}
|
||||
embeddings = embedder_op.get()->batch_embed(values, remote_embedding_batch_size, remote_embedding_timeout_ms,
|
||||
embeddings_text = embedder_op.get()->batch_embed(values_text, remote_embedding_batch_size, remote_embedding_timeout_ms,
|
||||
remote_embedding_num_tries);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < embeddings.size(); i++) {
|
||||
auto& embedding_res = embeddings[i];
|
||||
if(!embedding_res.success) {
|
||||
values_to_embed[i].first->embedding_res = embedding_res.error;
|
||||
values_to_embed[i].first->index_failure(embedding_res.status_code, "");
|
||||
continue;
|
||||
for(auto& record: records_to_index) {
|
||||
size_t count = 0;
|
||||
if(!values_to_embed_text.empty()) {
|
||||
process_embed_results(values_to_embed_text, record, embeddings_text, count, field);
|
||||
}
|
||||
|
||||
if(!values_to_embed_image.empty()) {
|
||||
process_embed_results(values_to_embed_image, record, embeddings_image, count, field);
|
||||
}
|
||||
|
||||
if(count > 1) {
|
||||
auto& doc = record->is_update ? record->new_doc : record->doc;
|
||||
std::vector<float> existing_embedding = doc[field.name].get<std::vector<float>>();
|
||||
// average embeddings
|
||||
for(size_t i = 0; i < existing_embedding.size(); i++) {
|
||||
existing_embedding[i] /= count;
|
||||
}
|
||||
doc[field.name] = existing_embedding;
|
||||
}
|
||||
if(values_to_embed[i].first->is_update) {
|
||||
values_to_embed[i].first->new_doc[field.name] = embedding_res.embedding;
|
||||
}
|
||||
values_to_embed[i].first->doc[field.name] = embedding_res.embedding;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Index::process_embed_results(std::vector<std::pair<index_record*, std::string>>& values_to_embed,
|
||||
const index_record* record,
|
||||
const std::vector<embedding_res_t>& embedding_results,
|
||||
size_t& count, const field& the_field) {
|
||||
for(size_t i = 0; i < values_to_embed.size(); i++) {
|
||||
auto& value_to_embed = values_to_embed[i];
|
||||
if(record == value_to_embed.first) {
|
||||
if(!value_to_embed.first->embedding_res.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(!embedding_results[i].success) {
|
||||
value_to_embed.first->embedding_res = embedding_results[i].error;
|
||||
value_to_embed.first->index_failure(embedding_results[i].status_code, "");
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<float> embedding_vals;
|
||||
auto& doc = value_to_embed.first->is_update ? value_to_embed.first->new_doc : value_to_embed.first->doc;
|
||||
if(doc.count(the_field.name) == 0) {
|
||||
embedding_vals = embedding_results[i].embedding;
|
||||
} else {
|
||||
std::vector<float> existing_embedding = doc[the_field.name].get<std::vector<float>>();
|
||||
// accumulate embeddings
|
||||
for(size_t j = 0; j < existing_embedding.size(); j++) {
|
||||
existing_embedding[j] += embedding_results[i].embedding[j];
|
||||
}
|
||||
embedding_vals = existing_embedding;
|
||||
}
|
||||
|
||||
doc[the_field.name] = embedding_vals;
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Index::repair_hnsw_index() {
|
||||
std::vector<std::string> vector_fields;
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user