Merge pull request #1124 from ozanarmagan/v0.25-join

Change 'remote_embedding_num_try' to 'remote_embedding_num_tries'
This commit is contained in:
Kishore Nallan 2023-07-30 19:30:57 +05:30 committed by GitHub
commit 558734287f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 22 deletions

View File

@ -467,7 +467,7 @@ public:
const size_t facet_sample_threshold = 0,
const size_t page_offset = 0,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_try = 2) const;
const size_t remote_embedding_num_tries = 2) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -16,7 +16,7 @@ class TextEmbedder {
// Constructor for remote models
TextEmbedder(const nlohmann::json& model_config);
~TextEmbedder();
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200);
const std::string& get_vocab_file_name() const;
bool is_remote() {

View File

@ -29,7 +29,7 @@ class RemoteEmbedder {
static inline ReplicationState* raft_server = nullptr;
public:
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) = 0;
virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) = 0;
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) = 0;
static void init(ReplicationState* rs) {
raft_server = rs;
@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder {
public:
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override;
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
};
@ -65,7 +65,7 @@ class GoogleEmbedder : public RemoteEmbedder {
public:
GoogleEmbedder(const std::string& google_api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override;
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
};
@ -92,7 +92,7 @@ class GCPEmbedder : public RemoteEmbedder {
GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token,
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret);
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override;
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
};

View File

@ -1120,7 +1120,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t facet_sample_threshold,
const size_t page_offset,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_try) const {
const size_t remote_embedding_num_tries) const {
std::shared_lock lock(mutex);
@ -1260,14 +1260,14 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
return Option<nlohmann::json>(400, error);
}
if(remote_embedding_num_try == 0) {
std::string error = "`remote-embedding-num-try` must be greater than 0.";
if(remote_embedding_num_tries == 0) {
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
return Option<nlohmann::json>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_try);
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());

View File

@ -696,7 +696,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *VECTOR_QUERY = "vector_query";
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try";
const char* REMOTE_EMBEDDING_NUM_TRIES = "remote_embedding_num_tries";
const char *GROUP_BY = "group_by";
const char *GROUP_LIMIT = "group_limit";
@ -851,7 +851,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
text_match_type_t match_type = max_score;
size_t remote_embedding_timeout_ms = 5000;
size_t remote_embedding_num_try = 2;
size_t remote_embedding_num_tries = 2;
size_t facet_sample_percent = 100;
size_t facet_sample_threshold = 0;
@ -879,7 +879,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
{REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try},
{REMOTE_EMBEDDING_NUM_TRIES, &remote_embedding_num_tries},
};
std::unordered_map<std::string, std::string*> str_values = {
@ -1094,7 +1094,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
facet_sample_threshold,
offset,
remote_embedding_timeout_ms,
remote_embedding_num_try
remote_embedding_num_tries
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -94,9 +94,9 @@ std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<floa
return pooled_output;
}
embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
if(is_remote()) {
return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedding_num_try);
return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedding_num_tries);
} else {
// Cannot run same model in parallel, so lock the mutex
std::lock_guard<std::mutex> lock(mutex_);

View File

@ -149,13 +149,13 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
return Option<bool>(true);
}
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_try"] = std::to_string(remote_embedding_num_try);
headers["num_try"] = std::to_string(remote_embedding_num_tries);
std::string res;
nlohmann::json req_body;
req_body["input"] = std::vector<std::string>{text};
@ -314,12 +314,12 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
return Option<bool>(true);
}
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_try"] = std::to_string(remote_embedding_num_try);
headers["num_try"] = std::to_string(remote_embedding_num_tries);
std::string res;
nlohmann::json req_body;
req_body["text"] = text;
@ -449,7 +449,7 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, siz
return Option<bool>(true);
}
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_tries) {
nlohmann::json req_body;
req_body["instances"] = nlohmann::json::array();
nlohmann::json instance;
@ -459,7 +459,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_
headers["Authorization"] = "Bearer " + access_token;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_try"] = std::to_string(remote_embedding_num_try);
headers["num_try"] = std::to_string(remote_embedding_num_tries);
std::map<std::string, std::string> res_headers;
std::string res;