Add remote_embeding_batch_size parameter for doc import

This commit is contained in:
ozanarmagan 2023-07-08 13:11:45 +03:00
parent a769eeb0a7
commit ab0be22489
9 changed files with 55 additions and 31 deletions

View File

@ -385,7 +385,7 @@ public:
const std::vector<sort_by>& sort_by_fields);
void batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out, size_t &num_indexed,
const bool& return_doc, const bool& return_id);
const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size = 200);
bool is_exceeding_memory_threshold() const;
@ -398,7 +398,7 @@ public:
nlohmann::json get_summary_json() const;
size_t batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings = true);
size_t batch_index_in_memory(std::vector<index_record>& index_records, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true);
Option<nlohmann::json> add(const std::string & json_str,
const index_operation_t& operation=CREATE, const std::string& id="",
@ -407,7 +407,7 @@ public:
nlohmann::json add_many(std::vector<std::string>& json_lines, nlohmann::json& document,
const index_operation_t& operation=CREATE, const std::string& id="",
const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT,
const bool& return_doc=false, const bool& return_id=false);
const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=100);
Option<nlohmann::json> update_matching_filter(const std::string& filter_query,
const std::string & json_str,

View File

@ -545,7 +545,7 @@ private:
static void batch_embed_fields(std::vector<index_record*>& documents,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema);
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size = 200);
public:
// for limiting number of results on multiple candidates / query rewrites
@ -678,7 +678,7 @@ public:
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
const bool do_validation, const bool generate_embeddings = true);
const bool do_validation, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true);
static size_t batch_memory_index(Index *index,
std::vector<index_record>& iter_batch,
@ -688,7 +688,7 @@ public:
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
const bool do_validation, const bool generate_embeddings = true);
const bool do_validation, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true);
void index_field_in_memory(const field& afield, std::vector<index_record>& iter_batch);

View File

@ -17,7 +17,7 @@ class TextEmbedder {
TextEmbedder(const nlohmann::json& model_config);
~TextEmbedder();
embedding_res_t Embed(const std::string& text);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 0);
const std::string& get_vocab_file_name() const;
bool is_remote() {
return remote_embedder_ != nullptr;

View File

@ -29,7 +29,7 @@ class RemoteEmbedder {
static inline ReplicationState* raft_server = nullptr;
public:
virtual embedding_res_t Embed(const std::string& text) = 0;
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) = 0;
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) = 0;
static void init(ReplicationState* rs) {
raft_server = rs;
}
@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder {
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
embedding_res_t Embed(const std::string& text) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
};
@ -63,7 +63,7 @@ class GoogleEmbedder : public RemoteEmbedder {
GoogleEmbedder(const std::string& google_api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
embedding_res_t Embed(const std::string& text) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
};
@ -89,7 +89,7 @@ class GCPEmbedder : public RemoteEmbedder {
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret);
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
embedding_res_t Embed(const std::string& text) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
};

View File

@ -295,7 +295,8 @@ Option<nlohmann::json> Collection::add(const std::string & json_str,
nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohmann::json& document,
const index_operation_t& operation, const std::string& id,
const DIRTY_VALUES& dirty_values, const bool& return_doc, const bool& return_id) {
const DIRTY_VALUES& dirty_values, const bool& return_doc, const bool& return_id,
const size_t remote_embedding_batch_size) {
//LOG(INFO) << "Memory ratio. Max = " << max_memory_ratio << ", Used = " << SystemMetrics::used_memory_ratio();
std::vector<index_record> index_records;
@ -385,7 +386,7 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
if((i+1) % index_batch_size == 0 || i == json_lines.size()-1 || repeated_doc) {
batch_index(index_records, json_lines, num_indexed, return_doc, return_id);
batch_index(index_records, json_lines, num_indexed, return_doc, return_id, remote_embedding_batch_size);
// to return the document for the single doc add cases
if(index_records.size() == 1) {
@ -502,9 +503,9 @@ bool Collection::is_exceeding_memory_threshold() const {
}
void Collection::batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out,
size_t &num_indexed, const bool& return_doc, const bool& return_id) {
size_t &num_indexed, const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size) {
batch_index_in_memory(index_records);
batch_index_in_memory(index_records, remote_embedding_batch_size);
// store only documents that were indexed in-memory successfully
for(auto& index_record: index_records) {
@ -608,11 +609,11 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
return Option<>(200);
}
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings) {
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records, const size_t remote_embedding_batch_size, const bool generate_embeddings) {
std::unique_lock lock(mutex);
size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field,
search_schema, embedding_fields, fallback_field_type,
token_separators, symbols_to_index, true, generate_embeddings);
token_separators, symbols_to_index, true, remote_embedding_batch_size, generate_embeddings);
num_documents += num_indexed;
return num_indexed;
}
@ -3807,7 +3808,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
}
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
fallback_field_type, token_separators, symbols_to_index, true, false);
fallback_field_type, token_separators, symbols_to_index, true, 100, false);
iter_batch.clear();
}

View File

@ -745,11 +745,16 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
const char *DIRTY_VALUES = "dirty_values";
const char *RETURN_DOC = "return_doc";
const char *RETURN_ID = "return_id";
const char *REMOTE_EMBEDDING_BATCH_SIZE = "remote_embedding_batch_size";
if(req->params.count(BATCH_SIZE) == 0) {
req->params[BATCH_SIZE] = "40";
}
if(req->params.count(REMOTE_EMBEDDING_BATCH_SIZE) == 0) {
req->params[REMOTE_EMBEDDING_BATCH_SIZE] = "200";
}
if(req->params.count(ACTION) == 0) {
req->params[ACTION] = "create";
}
@ -796,6 +801,7 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
}
const size_t IMPORT_BATCH_SIZE = std::stoi(req->params[BATCH_SIZE]);
const size_t REMOTE_EMBEDDING_BATCH_SIZE_VAL = std::stoi(req->params[REMOTE_EMBEDDING_BATCH_SIZE]);
if(IMPORT_BATCH_SIZE == 0) {
res->final = true;
@ -804,6 +810,13 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
return false;
}
if(REMOTE_EMBEDDING_BATCH_SIZE_VAL == 0) {
res->final = true;
res->set_400("Parameter `" + std::string(REMOTE_EMBEDDING_BATCH_SIZE) + "` must be a positive integer.");
stream_response(req, res);
return false;
}
if(req->body_index == 0) {
// will log for every major chunk of request body
//LOG(INFO) << "Import, req->body.size=" << req->body.size() << ", batch_size=" << IMPORT_BATCH_SIZE;
@ -873,7 +886,7 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
const bool& return_doc = req->params[RETURN_DOC] == "true";
const bool& return_id = req->params[RETURN_ID] == "true";
nlohmann::json json_res = collection->add_many(json_lines, document, operation, "",
dirty_values, return_doc, return_id);
dirty_values, return_doc, return_id, REMOTE_EMBEDDING_BATCH_SIZE_VAL);
//const std::string& import_summary_json = json_res->dump();
//response_stream << import_summary_json << "\n";

View File

@ -416,7 +416,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
const bool do_validation, const bool generate_embeddings) {
const bool do_validation, const size_t remote_embedding_batch_size, const bool generate_embeddings) {
// runs in a partitioned thread
std::vector<index_record*> records_to_embed;
@ -505,7 +505,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
}
}
if(generate_embeddings) {
batch_embed_fields(records_to_embed, embedding_fields, search_schema);
batch_embed_fields(records_to_embed, embedding_fields, search_schema, remote_embedding_batch_size);
}
}
@ -516,7 +516,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
const std::string& fallback_field_type,
const std::vector<char>& token_separators,
const std::vector<char>& symbols_to_index,
const bool do_validation, const bool generate_embeddings) {
const bool do_validation, const size_t remote_embedding_batch_size, const bool generate_embeddings) {
const size_t concurrency = 4;
const size_t num_threads = std::min(concurrency, iter_batch.size());
@ -546,7 +546,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
index->thread_pool->enqueue([&, batch_index, batch_len]() {
write_log_index = local_write_log_index;
validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema,
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, generate_embeddings);
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, remote_embedding_batch_size, generate_embeddings);
std::unique_lock<std::mutex> lock(m_process);
num_processed++;
@ -6468,7 +6468,7 @@ bool Index::common_results_exist(std::vector<art_leaf*>& leaves, bool must_match
void Index::batch_embed_fields(std::vector<index_record*>& records,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema) {
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size) {
for(const auto& field : embedding_fields) {
std::vector<std::pair<index_record*, std::string>> texts_to_embed;
auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
@ -6529,7 +6529,7 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
texts.push_back(text_to_embed.second);
}
auto embeddings = embedder_op.get()->batch_embed(texts);
auto embeddings = embedder_op.get()->batch_embed(texts, remote_embedding_batch_size);
for(size_t i = 0; i < embeddings.size(); i++) {
auto& embedding_res = embeddings[i];

View File

@ -133,7 +133,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text) {
}
}
std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
std::vector<embedding_res_t> outputs;
if(!is_remote()) {
for(int i = 0; i < inputs.size(); i += 8) {
@ -215,7 +215,7 @@ std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::st
}
}
} else {
outputs = std::move(remote_embedder_->batch_embed(inputs));
outputs = std::move(remote_embedder_->batch_embed(inputs, remote_embedding_batch_size));
}
return outputs;

View File

@ -182,7 +182,17 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
}
}
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
// call recursively if inputs larger than remote_embedding_batch_size
if(inputs.size() > remote_embedding_batch_size) {
std::vector<embedding_res_t> outputs;
for(size_t i = 0; i < inputs.size(); i += remote_embedding_batch_size) {
auto batch = std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + remote_embedding_batch_size, inputs.size()));
auto batch_outputs = batch_embed(batch, remote_embedding_batch_size);
outputs.insert(outputs.end(), batch_outputs.begin(), batch_outputs.end());
}
return outputs;
}
nlohmann::json req_body;
req_body["input"] = inputs;
// remove "openai/" prefix
@ -326,7 +336,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
nlohmann::json json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
json_res = nlohmann::json::object();
json_res["error"] = "Malformed response from Google API."
json_res["error"] = "Malformed response from Google API.";
}
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
@ -356,7 +366,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
}
std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
std::vector<embedding_res_t> outputs;
for(auto& input : inputs) {
auto res = Embed(input);
@ -513,7 +523,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) {
}
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
// GCP API has a limit of 5 instances per request
if(inputs.size() > 5) {
std::vector<embedding_res_t> res;