mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 22:33:27 +08:00
Merge branch 'v0.25-join' into v0.26-filter
# Conflicts: # src/index.cpp
This commit is contained in:
commit
d5048f689b
@ -207,7 +207,8 @@ private:
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
bool is_wildcard_query, bool is_group_by_query = false) const;
|
||||
bool is_wildcard_query,const bool is_vector_query,
|
||||
bool is_group_by_query = false) const;
|
||||
|
||||
|
||||
Option<bool> persist_collection_meta();
|
||||
@ -397,7 +398,7 @@ public:
|
||||
|
||||
nlohmann::json get_summary_json() const;
|
||||
|
||||
size_t batch_index_in_memory(std::vector<index_record>& index_records);
|
||||
size_t batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings = true);
|
||||
|
||||
Option<nlohmann::json> add(const std::string & json_str,
|
||||
const index_operation_t& operation=CREATE, const std::string& id="",
|
||||
@ -462,7 +463,8 @@ public:
|
||||
const text_match_type_t match_type = max_score,
|
||||
const size_t facet_sample_percent = 100,
|
||||
const size_t facet_sample_threshold = 0,
|
||||
const size_t page_offset = UINT32_MAX) const;
|
||||
const size_t page_offset = UINT32_MAX,
|
||||
const size_t vector_query_hits = 250) const;
|
||||
|
||||
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;
|
||||
|
||||
|
@ -550,6 +550,8 @@ namespace sort_field_const {
|
||||
static const std::string precision = "precision";
|
||||
|
||||
static const std::string missing_values = "missing_values";
|
||||
|
||||
static const std::string vector_distance = "_vector_distance";
|
||||
}
|
||||
|
||||
struct sort_by {
|
||||
|
@ -51,7 +51,10 @@ public:
|
||||
|
||||
bool is_res_start = true;
|
||||
h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS;
|
||||
h2o_iovec_t res_body{};
|
||||
|
||||
std::string res_body;
|
||||
h2o_iovec_t res_buff;
|
||||
|
||||
h2o_iovec_t res_content_type{};
|
||||
int status = 0;
|
||||
const char* reason = nullptr;
|
||||
@ -64,8 +67,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void set_response(uint32_t status_code, const std::string& content_type, const std::string& body) {
|
||||
res_body = h2o_strdup(&req->pool, body.c_str(), SIZE_MAX);
|
||||
void set_response(uint32_t status_code, const std::string& content_type, std::string& body) {
|
||||
std::string().swap(res_body);
|
||||
res_body = std::move(body);
|
||||
res_buff = h2o_iovec_t{.base = res_body.data(), .len = res_body.size()};
|
||||
|
||||
if(is_res_start) {
|
||||
res_content_type = h2o_strdup(&req->pool, content_type.c_str(), SIZE_MAX);
|
||||
|
@ -202,6 +202,9 @@ struct index_record {
|
||||
nlohmann::json new_doc; // new *full* document to be stored into disk
|
||||
nlohmann::json del_doc; // document containing the fields that should be deleted
|
||||
|
||||
nlohmann::json embedding_res; // embedding result
|
||||
int embedding_status_code; // embedding status code
|
||||
|
||||
index_operation_t operation;
|
||||
bool is_update;
|
||||
|
||||
@ -344,6 +347,7 @@ private:
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> eval_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> vector_distance_sentinel_value;
|
||||
|
||||
// Internal utility functions
|
||||
|
||||
@ -523,7 +527,7 @@ private:
|
||||
|
||||
void initialize_facet_indexes(const field& facet_field);
|
||||
|
||||
static Option<bool> batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
static void batch_embed_fields(std::vector<index_record*>& documents,
|
||||
const tsl::htrie_map<char, field>& embedding_fields,
|
||||
const tsl::htrie_map<char, field> & search_schema);
|
||||
|
||||
@ -661,7 +665,7 @@ public:
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation);
|
||||
const bool do_validation, const bool generate_embeddings = true);
|
||||
|
||||
static size_t batch_memory_index(Index *index,
|
||||
std::vector<index_record>& iter_batch,
|
||||
@ -671,7 +675,7 @@ public:
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation);
|
||||
const bool do_validation, const bool generate_embeddings = true);
|
||||
|
||||
void index_field_in_memory(const field& afield, std::vector<index_record>& iter_batch);
|
||||
|
||||
@ -929,7 +933,7 @@ public:
|
||||
size_t filter_index,
|
||||
int64_t max_field_match_score,
|
||||
int64_t* scores,
|
||||
int64_t& match_score_index) const;
|
||||
int64_t& match_score_index, float vector_distance = 0) const;
|
||||
|
||||
void process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
const std::vector<uint32_t>& excluded_ids, const std::vector<std::string>& group_by_fields,
|
||||
|
@ -16,8 +16,8 @@ class TextEmbedder {
|
||||
// Constructor for remote models
|
||||
TextEmbedder(const nlohmann::json& model_config);
|
||||
~TextEmbedder();
|
||||
Option<std::vector<float>> Embed(const std::string& text);
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs);
|
||||
embedding_res_t Embed(const std::string& text);
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs);
|
||||
const std::string& get_vocab_file_name() const;
|
||||
bool is_remote() {
|
||||
return remote_embedder_ != nullptr;
|
||||
|
@ -9,6 +9,18 @@
|
||||
|
||||
|
||||
|
||||
struct embedding_res_t {
|
||||
std::vector<float> embedding;
|
||||
nlohmann::json error = nlohmann::json::object();
|
||||
int status_code;
|
||||
bool success;
|
||||
|
||||
embedding_res_t(const std::vector<float>& embedding) : embedding(embedding), success(true) {}
|
||||
|
||||
embedding_res_t(int status_code, const nlohmann::json& error) : error(error), success(false), status_code(status_code) {}
|
||||
};
|
||||
|
||||
|
||||
|
||||
class RemoteEmbedder {
|
||||
protected:
|
||||
@ -16,11 +28,12 @@ class RemoteEmbedder {
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers);
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
public:
|
||||
virtual Option<std::vector<float>> Embed(const std::string& text) = 0;
|
||||
virtual Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
virtual embedding_res_t Embed(const std::string& text) = 0;
|
||||
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
static void init(ReplicationState* rs) {
|
||||
raft_server = rs;
|
||||
}
|
||||
virtual ~RemoteEmbedder() = default;
|
||||
|
||||
};
|
||||
|
||||
@ -34,8 +47,8 @@ class OpenAIEmbedder : public RemoteEmbedder {
|
||||
public:
|
||||
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
@ -49,8 +62,8 @@ class GoogleEmbedder : public RemoteEmbedder {
|
||||
public:
|
||||
GoogleEmbedder(const std::string& google_api_key);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
@ -75,8 +88,8 @@ class GCPEmbedder : public RemoteEmbedder {
|
||||
GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token,
|
||||
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
embedding_res_t Embed(const std::string& text) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
|
@ -30,6 +30,7 @@ class TextEmbeddingTokenizer {
|
||||
private:
|
||||
public:
|
||||
virtual encoded_input_t Encode(const std::string& text) = 0;
|
||||
virtual ~TextEmbeddingTokenizer() = default;
|
||||
};
|
||||
|
||||
class BertTokenizerWrapper : public TextEmbeddingTokenizer {
|
||||
|
@ -15,6 +15,11 @@ struct KV {
|
||||
uint64_t key{};
|
||||
uint64_t distinct_key{};
|
||||
int64_t scores[3]{}; // match score + 2 custom attributes
|
||||
|
||||
// only to be used in hybrid search
|
||||
float vector_distance = 0.0f;
|
||||
int64_t text_match_score = 0;
|
||||
|
||||
reference_filter_result_t* reference_filter_result = nullptr;
|
||||
|
||||
// to be used only in final aggregation
|
||||
@ -43,6 +48,9 @@ struct KV {
|
||||
|
||||
query_indices = kv.query_indices;
|
||||
kv.query_indices = nullptr;
|
||||
|
||||
vector_distance = kv.vector_distance;
|
||||
text_match_score = kv.text_match_score;
|
||||
}
|
||||
|
||||
KV& operator=(KV&& kv) noexcept {
|
||||
@ -60,6 +68,9 @@ struct KV {
|
||||
delete[] query_indices;
|
||||
query_indices = kv.query_indices;
|
||||
kv.query_indices = nullptr;
|
||||
|
||||
vector_distance = kv.vector_distance;
|
||||
text_match_score = kv.text_match_score;
|
||||
}
|
||||
|
||||
return *this;
|
||||
@ -80,6 +91,9 @@ struct KV {
|
||||
delete[] query_indices;
|
||||
query_indices = kv.query_indices;
|
||||
kv.query_indices = nullptr;
|
||||
|
||||
vector_distance = kv.vector_distance;
|
||||
text_match_score = kv.text_match_score;
|
||||
}
|
||||
|
||||
return *this;
|
||||
|
@ -16,7 +16,7 @@ public:
|
||||
const index_operation_t op,
|
||||
const bool is_update,
|
||||
const std::string& fallback_field_type,
|
||||
const DIRTY_VALUES& dirty_values);
|
||||
const DIRTY_VALUES& dirty_values, const bool validate_embedding_fields = true);
|
||||
|
||||
|
||||
static Option<uint32_t> coerce_element(const field& a_field, nlohmann::json& document,
|
||||
|
@ -562,12 +562,22 @@ void Collection::batch_index(std::vector<index_record>& index_records, std::vect
|
||||
if(!index_record.indexed.ok()) {
|
||||
res["document"] = json_out[index_record.position];
|
||||
res["error"] = index_record.indexed.error();
|
||||
if (!index_record.embedding_res.empty()) {
|
||||
res["embedding_error"] = nlohmann::json::object();
|
||||
res["embedding_error"] = index_record.embedding_res;
|
||||
res["error"] = index_record.embedding_res["error"];
|
||||
}
|
||||
res["code"] = index_record.indexed.code();
|
||||
}
|
||||
} else {
|
||||
res["success"] = false;
|
||||
res["document"] = json_out[index_record.position];
|
||||
res["error"] = index_record.indexed.error();
|
||||
if (!index_record.embedding_res.empty()) {
|
||||
res["embedding_error"] = nlohmann::json::object();
|
||||
res["error"] = index_record.embedding_res["error"];
|
||||
res["embedding_error"] = index_record.embedding_res;
|
||||
}
|
||||
res["code"] = index_record.indexed.code();
|
||||
}
|
||||
|
||||
@ -599,11 +609,11 @@ Option<uint32_t> Collection::index_in_memory(nlohmann::json &document, uint32_t
|
||||
return Option<>(200);
|
||||
}
|
||||
|
||||
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records) {
|
||||
size_t Collection::batch_index_in_memory(std::vector<index_record>& index_records, const bool generate_embeddings) {
|
||||
std::unique_lock lock(mutex);
|
||||
size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field,
|
||||
search_schema, embedding_fields, fallback_field_type,
|
||||
token_separators, symbols_to_index, true);
|
||||
token_separators, symbols_to_index, true, generate_embeddings);
|
||||
num_documents += num_indexed;
|
||||
return num_indexed;
|
||||
}
|
||||
@ -732,6 +742,7 @@ void Collection::curate_results(string& actual_query, const string& filter_query
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const bool is_group_by_query) const {
|
||||
|
||||
size_t num_sort_expressions = 0;
|
||||
@ -906,7 +917,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found) {
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
|
||||
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
@ -920,6 +931,11 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
std::string error = "group_by parameters should not be empty when using sort_by group_found";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) {
|
||||
std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
StringUtils::toupper(sort_field_std.order);
|
||||
|
||||
@ -942,6 +958,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
|
||||
}
|
||||
|
||||
if(is_vector_query) {
|
||||
sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc);
|
||||
}
|
||||
|
||||
if(!default_sorting_field.empty()) {
|
||||
sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc);
|
||||
} else {
|
||||
@ -950,9 +970,15 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
|
||||
bool found_match_score = false;
|
||||
bool found_vector_distance = false;
|
||||
for(const auto & sort_field : sort_fields_std) {
|
||||
if(sort_field.name == sort_field_const::text_match) {
|
||||
found_match_score = true;
|
||||
}
|
||||
if(sort_field.name == sort_field_const::vector_distance) {
|
||||
found_vector_distance = true;
|
||||
}
|
||||
if(found_match_score && found_vector_distance) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -961,6 +987,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc);
|
||||
}
|
||||
|
||||
if(!found_vector_distance && is_vector_query && sort_fields.size() < 3) {
|
||||
sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc);
|
||||
}
|
||||
|
||||
if(sort_fields_std.size() > 3) {
|
||||
std::string message = "Only upto 3 sort_by fields can be specified.";
|
||||
return Option<bool>(422, message);
|
||||
@ -1077,7 +1107,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const text_match_type_t match_type,
|
||||
const size_t facet_sample_percent,
|
||||
const size_t facet_sample_threshold,
|
||||
const size_t page_offset) const {
|
||||
const size_t page_offset,
|
||||
const size_t vector_query_hits) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
@ -1132,6 +1163,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
group_limit = 0;
|
||||
}
|
||||
|
||||
|
||||
vector_query_t vector_query;
|
||||
if(!vector_query_str.empty()) {
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this);
|
||||
@ -1185,8 +1217,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
// }
|
||||
|
||||
if(raw_query == "*") {
|
||||
std::string error = "Wildcard query is not supported for embedding fields.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
// ignore embedding field if query is a wildcard
|
||||
continue;
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
@ -1206,14 +1238,18 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
|
||||
auto embedding_op = embedder->Embed(embed_query);
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<nlohmann::json>(400, embedding_op.error());
|
||||
if(!embedding_op.success) {
|
||||
if(!embedding_op.error["error"].get<std::string>().empty()) {
|
||||
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());
|
||||
} else {
|
||||
return Option<nlohmann::json>(400, embedding_op.error.dump());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> embedding = embedding_op.get();
|
||||
std::vector<float> embedding = embedding_op.embedding;
|
||||
vector_query._reset();
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
vector_query.k = vector_query_hits;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1443,10 +1479,11 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
bool is_wildcard_query = (query == "*");
|
||||
bool is_group_by_query = group_by_fields.size() > 0;
|
||||
bool is_vector_query = !vector_query.field_name.empty();
|
||||
|
||||
if(curated_sort_by.empty()) {
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_group_by_query);
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1458,7 +1495,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_group_by_query);
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1888,7 +1925,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
populate_text_match_info(wrapper_doc["text_match_info"],
|
||||
field_order_kv->scores[field_order_kv->match_score_index], match_type);
|
||||
} else {
|
||||
wrapper_doc["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]);
|
||||
wrapper_doc["hybrid_search_info"] = nlohmann::json::object();
|
||||
wrapper_doc["hybrid_search_info"]["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]);
|
||||
wrapper_doc["hybrid_search_info"]["text_match_score"] = field_order_kv->text_match_score;
|
||||
wrapper_doc["hybrid_search_info"]["vector_distance"] = field_order_kv->vector_distance;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1906,7 +1946,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty() && query == "*") {
|
||||
wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]);
|
||||
wrapper_doc["vector_distance"] = field_order_kv->vector_distance;
|
||||
}
|
||||
|
||||
hits_array.push_back(wrapper_doc);
|
||||
@ -3708,6 +3748,10 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
nested_field_names.push_back(f.name);
|
||||
}
|
||||
|
||||
if(f.embed.count(fields::from) != 0) {
|
||||
embedding_fields.emplace(f.name, f);
|
||||
}
|
||||
|
||||
fields.push_back(f);
|
||||
}
|
||||
|
||||
@ -3762,9 +3806,9 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
index->remove(seq_id, rec.doc, del_fields, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
|
||||
fallback_field_type, token_separators, symbols_to_index, true);
|
||||
fallback_field_type, token_separators, symbols_to_index, true, false);
|
||||
|
||||
iter_batch.clear();
|
||||
}
|
||||
@ -3869,6 +3913,21 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// hide credentials in the alter payload return
|
||||
for(auto& field_json : alter_payload["fields"]) {
|
||||
if(field_json[fields::embed].count(fields::model_config) != 0) {
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "api_key");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "access_token");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "refresh_token");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "client_id");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "client_secret");
|
||||
hide_credential(field_json[fields::embed][fields::model_config], "project_id");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
@ -4212,6 +4271,65 @@ Option<bool> Collection::validate_alter_payload(nlohmann::json& schema_changes,
|
||||
}
|
||||
}
|
||||
|
||||
for(const auto& kv: schema_changes["fields"].items()) {
|
||||
// validate embedding fields externally
|
||||
auto& field_json = kv.value();
|
||||
if(field_json.count(fields::embed) != 0 && !field_json[fields::embed].empty()) {
|
||||
if(!field_json[fields::embed].is_object()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "` must be an object.");
|
||||
}
|
||||
|
||||
if(field_json[fields::embed].count(fields::from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "` must contain a `" + fields::from + "` property.");
|
||||
}
|
||||
|
||||
if(!field_json[fields::embed][fields::from].is_array()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must be an array.");
|
||||
}
|
||||
|
||||
if(field_json[fields::embed][fields::from].empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must have at least one element.");
|
||||
}
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
|
||||
if(!embed_from_field.is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must contain only field names as strings.");
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` is only allowed on a float array field.");
|
||||
}
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
|
||||
bool flag = false;
|
||||
for(const auto& field : search_schema) {
|
||||
if(field.name == embed_from_field) {
|
||||
if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
for(const auto& other_kv: schema_changes["fields"].items()) {
|
||||
if(other_kv.value()["name"] == embed_from_field) {
|
||||
if(other_kv.value()[fields::type] != field_types::STRING && other_kv.value()[fields::type] != field_types::STRING_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(num_auto_detect_fields > 1) {
|
||||
return Option<bool>(400, "There can be only one field named `.*`.");
|
||||
}
|
||||
@ -4819,10 +4937,10 @@ void Collection::process_remove_field_for_embedding_fields(const field& the_fiel
|
||||
|
||||
void Collection::hide_credential(nlohmann::json& json, const std::string& credential_name) {
|
||||
if(json.count(credential_name) != 0) {
|
||||
// hide api key with * except first 3 chars
|
||||
// hide api key with * except first 5 chars
|
||||
std::string credential_name_str = json[credential_name];
|
||||
if(credential_name_str.size() > 3) {
|
||||
json[credential_name] = credential_name_str.replace(3, credential_name_str.size() - 3, credential_name_str.size() - 3, '*');
|
||||
if(credential_name_str.size() > 5) {
|
||||
json[credential_name] = credential_name_str.replace(5, credential_name_str.size() - 5, credential_name_str.size() - 5, '*');
|
||||
} else {
|
||||
json[credential_name] = credential_name_str.replace(0, credential_name_str.size(), credential_name_str.size(), '*');
|
||||
}
|
||||
|
@ -669,6 +669,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
const char *MAX_FACET_VALUES = "max_facet_values";
|
||||
|
||||
const char *VECTOR_QUERY = "vector_query";
|
||||
const char *VECTOR_QUERY_HITS = "vector_query_hits";
|
||||
|
||||
const char *GROUP_BY = "group_by";
|
||||
const char *GROUP_LIMIT = "group_limit";
|
||||
@ -821,6 +822,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
size_t max_extra_suffix = INT16_MAX;
|
||||
bool enable_highlight_v1 = true;
|
||||
text_match_type_t match_type = max_score;
|
||||
size_t vector_query_hits = 250;
|
||||
|
||||
size_t facet_sample_percent = 100;
|
||||
size_t facet_sample_threshold = 0;
|
||||
@ -847,6 +849,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{FILTER_CURATED_HITS, &filter_curated_hits_option},
|
||||
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
|
||||
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
|
||||
{VECTOR_QUERY_HITS, &vector_query_hits},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string*> str_values = {
|
||||
@ -1383,7 +1386,7 @@ Option<bool> CollectionManager::load_collection(const nlohmann::json &collection
|
||||
// batch must match atleast the number of shards
|
||||
if(exceeds_batch_mem_threshold || (num_valid_docs % batch_size == 0) || last_record) {
|
||||
size_t num_records = index_records.size();
|
||||
size_t num_indexed = collection->batch_index_in_memory(index_records);
|
||||
size_t num_indexed = collection->batch_index_in_memory(index_records, false);
|
||||
batch_doc_str_size = 0;
|
||||
|
||||
if(num_indexed != num_records) {
|
||||
|
@ -928,14 +928,42 @@ bool post_add_document(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
const index_operation_t operation = get_index_operation(req->params[ACTION]);
|
||||
const auto& dirty_values = collection->parse_dirty_values_option(req->params[DIRTY_VALUES_PARAM]);
|
||||
|
||||
Option<nlohmann::json> inserted_doc_op = collection->add(req->body, operation, "", dirty_values);
|
||||
nlohmann::json document;
|
||||
std::vector<std::string> json_lines = {req->body};
|
||||
const nlohmann::json& inserted_doc_op = collection->add_many(json_lines, document, operation, "", dirty_values, false, false);
|
||||
|
||||
if(!inserted_doc_op.ok()) {
|
||||
res->set(inserted_doc_op.code(), inserted_doc_op.error());
|
||||
if(!inserted_doc_op["success"].get<bool>()) {
|
||||
nlohmann::json res_doc;
|
||||
|
||||
try {
|
||||
res_doc = nlohmann::json::parse(json_lines[0]);
|
||||
} catch(const std::exception& e) {
|
||||
LOG(ERROR) << "JSON error: " << e.what();
|
||||
res->set_400("Bad JSON.");
|
||||
return false;
|
||||
}
|
||||
|
||||
res->status_code = res_doc["code"].get<size_t>();
|
||||
// erase keys from res_doc except error and embedding_error
|
||||
for(auto it = res_doc.begin(); it != res_doc.end(); ) {
|
||||
if(it.key() != "error" && it.key() != "embedding_error") {
|
||||
it = res_doc.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// rename error to message if not empty and exists
|
||||
if(res_doc.count("error") != 0 && !res_doc["error"].get<std::string>().empty()) {
|
||||
res_doc["message"] = res_doc["error"];
|
||||
res_doc.erase("error");
|
||||
}
|
||||
|
||||
res->body = res_doc.dump();
|
||||
return false;
|
||||
}
|
||||
|
||||
res->set_201(inserted_doc_op.get().dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore));
|
||||
res->set_201(document.dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -279,6 +279,8 @@ size_t HttpClient::curl_write_async_done(void *context, curl_socket_t item) {
|
||||
|
||||
if(!req_res->res->is_alive) {
|
||||
// underlying client request is dead, don't try to send anymore data
|
||||
// also, close the socket as we've overridden the close socket handler!
|
||||
close(item);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -504,6 +504,9 @@ int HttpServer::catch_all_handler(h2o_handler_t *_h2o_handler, h2o_req_t *req) {
|
||||
);
|
||||
*allocated_generator = custom_gen;
|
||||
|
||||
// ensures that the first response need not wait for previous chunk to be done sending
|
||||
response->notify();
|
||||
|
||||
//LOG(INFO) << "Init res: " << custom_gen->response << ", ref count: " << custom_gen->response.use_count();
|
||||
|
||||
if(root_resource == "multi_search") {
|
||||
@ -544,7 +547,9 @@ bool HttpServer::is_write_request(const std::string& root_resource, const std::s
|
||||
return false;
|
||||
}
|
||||
|
||||
bool write_free_request = (root_resource == "multi_search" || root_resource == "operations");
|
||||
bool write_free_request = (root_resource == "multi_search" || root_resource == "proxy" ||
|
||||
root_resource == "operations");
|
||||
|
||||
if(!write_free_request &&
|
||||
(http_method == "POST" || http_method == "PUT" ||
|
||||
http_method == "DELETE" || http_method == "PATCH")) {
|
||||
@ -632,9 +637,6 @@ int HttpServer::async_req_cb(void *ctx, int is_end_stream) {
|
||||
|
||||
if(request->first_chunk_aggregate) {
|
||||
request->first_chunk_aggregate = false;
|
||||
|
||||
// ensures that the first response need not wait for previous chunk to be done sending
|
||||
response->notify();
|
||||
}
|
||||
|
||||
// default value for last_chunk_aggregate is false
|
||||
@ -821,7 +823,7 @@ void HttpServer::stream_response(stream_response_state_t& state) {
|
||||
h2o_start_response(req, state.generator);
|
||||
}
|
||||
|
||||
h2o_send(req, &state.res_body, 1, H2O_SEND_STATE_FINAL);
|
||||
h2o_send(req, &state.res_buff, 1, H2O_SEND_STATE_FINAL);
|
||||
h2o_dispose_request(req);
|
||||
|
||||
return ;
|
||||
@ -833,13 +835,13 @@ void HttpServer::stream_response(stream_response_state_t& state) {
|
||||
h2o_start_response(req, state.generator);
|
||||
}
|
||||
|
||||
if(state.res_body.len == 0 && state.send_state != H2O_SEND_STATE_FINAL) {
|
||||
if(state.res_buff.len == 0 && state.send_state != H2O_SEND_STATE_FINAL) {
|
||||
// without this guard, http streaming will break
|
||||
state.generator->proceed(state.generator, req);
|
||||
return;
|
||||
}
|
||||
|
||||
h2o_send(req, &state.res_body, 1, state.send_state);
|
||||
h2o_send(req, &state.res_buff, 1, state.send_state);
|
||||
|
||||
//LOG(INFO) << "stream_response after send";
|
||||
}
|
||||
|
168
src/index.cpp
168
src/index.cpp
@ -43,6 +43,7 @@ spp::sparse_hash_map<uint32_t, int64_t> Index::seq_id_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_distance_sentinel_value;
|
||||
|
||||
struct token_posting_t {
|
||||
uint32_t token_id;
|
||||
@ -427,10 +428,10 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation) {
|
||||
const bool do_validation, const bool generate_embeddings) {
|
||||
|
||||
// runs in a partitioned thread
|
||||
std::vector<nlohmann::json*> docs_to_embed;
|
||||
std::vector<index_record*> records_to_embed;
|
||||
|
||||
for(size_t i = 0; i < batch_size; i++) {
|
||||
index_record& index_rec = iter_batch[batch_start_index + i];
|
||||
@ -453,7 +454,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
index_rec.operation,
|
||||
index_rec.is_update,
|
||||
fallback_field_type,
|
||||
index_rec.dirty_values);
|
||||
index_rec.dirty_values, generate_embeddings);
|
||||
|
||||
if(!validation_op.ok()) {
|
||||
index_rec.index_failure(validation_op.code(), validation_op.error());
|
||||
@ -467,14 +468,16 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
index_rec.new_doc, index_rec.del_doc);
|
||||
scrub_reindex_doc(search_schema, index_rec.doc, index_rec.del_doc, index_rec.old_doc);
|
||||
|
||||
for(auto& field: index_rec.doc.items()) {
|
||||
for(auto& embedding_field : embedding_fields) {
|
||||
if(!embedding_field.embed[fields::from].is_null()) {
|
||||
auto embed_from_vector = embedding_field.embed[fields::from].get<std::vector<std::string>>();
|
||||
for(auto& embed_from: embed_from_vector) {
|
||||
if(embed_from == field.key()) {
|
||||
docs_to_embed.push_back(&index_rec.new_doc);
|
||||
break;
|
||||
if(generate_embeddings) {
|
||||
for(auto& field: index_rec.doc.items()) {
|
||||
for(auto& embedding_field : embedding_fields) {
|
||||
if(!embedding_field.embed[fields::from].is_null()) {
|
||||
auto embed_from_vector = embedding_field.embed[fields::from].get<std::vector<std::string>>();
|
||||
for(auto& embed_from: embed_from_vector) {
|
||||
if(embed_from == field.key()) {
|
||||
records_to_embed.push_back(&index_rec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -482,7 +485,9 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
}
|
||||
} else {
|
||||
handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc);
|
||||
docs_to_embed.push_back(&index_rec.doc);
|
||||
if(generate_embeddings) {
|
||||
records_to_embed.push_back(&index_rec);
|
||||
}
|
||||
}
|
||||
|
||||
compute_token_offsets_facets(index_rec, search_schema, token_separators, symbols_to_index);
|
||||
@ -512,13 +517,8 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
|
||||
index_rec.index_failure(400, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
auto embed_op = batch_embed_fields(docs_to_embed, embedding_fields, search_schema);
|
||||
if(!embed_op.ok()) {
|
||||
for(size_t i = 0; i < batch_size; i++) {
|
||||
index_record& index_rec = iter_batch[batch_start_index + i];
|
||||
index_rec.index_failure(embed_op.code(), embed_op.error());
|
||||
}
|
||||
if(generate_embeddings) {
|
||||
batch_embed_fields(records_to_embed, embedding_fields, search_schema);
|
||||
}
|
||||
}
|
||||
|
||||
@ -529,7 +529,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
|
||||
const std::string& fallback_field_type,
|
||||
const std::vector<char>& token_separators,
|
||||
const std::vector<char>& symbols_to_index,
|
||||
const bool do_validation) {
|
||||
const bool do_validation, const bool generate_embeddings) {
|
||||
|
||||
const size_t concurrency = 4;
|
||||
const size_t num_threads = std::min(concurrency, iter_batch.size());
|
||||
@ -559,7 +559,7 @@ size_t Index::batch_memory_index(Index *index, std::vector<index_record>& iter_b
|
||||
index->thread_pool->enqueue([&, batch_index, batch_len]() {
|
||||
write_log_index = local_write_log_index;
|
||||
validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema,
|
||||
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation);
|
||||
embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, generate_embeddings);
|
||||
|
||||
std::unique_lock<std::mutex> lock(m_process);
|
||||
num_processed++;
|
||||
@ -2530,12 +2530,12 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
}
|
||||
|
||||
int64_t scores[3] = {0};
|
||||
scores[0] = -float_to_int64_t(vec_dist_score);
|
||||
int64_t match_score_index = -1;
|
||||
|
||||
//LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first;
|
||||
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0, 0, scores, match_score_index, vec_dist_score);
|
||||
|
||||
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr);
|
||||
kv.vector_distance = vec_dist_score;
|
||||
int ret = topster->add(&kv);
|
||||
|
||||
if(group_limit != 0 && ret < 2) {
|
||||
@ -2756,7 +2756,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
VectorFilterFunctor filterFunctor(filter_result_iterator);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
auto k = std::max<size_t>(vector_query.k, fetch_size);
|
||||
// use k as 100 by default for ensuring results stability in pagination
|
||||
size_t default_k = 100;
|
||||
auto k = std::max<size_t>(vector_query.k, default_k);
|
||||
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
@ -2792,29 +2794,62 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
continue;
|
||||
}
|
||||
// (1 / rank_of_document) * WEIGHT)
|
||||
result->text_match_score = result->scores[result->match_score_index];
|
||||
result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT);
|
||||
}
|
||||
|
||||
for(int i = 0; i < vec_results.size(); i++) {
|
||||
auto& result = vec_results[i];
|
||||
auto doc_id = result.first;
|
||||
std::vector<uint32_t> vec_search_ids; // list of IDs found only in vector search
|
||||
|
||||
for(size_t res_index = 0; res_index < vec_results.size(); res_index++) {
|
||||
auto& vec_result = vec_results[res_index];
|
||||
auto doc_id = vec_result.first;
|
||||
auto result_it = topster->kv_map.find(doc_id);
|
||||
|
||||
if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) {
|
||||
if(result_it != topster->kv_map.end()) {
|
||||
if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// result overlaps with keyword search: we have to combine the scores
|
||||
|
||||
auto result = result_it->second;
|
||||
// old_score + (1 / rank_of_document) * WEIGHT)
|
||||
result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT));
|
||||
result->vector_distance = vec_result.second;
|
||||
result->scores[result->match_score_index] = float_to_int64_t(
|
||||
(int64_t_to_float(result->scores[result->match_score_index])) +
|
||||
((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT));
|
||||
|
||||
for(size_t i = 0;i < 3; i++) {
|
||||
if(field_values[i] == &vector_distance_sentinel_value) {
|
||||
result->scores[i] = float_to_int64_t(vec_result.second);
|
||||
}
|
||||
|
||||
if(sort_order[i] == -1) {
|
||||
result->scores[i] = -result->scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
int64_t scores[3] = {0};
|
||||
// Result has been found only in vector search: we have to add it to both KV and result_ids
|
||||
// (1 / rank_of_document) * WEIGHT)
|
||||
scores[0] = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT);
|
||||
int64_t match_score_index = 0;
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score = float_to_int64_t((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT);
|
||||
int64_t match_score_index = -1;
|
||||
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second);
|
||||
KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores);
|
||||
kv.vector_distance = vec_result.second;
|
||||
topster->add(&kv);
|
||||
++all_result_ids_len;
|
||||
vec_search_ids.push_back(doc_id);
|
||||
}
|
||||
}
|
||||
|
||||
if(!vec_search_ids.empty()) {
|
||||
uint32_t* new_all_result_ids = nullptr;
|
||||
all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &vec_search_ids[0],
|
||||
vec_search_ids.size(), &new_all_result_ids);
|
||||
delete[] all_result_ids;
|
||||
all_result_ids = new_all_result_ids;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3778,7 +3813,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
uint32_t seq_id, size_t filter_index, int64_t max_field_match_score,
|
||||
int64_t* scores, int64_t& match_score_index) const {
|
||||
int64_t* scores, int64_t& match_score_index, float vector_distance) const {
|
||||
|
||||
int64_t geopoint_distances[3];
|
||||
|
||||
@ -3873,6 +3908,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
|
||||
scores[0] = int64_t(found);
|
||||
} else if(field_values[0] == &vector_distance_sentinel_value) {
|
||||
scores[0] = float_to_int64_t(vector_distance);
|
||||
} else {
|
||||
auto it = field_values[0]->find(seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
@ -3929,6 +3966,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
|
||||
scores[1] = int64_t(found);
|
||||
} else if(field_values[1] == &vector_distance_sentinel_value) {
|
||||
scores[1] = float_to_int64_t(vector_distance);
|
||||
} else {
|
||||
auto it = field_values[1]->find(seq_id);
|
||||
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
|
||||
@ -3981,6 +4020,8 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
|
||||
scores[2] = int64_t(found);
|
||||
} else if(field_values[2] == &vector_distance_sentinel_value) {
|
||||
scores[2] = float_to_int64_t(vector_distance);
|
||||
} else {
|
||||
auto it = field_values[2]->find(seq_id);
|
||||
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
|
||||
@ -4679,15 +4720,14 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
field_values[i] = &seq_id_sentinel_value;
|
||||
} else if (sort_fields_std[i].name == sort_field_const::eval) {
|
||||
field_values[i] = &eval_sentinel_value;
|
||||
|
||||
auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root);
|
||||
auto filter_init_op = filter_result_iterator.init_status();
|
||||
if (!filter_init_op.ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids);
|
||||
|
||||
} else if(sort_fields_std[i].name == sort_field_const::vector_distance) {
|
||||
field_values[i] = &vector_distance_sentinel_value;
|
||||
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
|
||||
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
|
||||
geopoint_indices.push_back(i);
|
||||
@ -6091,13 +6131,26 @@ bool Index::common_results_exist(std::vector<art_leaf*>& leaves, bool must_match
|
||||
}
|
||||
|
||||
|
||||
Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
void Index::batch_embed_fields(std::vector<index_record*>& records,
|
||||
const tsl::htrie_map<char, field>& embedding_fields,
|
||||
const tsl::htrie_map<char, field> & search_schema) {
|
||||
for(const auto& field : embedding_fields) {
|
||||
std::vector<std::pair<nlohmann::json*, std::string>> texts_to_embed;
|
||||
std::vector<std::pair<index_record*, std::string>> texts_to_embed;
|
||||
auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]);
|
||||
for(auto& document : documents) {
|
||||
for(auto& record : records) {
|
||||
if(!record->indexed.ok()) {
|
||||
continue;
|
||||
}
|
||||
nlohmann::json* document;
|
||||
if(record->is_update) {
|
||||
document = &record->new_doc;
|
||||
} else {
|
||||
document = &record->doc;
|
||||
}
|
||||
|
||||
if(document == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string text = indexing_prefix;
|
||||
auto embed_from = field.embed[fields::from].get<std::vector<std::string>>();
|
||||
for(const auto& field_name : embed_from) {
|
||||
@ -6110,8 +6163,8 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!text.empty()) {
|
||||
texts_to_embed.push_back(std::make_pair(document, text));
|
||||
if(text != indexing_prefix) {
|
||||
texts_to_embed.push_back(std::make_pair(record, text));
|
||||
}
|
||||
}
|
||||
|
||||
@ -6123,13 +6176,15 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
|
||||
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<bool>(400, embedder_op.error());
|
||||
LOG(ERROR) << "Error while getting embedder for model: " << field.embed[fields::model_config];
|
||||
LOG(ERROR) << "Error: " << embedder_op.error();
|
||||
return;
|
||||
}
|
||||
|
||||
// sort texts by length
|
||||
std::sort(texts_to_embed.begin(), texts_to_embed.end(),
|
||||
[](const std::pair<nlohmann::json*, std::string>& a,
|
||||
const std::pair<nlohmann::json*, std::string>& b) {
|
||||
[](const std::pair<index_record*, std::string>& a,
|
||||
const std::pair<index_record*, std::string>& b) {
|
||||
return a.second.size() < b.second.size();
|
||||
});
|
||||
|
||||
@ -6139,19 +6194,24 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
texts.push_back(text_to_embed.second);
|
||||
}
|
||||
|
||||
auto embedding_op = embedder_op.get()->batch_embed(texts);
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<bool>(400, embedding_op.error());
|
||||
}
|
||||
auto embeddings = embedder_op.get()->batch_embed(texts);
|
||||
|
||||
auto embeddings = embedding_op.get();
|
||||
for(size_t i = 0; i < embeddings.size(); i++) {
|
||||
auto& embedding = embeddings[i];
|
||||
auto& document = texts_to_embed[i].first;
|
||||
(*document)[field.name] = embedding;
|
||||
auto& embedding_res = embeddings[i];
|
||||
if(!embedding_res.success) {
|
||||
texts_to_embed[i].first->embedding_res = embedding_res.error;
|
||||
texts_to_embed[i].first->index_failure(embedding_res.status_code, "");
|
||||
continue;
|
||||
}
|
||||
nlohmann::json* document;
|
||||
if(texts_to_embed[i].first->is_update) {
|
||||
document = &texts_to_embed[i].first->new_doc;
|
||||
} else {
|
||||
document = &texts_to_embed[i].first->doc;
|
||||
}
|
||||
(*document)[field.name] = embedding_res.embedding;
|
||||
}
|
||||
}
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -194,7 +194,8 @@ void ReplicationState::write(const std::shared_ptr<http_req>& request, const std
|
||||
auto resource_check = cached_resource_stat_t::get_instance().has_enough_resources(raft_dir_path,
|
||||
config->get_disk_used_max_percentage(), config->get_memory_used_max_percentage());
|
||||
|
||||
if (resource_check != cached_resource_stat_t::OK && request->http_method != "DELETE") {
|
||||
if (resource_check != cached_resource_stat_t::OK &&
|
||||
request->http_method != "DELETE" && request->path_without_query != "/health") {
|
||||
response->set_422("Rejecting write: running out of resource type: " +
|
||||
std::string(magic_enum::enum_name(resource_check)));
|
||||
response->final = true;
|
||||
@ -967,6 +968,36 @@ bool ReplicationState::get_ext_snapshot_succeeded() {
|
||||
std::string ReplicationState::get_leader_url() const {
|
||||
std::shared_lock lock(node_mutex);
|
||||
|
||||
if(!node) {
|
||||
const Option<std::string> & refreshed_nodes_op = Config::fetch_nodes_config(config->get_nodes());
|
||||
|
||||
if(!refreshed_nodes_op.ok()) {
|
||||
LOG(WARNING) << "Error while fetching peer configuration: " << refreshed_nodes_op.error();
|
||||
return "";
|
||||
}
|
||||
|
||||
const std::string& nodes_config = ReplicationState::to_nodes_config(peering_endpoint,
|
||||
Config::get_instance().get_api_port(),
|
||||
|
||||
refreshed_nodes_op.get());
|
||||
std::vector<braft::PeerId> peers;
|
||||
braft::Configuration peer_config;
|
||||
peer_config.parse_from(nodes_config);
|
||||
peer_config.list_peers(&peers);
|
||||
|
||||
if(peers.empty()) {
|
||||
LOG(WARNING) << "No peers found in nodes config: " << nodes_config;
|
||||
return "";
|
||||
}
|
||||
|
||||
|
||||
const std::string protocol = api_uses_ssl ? "https" : "http";
|
||||
std::string url = get_node_url_path(peers[0].to_string(), "/", protocol);
|
||||
|
||||
LOG(INFO) << "Returning first peer as leader URL: " << url;
|
||||
return url;
|
||||
}
|
||||
|
||||
if(node->leader_id().is_empty()) {
|
||||
LOG(ERROR) << "Could not get leader status, as node does not have a leader!";
|
||||
return "";
|
||||
|
@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
|
||||
|
||||
TextEmbedder::TextEmbedder(const nlohmann::json& model_config) {
|
||||
auto model_name = model_config["model_name"].get<std::string>();
|
||||
LOG(INFO) << "Loading model from remote: " << model_name;
|
||||
LOG(INFO) << "Initializing remote embedding model: " << model_name;
|
||||
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
|
||||
|
||||
if(model_namespace == "openai") {
|
||||
@ -83,7 +83,7 @@ std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<floa
|
||||
return pooled_output;
|
||||
}
|
||||
|
||||
Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t TextEmbedder::Embed(const std::string& text) {
|
||||
if(is_remote()) {
|
||||
return remote_embedder_->Embed(text);
|
||||
} else {
|
||||
@ -129,12 +129,12 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
}
|
||||
auto pooled_output = mean_pooling(output);
|
||||
|
||||
return Option<std::vector<float>>(pooled_output);
|
||||
return embedding_res_t(pooled_output);
|
||||
}
|
||||
}
|
||||
|
||||
Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
if(!is_remote()) {
|
||||
for(int i = 0; i < inputs.size(); i += 8) {
|
||||
auto input_batch = std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + 8, static_cast<int>(inputs.size())));
|
||||
@ -193,7 +193,7 @@ Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vec
|
||||
// if seq length is 0, return empty vector
|
||||
if(input_shapes[0][1] == 0) {
|
||||
for(int i = 0; i < input_batch.size(); i++) {
|
||||
outputs.push_back(std::vector<float>());
|
||||
outputs.push_back(embedding_res_t(400, nlohmann::json({{"error", "Invalid input: empty sequence"}})));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -211,17 +211,14 @@ Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vec
|
||||
}
|
||||
output.push_back(output_row);
|
||||
}
|
||||
outputs.push_back(mean_pooling(output));
|
||||
outputs.push_back(embedding_res_t(mean_pooling(output)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto embed_op = remote_embedder_->batch_embed(inputs);
|
||||
if(!embed_op.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(embed_op.code(), embed_op.error());
|
||||
}
|
||||
outputs = embed_op.get();
|
||||
outputs = std::move(remote_embedder_->batch_embed(inputs));
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
TextEmbedder::~TextEmbedder() {
|
||||
|
@ -94,7 +94,6 @@ const std::string& TextEmbedderManager::get_model_dir() {
|
||||
}
|
||||
|
||||
TextEmbedderManager::~TextEmbedderManager() {
|
||||
delete_all_text_embedders();
|
||||
}
|
||||
|
||||
const std::string TextEmbedderManager::get_absolute_model_path(const std::string& model_name) {
|
||||
@ -117,8 +116,9 @@ const bool TextEmbedderManager::check_md5(const std::string& file_path, const st
|
||||
std::stringstream ss,res;
|
||||
ss << stream.rdbuf();
|
||||
MD5((unsigned char*)ss.str().c_str(), ss.str().length(), md5);
|
||||
for (int i = 0; i < MD5_DIGEST_LENGTH; i++) {
|
||||
res << std::hex << (int)md5[i];
|
||||
// convert md5 to hex string with leading zeros
|
||||
for(int i = 0; i < MD5_DIGEST_LENGTH; i++) {
|
||||
res << std::hex << std::setfill('0') << std::setw(2) << (int)md5[i];
|
||||
}
|
||||
return res.str() == target_md5;
|
||||
}
|
||||
|
@ -13,15 +13,16 @@ Option<bool> RemoteEmbedder::validate_string_properties(const nlohmann::json& mo
|
||||
|
||||
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body,
|
||||
std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers) {
|
||||
if(raft_server == nullptr) {
|
||||
if(raft_server == nullptr || raft_server->get_leader_url().empty()) {
|
||||
if(method == "GET") {
|
||||
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 10000, true);
|
||||
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 100000, true);
|
||||
} else if(method == "POST") {
|
||||
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 10000, true);
|
||||
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 100000, true);
|
||||
} else {
|
||||
return 400;
|
||||
}
|
||||
}
|
||||
|
||||
auto leader_url = raft_server->get_leader_url();
|
||||
leader_url += "proxy";
|
||||
nlohmann::json req_body;
|
||||
@ -103,7 +104,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
@ -116,15 +117,23 @@ Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
return embedding_res_t(res_code, embedding_res);
|
||||
}
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
|
||||
return embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = inputs;
|
||||
// remove "openai/" prefix
|
||||
@ -137,20 +146,35 @@ Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::v
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<std::vector<float>>>(400, "OpenAI API error: " + res);
|
||||
LOG(INFO) << "OpenAI API error: " << json_res.dump();
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
embedding_res["request"]["body"]["input"] = std::vector<std::string>{inputs[0]};
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(400, res);
|
||||
|
||||
for(size_t i = 0; i < inputs.size(); i++) {
|
||||
embedding_res["request"]["body"]["input"][0] = inputs[i];
|
||||
outputs.push_back(embedding_res_t(res_code, embedding_res));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(auto& data : res_json["data"]) {
|
||||
outputs.push_back(data["embedding"].get<std::vector<float>>());
|
||||
outputs.push_back(embedding_res_t(data["embedding"].get<std::vector<float>>()));
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
|
||||
@ -198,7 +222,7 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
@ -210,27 +234,30 @@ Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "Google API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<float>>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
return embedding_res_t(res_code, embedding_res);
|
||||
}
|
||||
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
|
||||
return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
|
||||
Option<std::vector<std::vector<float>>> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(auto& input : inputs) {
|
||||
auto res = Embed(input);
|
||||
if(!res.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(res.code(), res.error());
|
||||
}
|
||||
outputs.push_back(res.get());
|
||||
outputs.push_back(res);
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
|
||||
@ -298,7 +325,7 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
embedding_res_t GCPEmbedder::Embed(const std::string& text) {
|
||||
nlohmann::json req_body;
|
||||
req_body["instances"] = nlohmann::json::array();
|
||||
nlohmann::json instance;
|
||||
@ -316,7 +343,9 @@ Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
if(res_code == 401) {
|
||||
auto refresh_op = generate_access_token(refresh_token, client_id, client_secret);
|
||||
if(!refresh_op.ok()) {
|
||||
return Option<std::vector<float>>(refresh_op.code(), refresh_op.error());
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["error"] = refresh_op.error();
|
||||
return embedding_res_t(refresh_op.code(), embedding_res);
|
||||
}
|
||||
access_token = refresh_op.get();
|
||||
// retry
|
||||
@ -327,32 +356,32 @@ Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "GCP API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<float>>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
return embedding_res_t(res_code, embedding_res);
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
return Option<std::vector<float>>(res_json["predictions"][0]["embeddings"]["values"].get<std::vector<float>>());
|
||||
return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
|
||||
Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
// GCP API has a limit of 5 instances per request
|
||||
if(inputs.size() > 5) {
|
||||
std::vector<std::vector<float>> res;
|
||||
std::vector<embedding_res_t> res;
|
||||
for(size_t i = 0; i < inputs.size(); i += 5) {
|
||||
auto batch_res = batch_embed(std::vector<std::string>(inputs.begin() + i, inputs.begin() + std::min(i + 5, inputs.size())));
|
||||
if(!batch_res.ok()) {
|
||||
LOG(INFO) << "Batch embedding failed: " << batch_res.error();
|
||||
return Option<std::vector<std::vector<float>>>(batch_res.code(), batch_res.error());
|
||||
}
|
||||
auto batch = batch_res.get();
|
||||
res.insert(res.end(), batch.begin(), batch.end());
|
||||
res.insert(res.end(), batch_res.begin(), batch_res.end());
|
||||
}
|
||||
auto opt = Option<std::vector<std::vector<float>>>(res);
|
||||
return opt;
|
||||
return res;
|
||||
}
|
||||
nlohmann::json req_body;
|
||||
req_body["instances"] = nlohmann::json::array();
|
||||
@ -371,7 +400,13 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
if(res_code == 401) {
|
||||
auto refresh_op = generate_access_token(refresh_token, client_id, client_secret);
|
||||
if(!refresh_op.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(refresh_op.code(), refresh_op.error());
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["error"] = refresh_op.error();
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(size_t i = 0; i < inputs.size(); i++) {
|
||||
outputs.push_back(embedding_res_t(refresh_op.code(), embedding_res));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
access_token = refresh_op.get();
|
||||
// retry
|
||||
@ -382,19 +417,29 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<std::vector<float>>>(400, "GCP API error: " + res);
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
|
||||
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(size_t i = 0; i < inputs.size(); i++) {
|
||||
outputs.push_back(embedding_res_t(res_code, embedding_res));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
std::vector<std::vector<float>> outputs;
|
||||
std::vector<embedding_res_t> outputs;
|
||||
for(const auto& prediction : res_json["predictions"]) {
|
||||
outputs.push_back(prediction["embeddings"]["values"].get<std::vector<float>>());
|
||||
outputs.push_back(embedding_res_t(prediction["embeddings"]["values"].get<std::vector<float>>()));
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
Option<std::string> GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) {
|
||||
|
@ -606,7 +606,7 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
|
||||
const index_operation_t op,
|
||||
const bool is_update,
|
||||
const std::string& fallback_field_type,
|
||||
const DIRTY_VALUES& dirty_values) {
|
||||
const DIRTY_VALUES& dirty_values, const bool validate_embedding_fields) {
|
||||
|
||||
bool missing_default_sort_field = (!default_sorting_field.empty() && document.count(default_sorting_field) == 0);
|
||||
|
||||
@ -653,10 +653,12 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
|
||||
}
|
||||
}
|
||||
|
||||
// validate embedding fields
|
||||
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update);
|
||||
if(!validate_embed_op.ok()) {
|
||||
return Option<>(validate_embed_op.code(), validate_embed_op.error());
|
||||
if(validate_embedding_fields) {
|
||||
// validate embedding fields
|
||||
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update);
|
||||
if(!validate_embed_op.ok()) {
|
||||
return Option<>(validate_embed_op.code(), validate_embed_op.error());
|
||||
}
|
||||
}
|
||||
|
||||
return Option<>(200);
|
||||
|
@ -1629,4 +1629,46 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
|
||||
|
||||
embedding_fields_map = coll->get_embedding_fields();
|
||||
ASSERT_EQ(0, embedding_fields_map.size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSchemaChangeTest, DropAndReindexEmbeddingField) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
// drop the embedding field and reindex
|
||||
nlohmann::json schema_without_embedding = R"({
|
||||
"fields": [
|
||||
{"name": "embedding", "drop": true},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
auto update_op = op.get()->alter(schema_without_embedding);
|
||||
|
||||
ASSERT_TRUE(update_op.ok());
|
||||
|
||||
auto embedding_fields_map = op.get()->get_embedding_fields();
|
||||
|
||||
ASSERT_EQ(1, embedding_fields_map.size());
|
||||
|
||||
// try adding a document
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "hello";
|
||||
auto add_op = op.get()->add(doc.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
auto added_doc = add_op.get();
|
||||
|
||||
ASSERT_EQ(384, added_doc["embedding"].get<std::vector<float>>().size());
|
||||
}
|
@ -2246,3 +2246,145 @@ TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, AscendingVectorDistance) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{3.0, 4.0},
|
||||
{9.0, 21.0},
|
||||
{8.0, 15.0},
|
||||
{1.0, 1.0},
|
||||
{5.0, 7.0}
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {
|
||||
sort_by("_vector_distance", "asc"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "points:([8.0, 15.0])").get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"2", "1", "4", "0", "3"};
|
||||
for(size_t i = 0; i < expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, DescendingVectorDistance) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{3.0, 4.0},
|
||||
{9.0, 21.0},
|
||||
{8.0, 15.0},
|
||||
{1.0, 1.0},
|
||||
{5.0, 7.0}
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {
|
||||
sort_by("_vector_distance", "DESC"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "points:([8.0, 15.0])").get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"3", "0", "4", "1", "2"};
|
||||
|
||||
for(size_t i = 0; i < expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{1.0, 1.0},
|
||||
{2.0, 2.0},
|
||||
{3.0, 3.0},
|
||||
{4.0, 4.0},
|
||||
{5.0, 5.0},
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {
|
||||
sort_by("_vector_distance", "desc"),
|
||||
};
|
||||
|
||||
|
||||
|
||||
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
|
||||
ASSERT_FALSE(results.ok());
|
||||
|
||||
ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error());
|
||||
}
|
@ -4790,9 +4790,9 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) {
|
||||
ASSERT_EQ("butterfly", search_res["hits"][1]["document"]["name"].get<std::string>());
|
||||
ASSERT_EQ("butterball", search_res["hits"][2]["document"]["name"].get<std::string>());
|
||||
|
||||
ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
|
||||
@ -4813,8 +4813,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error());
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, CreateModelDirIfNotExists) {
|
||||
@ -5060,7 +5059,7 @@ TEST_F(CollectionTest, HideOpenAIApiKey) {
|
||||
ASSERT_TRUE(op.ok());
|
||||
auto summary = op.get()->get_summary_json();
|
||||
// hide api key with * after first 3 characters
|
||||
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'));
|
||||
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(5, api_key.size() - 5, api_key.size() - 5, '*'));
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, PrefixSearchDisabledForOpenAI) {
|
||||
@ -5134,3 +5133,48 @@ TEST_F(CollectionTest, MoreThanOneEmbeddingField) {
|
||||
ASSERT_EQ("Only one embedding field is allowed in the query.", search_res_op.error());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "names", "type": "string[]"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
auto coll = op.get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["names"] = nlohmann::json::array();
|
||||
|
||||
// try adding
|
||||
auto add_op = coll->add(doc.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
ASSERT_TRUE(add_op.get()["embedding"].is_null());
|
||||
|
||||
// try updating
|
||||
auto id = add_op.get()["id"];
|
||||
doc["names"].push_back("butter");
|
||||
std::string dirty_values;
|
||||
|
||||
|
||||
auto update_op = coll->update_matching_filter("id:=" + id.get<std::string>(), doc.dump(), dirty_values);
|
||||
ASSERT_TRUE(update_op.ok());
|
||||
ASSERT_EQ(1, update_op.get()["num_updated"]);
|
||||
|
||||
|
||||
auto get_op = coll->get(id);
|
||||
ASSERT_TRUE(get_op.ok());
|
||||
|
||||
ASSERT_FALSE(get_op.get()["embedding"].is_null());
|
||||
|
||||
ASSERT_EQ(384, get_op.get()["embedding"].size());
|
||||
}
|
@ -711,12 +711,42 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:(" + dummy_vec_string +")");
|
||||
ASSERT_EQ(true, results_op.ok());
|
||||
|
||||
|
||||
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results_op.get()["hits"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string", "facet": true},
|
||||
{"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "john doe";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
|
||||
auto results_op = coll1->search("zzz", {"name", "vec"}, "", {"name"}, {}, {0}, 20, 1, FREQUENCY, {true},
|
||||
Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2);
|
||||
ASSERT_EQ(true, results_op.ok());
|
||||
ASSERT_EQ(1, results_op.get()["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results_op.get()["hits"].size());
|
||||
ASSERT_EQ(1, results_op.get()["facet_counts"].size());
|
||||
ASSERT_EQ(4, results_op.get()["facet_counts"][0].size());
|
||||
ASSERT_EQ("name", results_op.get()["facet_counts"][0]["field_name"]);
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, DistanceThresholdTest) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "test",
|
||||
|
Loading…
x
Reference in New Issue
Block a user