Change param name to 'remote_embedding_num_try'

This commit is contained in:
ozanarmagan 2023-07-07 17:11:25 +03:00
parent 6f7efa5d72
commit 8b1aa13ffe
6 changed files with 24 additions and 19 deletions

View File

@ -466,7 +466,7 @@ public:
const size_t page_offset = UINT32_MAX,
const size_t vector_query_hits = 250,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_retry = 2) const;
const size_t remote_embedding_num_try = 2) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -16,7 +16,7 @@ class TextEmbedder {
// Constructor for remote models
TextEmbedder(const nlohmann::json& model_config);
~TextEmbedder();
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_max_retries = 2);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_try = 2);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs);
const std::string& get_vocab_file_name() const;
bool is_remote() {

View File

@ -1109,7 +1109,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t page_offset,
const size_t vector_query_hits,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_retry) const {
const size_t remote_embedding_num_try) const {
std::shared_lock lock(mutex);
@ -1235,10 +1235,15 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
std::string error = "Prefix search is not supported for remote embedders. Please set `prefix=false` as an additional search parameter to disable prefix searching.";
return Option<nlohmann::json>(400, error);
}
if(remote_embedding_num_try == 0) {
std::string error = "`remote-embedding-num-try` must be greater than 0.";
return Option<nlohmann::json>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retry);
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_try);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());

View File

@ -672,7 +672,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *VECTOR_QUERY_HITS = "vector_query_hits";
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
const char* REMOTE_EMBEDDING_NUM_RETRY = "remote_embedding_num_retry";
const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try";
const char *GROUP_BY = "group_by";
const char *GROUP_LIMIT = "group_limit";
@ -828,7 +828,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
size_t vector_query_hits = 250;
size_t remote_embedding_timeout_ms = 30000;
size_t remote_embedding_num_retry = 2;
size_t remote_embedding_num_try = 2;
size_t facet_sample_percent = 100;
size_t facet_sample_threshold = 0;
@ -857,7 +857,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
{VECTOR_QUERY_HITS, &vector_query_hits},
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
{REMOTE_EMBEDDING_NUM_RETRY, &remote_embedding_num_retry},
{REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try},
};
std::unordered_map<std::string, std::string*> str_values = {
@ -1081,7 +1081,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
offset,
vector_query_hits,
remote_embedding_timeout_ms,
remote_embedding_num_retry
remote_embedding_num_try
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -38,16 +38,16 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth
key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size()));
size_t timeout_ms = 30000;
size_t num_retry = 2;
size_t num_try = 2;
if(headers.find("timeout_ms") != headers.end()){
timeout_ms = std::stoul(headers.at("timeout_ms"));
headers.erase("timeout_ms");
}
if(headers.find("num_retry") != headers.end()){
num_retry = std::stoul(headers.at("num_retry"));
headers.erase("num_retry");
if(headers.find("num_try") != headers.end()){
num_try = std::stoul(headers.at("num_try"));
headers.erase("num_try");
}
for(auto& header : headers){
@ -59,7 +59,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth
}
http_proxy_res_t res;
for(size_t i = 0; i < num_retry; i++){
for(size_t i = 0; i < num_try; i++){
res = call(url, method, body, headers, timeout_ms);
if(res.status_code != 408 && res.status_code < 500){

View File

@ -132,13 +132,13 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
return Option<bool>(true);
}
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) {
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) {
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_retry"] = std::to_string(remote_embedder_num_retry);
headers["num_try"] = std::to_string(remote_embedder_num_try);
std::string res;
nlohmann::json req_body;
req_body["input"] = std::vector<std::string>{text};
@ -287,12 +287,12 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
return Option<bool>(true);
}
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) {
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) {
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_retry"] = std::to_string(remote_embedder_num_retry);
headers["num_try"] = std::to_string(remote_embedder_num_try);
std::string res;
nlohmann::json req_body;
req_body["text"] = text;
@ -422,7 +422,7 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
return Option<bool>(true);
}
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) {
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) {
nlohmann::json req_body;
req_body["instances"] = nlohmann::json::array();
nlohmann::json instance;
@ -432,7 +432,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_
headers["Authorization"] = "Bearer " + access_token;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_retry"] = std::to_string(remote_embedder_num_retry);
headers["num_try"] = std::to_string(remote_embedder_num_try);
std::map<std::string, std::string> res_headers;
std::string res;