mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 13:42:26 +08:00
Add remote_embeding_batch_size parameter for doc import
This commit is contained in:
parent
a769eeb0a7
commit
ab0be22489
@ -385,7 +385,7 @@ public:
|
||||
const std::vector<sort_by>& sort_by_fields);
|
||||
|
||||
void batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out, size_t &num_indexed,
|
||||
const bool& return_doc, const bool& return_id);
|
||||
const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size = 200);
|
||||
|
||||
bool is_exceeding_memory_threshold() const;
|
||||
|
||||
@ -398,7 +398,7 @@ public:
|
||||
|
||||
nlohmann::json get_summary_json() const;
|
||||
|
||||
size_t batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings = true);
|
||||
size_t batch_index_in_memory(std::vector<index_record>& index_records, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true);
|
||||
|
||||
Option<nlohmann::json> add(const std::string & json_str,
|
||||
const index_operation_t& operation=CREATE, const std::string& id="",
|
||||
@ -407,7 +407,7 @@ public:
|
||||
nlohmann::json add_many(std::vector<std::string>& json_lines, nlohmann::json& document,
|
||||
const index_operation_t& operation=CREATE, const std::string& id="",
|
||||
const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT,
|
||||
const bool& return_doc=false, const bool& return_id=false);
|
||||
const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=100);
|
||||
|
||||
Option<nlohmann::json> update_matching_filter(const std::string& filter_query,
|
||||
const std::string & json_str,
|
||||
|
@ -545,7 +545,7 @@ private:
|
||||
|
||||
static void batch_embed_fields(std::vector<index_record*>& documents,
|
||||
const tsl::htrie_map<char, field>& embedding_fields,
|
||||
const tsl::htrie_map<char, field> & search_schema);
|
||||
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size = 200);
|
||||
|
||||
public:
|
||||
// for limiting number of results on multiple candidates / query rewrites
|
||||
@ -678,7 +678,7 @@ public:
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation, const bool generate_embeddings = true);
|
||||
const bool do_validation, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true);
|
||||
|
||||
static size_t batch_memory_index(Index *index,
|
||||
std::vector<index_record>& iter_batch,
|
||||
@ -688,7 +688,7 @@ public:
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation, const bool generate_embeddings = true);
|
||||
const bool do_validation, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true);
|
||||
|
||||
void index_field_in_memory(const field& afield, std::vector<index_record>& iter_batch);
|
||||
|
||||
|
@ -17,7 +17,7 @@ class TextEmbedder {
|
||||
TextEmbedder(const nlohmann::json& model_config);
|
||||
~TextEmbedder();
|
||||
embedding_res_t Embed(const std::string& text);
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs);
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 0);
|
||||
const std::string& get_vocab_file_name() const;
|
||||
bool is_remote() {
|
||||
return remote_embedder_ != nullptr;
|
||||
|
@ -29,7 +29,7 @@ class RemoteEmbedder {
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
public:
|
||||
virtual embedding_res_t Embed(const std::string& text) = 0;
|
||||
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) = 0;
|
||||
static void init(ReplicationState* rs) {
|
||||
raft_server = rs;
|
||||
}
|
||||
@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder {
|
||||
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
|
||||
};
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class GoogleEmbedder : public RemoteEmbedder {
|
||||
GoogleEmbedder(const std::string& google_api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
|
||||
};
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ class GCPEmbedder : public RemoteEmbedder {
|
||||
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
|
||||
};
|
||||
|
||||
|
||||
|
@ -295,7 +295,8 @@ Option<nlohmann::json> Collection::add(const std::string & json_str,
|
||||
|
||||
nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohmann::json& document,
|
||||
const index_operation_t& operation, const std::string& id,
|
||||
const DIRTY_VALUES& dirty_values, const bool& return_doc, const bool& return_id) {
|
||||
const DIRTY_VALUES& dirty_values, const bool& return_doc, const bool& return_id,
|
||||
const size_t remote_embedding_batch_size) {
|
||||
//LOG(INFO) << "Memory ratio. Max = " << max_memory_ratio << ", Used = " << SystemMetrics::used_memory_ratio();
|
||||
std::vector<index_record> index_records;
|
||||
|
||||
@ -385,7 +386,7 @@ nlohmann::json Collection::add_many(std::vector<std::string>& json_lines, nlohma
|
||||
|
||||
|
||||
if((i+1) % index_batch_size == 0 || i == json_lines.size()-1 || repeated_doc) {
|
||||
batch_index(index_records, json_lines, num_indexed, return_doc, return_id);
|
||||
batch_index(index_records, json_lines, num_indexed, return_doc, return_id, remote_embedding_batch_size);
|
||||
|
||||
// to return the document for the single doc add cases
|
||||
if(index_records.size() == 1) {
|
||||
@ -502,9 +503,9 @@ bool Collection::is_exceeding_memory_threshold() const {
|
||||
}
|
||||
|
||||
void Collection::batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out,
|
||||
size_t &num_indexed, const bool& return_doc, const bool& return_id) {
|
||||
size_t &num_indexed, const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size) {
|
||||
|
||||
batch_index_in_memory(index_records);
|
||||
batch_index_in_memory(index_records, remote_embedding_batch_size);
|
||||
|
||||
// store only documents that were indexed in-memory successfully
|
||||
for(auto& index_record: index_records) {
|
||||
@ -608,11 +609,11 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
|
||||
return Option<>(200);
|
||||
}
|
||||
|
||||
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings) {
|
||||
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records, const size_t remote_embedding_batch_size, const bool generate_embeddings) {
|
||||
std::unique_lock lock(mutex);
|
||||
size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field,
|
||||
search_schema, embedding_fields, fallback_field_type,
|
||||
token_separators, symbols_to_index, true, generate_embeddings);
|
||||
token_separators, symbols_to_index, true, remote_embedding_batch_size, generate_embeddings);
|
||||
num_documents += num_indexed;
|
||||
return num_indexed;
|
||||
}
|
||||
@ -3807,7 +3808,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
}
|
||||
|
||||
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
|
||||
fallback_field_type, token_separators, symbols_to_index, true, false);
|
||||
fallback_field_type, token_separators, symbols_to_index, true, 100, false);
|
||||
|
||||
iter_batch.clear();
|
||||
}
|
||||
|
@ -745,11 +745,16 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
|
||||
const char *DIRTY_VALUES = "dirty_values";
|
||||
const char *RETURN_DOC = "return_doc";
|
||||
const char *RETURN_ID = "return_id";
|
||||
const char *REMOTE_EMBEDDING_BATCH_SIZE = "remote_embedding_batch_size";
|
||||
|
||||
if(req->params.count(BATCH_SIZE) == 0) {
|
||||
req->params[BATCH_SIZE] = "40";
|
||||
}
|
||||
|
||||
if(req->params.count(REMOTE_EMBEDDING_BATCH_SIZE) == 0) {
|
||||
req->params[REMOTE_EMBEDDING_BATCH_SIZE] = "200";
|
||||
}
|
||||
|
||||
if(req->params.count(ACTION) == 0) {
|
||||
req->params[ACTION] = "create";
|
||||
}
|
||||
@ -796,6 +801,7 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
|
||||
}
|
||||
|
||||
const size_t IMPORT_BATCH_SIZE = std::stoi(req->params[BATCH_SIZE]);
|
||||
const size_t REMOTE_EMBEDDING_BATCH_SIZE_VAL = std::stoi(req->params[REMOTE_EMBEDDING_BATCH_SIZE]);
|
||||
|
||||
if(IMPORT_BATCH_SIZE == 0) {
|
||||
res->final = true;
|
||||
@ -804,6 +810,13 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
|
||||
return false;
|
||||
}
|
||||
|
||||
if(REMOTE_EMBEDDING_BATCH_SIZE_VAL == 0) {
|
||||
res->final = true;
|
||||
res->set_400("Parameter `" + std::string(REMOTE_EMBEDDING_BATCH_SIZE) + "` must be a positive integer.");
|
||||
stream_response(req, res);
|
||||
return false;
|
||||
}
|
||||
|
||||
if(req->body_index == 0) {
|
||||
// will log for every major chunk of request body
|
||||
//LOG(INFO) << "Import, req->body.size=" << req->body.size() << ", batch_size=" << IMPORT_BATCH_SIZE;
|
||||
@ -873,7 +886,7 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
|
||||
const bool& return_doc = req->params[RETURN_DOC] == "true";
|
||||
const bool& return_id = req->params[RETURN_ID] == "true";
|
||||
nlohmann::json json_res = collection->add_many(json_lines, document, operation, "",
|
||||
dirty_values, return_doc, return_id);
|
||||
dirty_values, return_doc, return_id, REMOTE_EMBEDDING_BATCH_SIZE_VAL);
|
||||
//const std::string& import_summary_json = json_res->dump();
|
||||
//response_stream << import_summary_json << "\n";
|
||||
|
||||
|
@ -416,7 +416,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation, const bool generate_embeddings) {
|
||||
const bool do_validation, const size_t remote_embedding_batch_size, const bool generate_embeddings) {
|
||||
|
||||
// runs in a partitioned thread
|
||||
std::vector<index_record*> records_to_embed;
|
||||
@ -505,7 +505,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
}
|
||||
}
|
||||
if(generate_embeddings) {
|
||||
batch_embed_fields(records_to_embed, embedding_fields, search_schema);
|
||||
batch_embed_fields(records_to_embed, embedding_fields, search_schema, remote_embedding_batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -516,7 +516,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation, const bool generate_embeddings) {
|
||||
const bool do_validation, const size_t remote_embedding_batch_size, const bool generate_embeddings) {
|
||||
|
||||
const size_t concurrency = 4;
|
||||
const size_t num_threads = std::min(concurrency, iter_batch.size());
|
||||
@ -546,7 +546,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
|
||||
index->thread_pool->enqueue([&, batch_index, batch_len]() {
|
||||
write_log_index = local_write_log_index;
|
||||
validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema,
|
||||
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, generate_embeddings);
|
||||
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, remote_embedding_batch_size, generate_embeddings);
|
||||
|
||||
std::unique_lock<std::mutex> lock(m_process);
|
||||
num_processed++;
|
||||
@ -6468,7 +6468,7 @@ bool Index::common_results_exist(std::vector<art_leaf*>& leaves, bool must_match
|
||||
|
||||
void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
const tsl::htrie_map<char, field>& embedding_fields,
|
||||
const tsl::htrie_map<char, field> & search_schema) {
|
||||
const tsl::htrie_map<char, field> & search_schema, const size_t remote_embedding_batch_size) {
|
||||
for(const auto& field : embedding_fields) {
|
||||
std::vector<std::pair<index_record*, std::string>> texts_to_embed;
|
||||
auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
|
||||
@ -6529,7 +6529,7 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
texts.push_back(text_to_embed.second);
|
||||
}
|
||||
|
||||
auto embeddings = embedder_op.get()->batch_embed(texts);
|
||||
auto embeddings = embedder_op.get()->batch_embed(texts, remote_embedding_batch_size);
|
||||
|
||||
for(size_t i = 0; i < embeddings.size(); i++) {
|
||||
auto& embedding_res = embeddings[i];
|
||||
|
@ -133,7 +133,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
if(!is_remote()) {
|
||||
for(int i = 0; i < inputs.size(); i += 8) {
|
||||
@ -215,7 +215,7 @@ std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::st
|
||||
}
|
||||
}
|
||||
} else {
|
||||
outputs = std::move(remote_embedder_->batch_embed(inputs));
|
||||
outputs = std::move(remote_embedder_->batch_embed(inputs, remote_embedding_batch_size));
|
||||
}
|
||||
|
||||
return outputs;
|
||||
|
@ -182,7 +182,17 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
|
||||
// call recursively if inputs larger than remote_embedding_batch_size
|
||||
if(inputs.size() > remote_embedding_batch_size) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(size_t i = 0; i < inputs.size(); i += remote_embedding_batch_size) {
|
||||
auto batch = std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + remote_embedding_batch_size, inputs.size()));
|
||||
auto batch_outputs = batch_embed(batch, remote_embedding_batch_size);
|
||||
outputs.insert(outputs.end(), batch_outputs.begin(), batch_outputs.end());
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = inputs;
|
||||
// remove "openai/" prefix
|
||||
@ -326,7 +336,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
json_res = nlohmann::json::object();
|
||||
json_res["error"] = "Malformed response from Google API."
|
||||
json_res["error"] = "Malformed response from Google API.";
|
||||
}
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
@ -356,7 +366,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
|
||||
|
||||
std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(auto& input : inputs) {
|
||||
auto res = Embed(input);
|
||||
@ -513,7 +523,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
|
||||
|
||||
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size) {
|
||||
// GCP API has a limit of 5 instances per request
|
||||
if(inputs.size() > 5) {
|
||||
std::vector<embedding_res_t> res;
|
||||
|
Loading…
x
Reference in New Issue
Block a user