mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 14:55:26 +08:00
Merge branch 'v0.26-filter' into v0.26-facets
# Conflicts: # include/index.h # src/index.cpp
This commit is contained in:
commit
9c5e553240
@ -167,3 +167,6 @@ bool is_doc_del_route(uint64_t route_hash);
|
||||
Option<std::pair<std::string,std::string>> get_api_key_and_ip(const std::string& metadata);
|
||||
|
||||
void init_api(uint32_t cache_num_entries);
|
||||
|
||||
|
||||
bool post_proxy(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
|
||||
|
@ -54,6 +54,7 @@ namespace fields {
|
||||
static const std::string from = "from";
|
||||
static const std::string embed_from = "embed_from";
|
||||
static const std::string model_name = "model_name";
|
||||
static const std::string range_index = "range_index";
|
||||
|
||||
// Some models require additional parameters to be passed to the model during indexing/querying
|
||||
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying
|
||||
@ -93,13 +94,17 @@ struct field {
|
||||
|
||||
std::string reference; // Foo.bar (reference to bar field in Foo collection).
|
||||
|
||||
bool range_index;
|
||||
|
||||
field() {}
|
||||
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) :
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
|
||||
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
|
||||
embed(embed), range_index(range_index) {
|
||||
|
||||
set_computed_defaults(sort, infix);
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ struct filter {
|
||||
std::string field_name;
|
||||
std::vector<std::string> values;
|
||||
std::vector<NUM_COMPARATOR> comparators;
|
||||
// Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the
|
||||
// Would be set when `field: != ...` is encountered with id/string field or `field: != [ ... ]` is encountered in the
|
||||
// case of int and float fields. During filtering, all the results of matching the field against the values are
|
||||
// aggregated and then this flag is checked if negation on the aggregated result is required.
|
||||
bool apply_not_equals = false;
|
||||
|
@ -160,6 +160,9 @@ public:
|
||||
/// Returns the status of the initialization of iterator tree.
|
||||
Option<bool> init_status();
|
||||
|
||||
/// Recursively computes the result of each node and stores the final result in the root node.
|
||||
void compute_result();
|
||||
|
||||
/// Returns a tri-state:
|
||||
/// 0: id is not valid
|
||||
/// 1: id is valid
|
||||
|
43
include/http_proxy.h
Normal file
43
include/http_proxy.h
Normal file
@ -0,0 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include "http_client.h"
|
||||
#include "lru/lru.hpp"
|
||||
|
||||
|
||||
struct http_proxy_res_t {
|
||||
std::string body;
|
||||
std::map<std::string, std::string> headers;
|
||||
long status_code;
|
||||
|
||||
bool operator==(const http_proxy_res_t& other) const {
|
||||
return body == other.body && headers == other.headers && status_code == other.status_code;
|
||||
}
|
||||
|
||||
bool operator!=(const http_proxy_res_t& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class HttpProxy {
|
||||
// singleton class for http proxy
|
||||
public:
|
||||
static HttpProxy& get_instance() {
|
||||
static HttpProxy instance;
|
||||
return instance;
|
||||
}
|
||||
HttpProxy(const HttpProxy&) = delete;
|
||||
void operator=(const HttpProxy&) = delete;
|
||||
HttpProxy(HttpProxy&&) = delete;
|
||||
void operator=(HttpProxy&&) = delete;
|
||||
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map<std::string, std::string>& headers);
|
||||
private:
|
||||
HttpProxy();
|
||||
~HttpProxy() = default;
|
||||
|
||||
|
||||
// lru cache for http requests
|
||||
LRU::TimedCache<uint64_t, http_proxy_res_t> cache;
|
||||
};
|
@ -52,6 +52,9 @@ public:
|
||||
bool is_res_start = true;
|
||||
h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS;
|
||||
h2o_iovec_t res_body{};
|
||||
h2o_iovec_t res_content_type{};
|
||||
int status = 0;
|
||||
const char* reason = nullptr;
|
||||
|
||||
h2o_generator_t* generator = nullptr;
|
||||
|
||||
@ -65,10 +68,9 @@ public:
|
||||
res_body = h2o_strdup(&req->pool, body.c_str(), SIZE_MAX);
|
||||
|
||||
if(is_res_start) {
|
||||
req->res.status = status_code;
|
||||
req->res.reason = http_res::get_status_reason(status_code);
|
||||
h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, NULL,
|
||||
content_type.c_str(), content_type.size());
|
||||
res_content_type = h2o_strdup(&req->pool, content_type.c_str(), SIZE_MAX);
|
||||
status = status_code;
|
||||
reason = http_res::get_status_reason(status_code);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -31,6 +31,7 @@
|
||||
#include "hnswlib/hnswlib.h"
|
||||
#include "filter.h"
|
||||
#include "facet_index.h"
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
static constexpr size_t ARRAY_FACET_DIM = 4;
|
||||
using facet_map_t = spp::sparse_hash_map<uint32_t, facet_hash_values_t>;
|
||||
@ -309,7 +310,9 @@ private:
|
||||
|
||||
spp::sparse_hash_map<std::string, num_tree_t*> numerical_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::vector<uint32_t>>*> geopoint_index;
|
||||
spp::sparse_hash_map<std::string, NumericTrie*> range_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, NumericTrie*> geo_range_index;
|
||||
|
||||
// geo_array_field => (seq_id => values) used for exact filtering of geo array records
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, int64_t*>*> geo_array_index;
|
||||
|
154
include/numeric_range_trie_test.h
Normal file
154
include/numeric_range_trie_test.h
Normal file
@ -0,0 +1,154 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include "sorted_array.h"
|
||||
|
||||
constexpr short EXPANSE = 256;
|
||||
|
||||
class NumericTrie {
|
||||
char max_level = 4;
|
||||
|
||||
class Node {
|
||||
Node** children = nullptr;
|
||||
sorted_array seq_ids;
|
||||
|
||||
void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level);
|
||||
|
||||
void insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, const char& max_level);
|
||||
|
||||
void search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level, std::set<Node*>& matches);
|
||||
|
||||
void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
void search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
void search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
public:
|
||||
|
||||
~Node() {
|
||||
if (children != nullptr) {
|
||||
for (auto i = 0; i < EXPANSE; i++) {
|
||||
delete children[i];
|
||||
}
|
||||
}
|
||||
|
||||
delete [] children;
|
||||
}
|
||||
|
||||
void insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void remove(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void search_geopoints(const std::vector<uint64_t>& cell_ids, const char& max_level,
|
||||
std::vector<uint32_t>& geo_result_ids);
|
||||
|
||||
void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level);
|
||||
|
||||
void get_all_ids(uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
};
|
||||
|
||||
Node* negative_trie = nullptr;
|
||||
Node* positive_trie = nullptr;
|
||||
|
||||
public:
|
||||
|
||||
explicit NumericTrie(char num_bits = 32) {
|
||||
max_level = num_bits / 8;
|
||||
}
|
||||
|
||||
~NumericTrie() {
|
||||
delete negative_trie;
|
||||
delete positive_trie;
|
||||
}
|
||||
|
||||
class iterator_t {
|
||||
struct match_state {
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
uint32_t index = 0;
|
||||
|
||||
explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {}
|
||||
|
||||
~match_state() {
|
||||
delete [] ids;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<match_state*> matches;
|
||||
|
||||
void set_seq_id();
|
||||
|
||||
public:
|
||||
|
||||
explicit iterator_t(std::vector<Node*>& matches);
|
||||
|
||||
~iterator_t() {
|
||||
for (auto& match: matches) {
|
||||
delete match;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_t& operator=(iterator_t&& obj) noexcept;
|
||||
|
||||
uint32_t seq_id = 0;
|
||||
bool is_valid = true;
|
||||
|
||||
void next();
|
||||
void skip_to(uint32_t id);
|
||||
void reset();
|
||||
};
|
||||
|
||||
void insert(const int64_t& value, const uint32_t& seq_id);
|
||||
|
||||
void remove(const int64_t& value, const uint32_t& seq_id);
|
||||
|
||||
void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id);
|
||||
|
||||
void search_geopoints(const std::vector<uint64_t>& cell_ids, std::vector<uint32_t>& geo_result_ids);
|
||||
|
||||
void delete_geopoint(const uint64_t& cell_id, uint32_t id);
|
||||
|
||||
void search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive);
|
||||
|
||||
void search_less_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_less_than(const int64_t& value, const bool& inclusive);
|
||||
|
||||
void search_greater_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_greater_than(const int64_t& value, const bool& inclusive);
|
||||
|
||||
void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_equal_to(const int64_t& value);
|
||||
};
|
@ -4,6 +4,7 @@
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include "http_client.h"
|
||||
#include "raft_server.h"
|
||||
#include "option.h"
|
||||
|
||||
|
||||
@ -12,9 +13,15 @@
|
||||
class RemoteEmbedder {
|
||||
protected:
|
||||
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers);
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
public:
|
||||
virtual Option<std::vector<float>> Embed(const std::string& text) = 0;
|
||||
virtual Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
static void init(ReplicationState* rs) {
|
||||
raft_server = rs;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@ -28,7 +28,6 @@ struct sort_fields_guard_t {
|
||||
~sort_fields_guard_t() {
|
||||
for(auto& sort_by_clause: sort_fields_std) {
|
||||
delete sort_by_clause.eval.filter_tree_root;
|
||||
|
||||
if(sort_by_clause.eval.ids) {
|
||||
delete [] sort_by_clause.eval.ids;
|
||||
sort_by_clause.eval.ids = nullptr;
|
||||
@ -1387,7 +1386,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
std::vector<std::vector<KV*>> raw_result_kvs;
|
||||
std::vector<std::vector<KV*>> override_result_kvs;
|
||||
|
||||
size_t total_found = 0;
|
||||
size_t total = 0;
|
||||
|
||||
std::vector<uint32_t> excluded_ids;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> included_ids; // ID -> position
|
||||
@ -1566,12 +1565,13 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
// for grouping we have to aggregate group set sizes to a count value
|
||||
if(group_limit) {
|
||||
total_found = search_params->groups_processed.size() + override_result_kvs.size();
|
||||
total = search_params->groups_processed.size() + override_result_kvs.size();
|
||||
} else {
|
||||
total_found = search_params->all_result_ids_len;
|
||||
total = search_params->all_result_ids_len;
|
||||
}
|
||||
|
||||
|
||||
if(search_cutoff && total_found == 0) {
|
||||
if(search_cutoff && total == 0) {
|
||||
// this can happen if other requests stopped this request from being processed
|
||||
// we should return an error so that request can be retried by client
|
||||
return Option<nlohmann::json>(408, "Request Timeout");
|
||||
@ -1692,7 +1692,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
nlohmann::json result = nlohmann::json::object();
|
||||
result["found"] = total_found;
|
||||
result["found"] = total;
|
||||
if(group_limit != 0) {
|
||||
result["found_docs"] = search_params->all_result_ids_len;
|
||||
}
|
||||
|
||||
if(exclude_fields.count("out_of") == 0) {
|
||||
result["out_of"] = num_documents.load();
|
||||
|
@ -731,6 +731,27 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
AuthManager::add_item_to_params(req_params, item, true);
|
||||
}
|
||||
|
||||
const auto preset_it = req_params.find("preset");
|
||||
|
||||
if(preset_it != req_params.end()) {
|
||||
nlohmann::json preset;
|
||||
const auto& preset_op = CollectionManager::get_instance().get_preset(preset_it->second, preset);
|
||||
|
||||
if(preset_op.ok()) {
|
||||
if(!preset.is_object()) {
|
||||
return Option<bool>(400, "Search preset is not an object.");
|
||||
}
|
||||
|
||||
for(const auto& search_item: preset.items()) {
|
||||
// overwrite = false since req params will contain embedded params and so has higher priority
|
||||
bool populated = AuthManager::add_item_to_params(req_params, search_item, false);
|
||||
if(!populated) {
|
||||
return Option<bool>(400, "One or more search parameters are malformed.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
const std::string& orig_coll_name = req_params["collection"];
|
||||
auto collection = collectionManager.get_collection(orig_coll_name);
|
||||
|
115
src/core_api.cpp
115
src/core_api.cpp
@ -15,6 +15,7 @@
|
||||
#include "lru/lru.hpp"
|
||||
#include "ratelimit_manager.h"
|
||||
#include "event_manager.h"
|
||||
#include "http_proxy.h"
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
@ -57,10 +58,8 @@ void stream_response(const std::shared_ptr<http_req>& req, const std::shared_ptr
|
||||
return ;
|
||||
}
|
||||
|
||||
if(req->_req->res.status != 0) {
|
||||
// not the first response chunk, so wait for previous chunk to finish
|
||||
res->wait();
|
||||
}
|
||||
// wait for previous chunk to finish (if any)
|
||||
res->wait();
|
||||
|
||||
auto req_res = new async_req_res_t(req, res, true);
|
||||
server->get_message_dispatcher()->send_message(HttpServer::STREAM_RESPONSE_MESSAGE, req_res);
|
||||
@ -368,29 +367,6 @@ bool get_search(const std::shared_ptr<http_req>& req, const std::shared_ptr<http
|
||||
}
|
||||
}
|
||||
|
||||
const auto preset_it = req->params.find("preset");
|
||||
|
||||
if(preset_it != req->params.end()) {
|
||||
nlohmann::json preset;
|
||||
const auto& preset_op = CollectionManager::get_instance().get_preset(preset_it->second, preset);
|
||||
|
||||
if(preset_op.ok()) {
|
||||
if(!preset.is_object()) {
|
||||
res->set_400("Search preset is not an object.");
|
||||
return false;
|
||||
}
|
||||
|
||||
for(const auto& search_item: preset.items()) {
|
||||
// overwrite = false since req params will contain embedded params and so has higher priority
|
||||
bool populated = AuthManager::add_item_to_params(req->params, search_item, false);
|
||||
if(!populated) {
|
||||
res->set_400("One or more search parameters are malformed.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(req->embedded_params_vec.empty()) {
|
||||
res->set_500("Embedded params is empty.");
|
||||
return false;
|
||||
@ -569,27 +545,6 @@ bool post_multi_search(const std::shared_ptr<http_req>& req, const std::shared_p
|
||||
}
|
||||
}
|
||||
|
||||
if(search_params.count("preset") != 0) {
|
||||
nlohmann::json preset;
|
||||
auto preset_op = CollectionManager::get_instance().get_preset(search_params["preset"].get<std::string>(),
|
||||
preset);
|
||||
if(preset_op.ok()) {
|
||||
if(!search_params.is_object()) {
|
||||
res->set_400("Search preset is not an object.");
|
||||
return false;
|
||||
}
|
||||
|
||||
for(const auto& search_item: preset.items()) {
|
||||
// overwrite = false since req params will contain embedded params and so has higher priority
|
||||
bool populated = AuthManager::add_item_to_params(req->params, search_item, false);
|
||||
if(!populated) {
|
||||
res->set_400("One or more search parameters are malformed.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string results_json_str;
|
||||
Option<bool> search_op = CollectionManager::do_search(req->params, req->embedded_params_vec[i],
|
||||
results_json_str, req->conn_ts);
|
||||
@ -774,7 +729,7 @@ bool get_export_documents(const std::shared_ptr<http_req>& req, const std::share
|
||||
}
|
||||
}
|
||||
|
||||
res->content_type_header = "application/octet-stream";
|
||||
res->content_type_header = "text/plain; charset=utf8";
|
||||
res->status_code = 200;
|
||||
|
||||
stream_response(req, res);
|
||||
@ -2136,3 +2091,65 @@ bool del_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared
|
||||
res->set_200(R"({"ok": true)");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool post_proxy(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
|
||||
HttpProxy& proxy = HttpProxy::get_instance();
|
||||
|
||||
nlohmann::json req_json;
|
||||
|
||||
try {
|
||||
req_json = nlohmann::json::parse(req->body);
|
||||
} catch(const nlohmann::json::parse_error& e) {
|
||||
LOG(ERROR) << "JSON error: " << e.what();
|
||||
res->set_400("Bad JSON.");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string body, url, method;
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
|
||||
if(req_json.count("url") == 0 || req_json.count("method") == 0) {
|
||||
res->set_400("Missing required fields.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!req_json["url"].is_string() || !req_json["method"].is_string() || req_json["url"].get<std::string>().empty() || req_json["method"].get<std::string>().empty()) {
|
||||
res->set_400("URL and method must be non-empty strings.");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
if(req_json.count("body") != 0 && !req_json["body"].is_string()) {
|
||||
res->set_400("Body must be a string.");
|
||||
return false;
|
||||
}
|
||||
if(req_json.count("headers") != 0 && !req_json["headers"].is_object()) {
|
||||
res->set_400("Headers must be a JSON object.");
|
||||
return false;
|
||||
}
|
||||
if(req_json.count("body")) {
|
||||
body = req_json["body"].get<std::string>();
|
||||
}
|
||||
url = req_json["url"].get<std::string>();
|
||||
method = req_json["method"].get<std::string>();
|
||||
if(req_json.count("headers")) {
|
||||
headers = req_json["headers"].get<std::unordered_map<std::string, std::string>>();
|
||||
}
|
||||
} catch(const std::exception& e) {
|
||||
LOG(ERROR) << "JSON error: " << e.what();
|
||||
res->set_400("Bad JSON.");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto response = proxy.send(url, method, body, headers);
|
||||
|
||||
if(response.status_code != 200) {
|
||||
int code = response.status_code;
|
||||
res->set_body(code, response.body);
|
||||
return false;
|
||||
}
|
||||
|
||||
res->set_200(response.body);
|
||||
return true;
|
||||
}
|
@ -75,6 +75,24 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::reference] = "";
|
||||
}
|
||||
|
||||
if (field_json.count(fields::range_index) != 0) {
|
||||
if (!field_json.at(fields::range_index).is_boolean()) {
|
||||
return Option<bool>(400, std::string("The `range_index` property of the field `") +
|
||||
field_json[fields::name].get<std::string>() +
|
||||
std::string("` should be a boolean."));
|
||||
}
|
||||
|
||||
auto const& type = field_json["type"];
|
||||
if (field_json[fields::range_index] &&
|
||||
type != field_types::INT32 && type != field_types::INT32_ARRAY &&
|
||||
type != field_types::INT64 && type != field_types::INT64_ARRAY &&
|
||||
type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, std::string("The `range_index` property is only allowed for the numerical fields`"));
|
||||
}
|
||||
} else {
|
||||
field_json[fields::range_index] = false;
|
||||
}
|
||||
|
||||
if(field_json["name"] == ".*") {
|
||||
if(field_json.count(fields::facet) == 0) {
|
||||
field_json[fields::facet] = false;
|
||||
@ -297,7 +315,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
|
||||
field_json[fields::reference], field_json[fields::embed])
|
||||
field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index])
|
||||
);
|
||||
|
||||
if (!field_json[fields::reference].get<std::string>().empty()) {
|
||||
|
@ -422,7 +422,10 @@ Option<bool> toFilter(const std::string expression,
|
||||
id_comparator = EQUALS;
|
||||
while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' ');
|
||||
} else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') {
|
||||
return Option<bool>(400, "Not equals filtering is not supported on the `id` field.");
|
||||
id_comparator = NOT_EQUALS;
|
||||
filter_exp.apply_not_equals = true;
|
||||
filter_value_index++;
|
||||
while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' ');
|
||||
}
|
||||
if (filter_value_index != 0) {
|
||||
raw_value = raw_value.substr(filter_value_index);
|
||||
|
@ -401,6 +401,22 @@ void filter_result_iterator_t::next() {
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
if (++result_index >= filter_result.count) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
reference.clear();
|
||||
for (auto const& item: filter_result.reference_filter_results) {
|
||||
reference[item.first] = item.second[result_index];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Advance the subtrees and then apply operators to arrive at the next valid doc.
|
||||
if (filter_node->filter_operator == AND) {
|
||||
@ -423,21 +439,6 @@ void filter_result_iterator_t::next() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_filter_result_initialized) {
|
||||
if (++result_index >= filter_result.count) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
reference.clear();
|
||||
for (auto const& item: filter_result.reference_filter_results) {
|
||||
reference[item.first] = item.second[result_index];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const filter a_filter = filter_node->filter_exp;
|
||||
|
||||
if (!index->field_is_indexed(a_filter.field_name)) {
|
||||
@ -619,11 +620,6 @@ void filter_result_iterator_t::init() {
|
||||
}
|
||||
|
||||
if (a_filter.field_name == "id") {
|
||||
if (a_filter.values.empty()) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// we handle `ids` separately
|
||||
std::vector<uint32_t> result_ids;
|
||||
for (const auto& id_str : a_filter.values) {
|
||||
@ -636,6 +632,16 @@ void filter_result_iterator_t::init() {
|
||||
filter_result.docs = new uint32_t[result_ids.size()];
|
||||
std::copy(result_ids.begin(), result_ids.end(), filter_result.docs);
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
|
||||
if (filter_result.count == 0) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
is_filter_result_initialized = true;
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
@ -650,27 +656,62 @@ void filter_result_iterator_t::init() {
|
||||
field f = index->search_schema.at(a_filter.field_name);
|
||||
|
||||
if (f.is_integer()) {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
if (f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t)std::stol(filter_value);
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
auto const& value = (int32_t)std::stoi(filter_value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const& range_end_value = (int32_t)std::stoi(next_filter_value);
|
||||
trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == EQUALS) {
|
||||
trie->search_equal_to(value, filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
uint32_t to_exclude_ids_len = 0;
|
||||
trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = index->seq_ids->uncompress();
|
||||
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
|
||||
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
|
||||
trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
|
||||
trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
|
||||
filter_result.count = result_size;
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t)std::stol(filter_value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
|
||||
}
|
||||
|
||||
filter_result.count = result_size;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
@ -688,28 +729,64 @@ void filter_result_iterator_t::init() {
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
return;
|
||||
} else if (f.is_float()) {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
if (f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, float_int64,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == EQUALS) {
|
||||
trie->search_equal_to(float_int64, filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
uint32_t to_exclude_ids_len = 0;
|
||||
trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = index->seq_ids->uncompress();
|
||||
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
|
||||
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
|
||||
trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
|
||||
trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
|
||||
filter_result.count = result_size;
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, float_int64,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
|
||||
}
|
||||
|
||||
filter_result.count = result_size;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
@ -821,17 +898,15 @@ void filter_result_iterator_t::init() {
|
||||
S2RegionTermIndexer::Options options;
|
||||
options.set_index_contains_points_only(true);
|
||||
S2RegionTermIndexer indexer(options);
|
||||
auto const& geo_range_index = index->geo_range_index.at(a_filter.field_name);
|
||||
|
||||
std::vector<uint64_t> cell_ids;
|
||||
for (const auto& term : indexer.GetQueryTerms(*query_region, "")) {
|
||||
auto geo_index = index->geopoint_index.at(a_filter.field_name);
|
||||
const auto& ids_it = geo_index->find(term);
|
||||
if(ids_it != geo_index->end()) {
|
||||
geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end());
|
||||
}
|
||||
auto cell = S2CellId::FromToken(term);
|
||||
cell_ids.push_back(cell.id());
|
||||
}
|
||||
|
||||
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
|
||||
geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end());
|
||||
geo_range_index->search_geopoints(cell_ids, geo_result_ids);
|
||||
|
||||
// Skip exact filtering step if query radius is greater than the threshold.
|
||||
if (fi < a_filter.params.size() &&
|
||||
@ -955,20 +1030,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Skip the subtrees to id and then apply operators to arrive at the next valid doc.
|
||||
left_it->skip_to(id);
|
||||
right_it->skip_to(id);
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id);
|
||||
|
||||
@ -986,6 +1048,20 @@ void filter_result_iterator_t::skip_to(uint32_t id) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Skip the subtrees to id and then apply operators to arrive at the next valid doc.
|
||||
left_it->skip_to(id);
|
||||
right_it->skip_to(id);
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const filter a_filter = filter_node->filter_exp;
|
||||
|
||||
if (!index->field_is_indexed(a_filter.field_name)) {
|
||||
@ -1068,6 +1144,12 @@ int filter_result_iterator_t::valid(uint32_t id) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
skip_to(id);
|
||||
return is_valid ? (seq_id == id ? 1 : 0) : -1;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
auto left_valid = left_it->valid(id), right_valid = right_it->valid(id);
|
||||
|
||||
@ -1181,21 +1263,7 @@ void filter_result_iterator_t::reset() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Reset the subtrees then apply operators to arrive at the first valid doc.
|
||||
left_it->reset();
|
||||
right_it->reset();
|
||||
is_valid = true;
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
if (filter_result.count == 0) {
|
||||
is_valid = false;
|
||||
@ -1214,6 +1282,21 @@ void filter_result_iterator_t::reset() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Reset the subtrees then apply operators to arrive at the first valid doc.
|
||||
left_it->reset();
|
||||
right_it->reset();
|
||||
is_valid = true;
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const filter a_filter = filter_node->filter_exp;
|
||||
|
||||
if (!index->field_is_indexed(a_filter.field_name)) {
|
||||
@ -1459,3 +1542,136 @@ void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_
|
||||
root_iterator->seq_id = left_it->seq_id;
|
||||
filter_result_iterator = root_iterator;
|
||||
}
|
||||
|
||||
void filter_result_iterator_t::compute_result() {
|
||||
if (filter_node->isOperator) {
|
||||
left_it->compute_result();
|
||||
right_it->compute_result();
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
filter_result_t::and_filter_results(left_it->filter_result, right_it->filter_result, filter_result);
|
||||
} else {
|
||||
filter_result_t::or_filter_results(left_it->filter_result, right_it->filter_result, filter_result);
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
is_filter_result_initialized = true;
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
return;
|
||||
}
|
||||
|
||||
// Only string field filter needs to be evaluated.
|
||||
if (is_filter_result_initialized || index->search_index.count(filter_node->filter_exp.field_name) == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto const& a_filter = filter_node->filter_exp;
|
||||
auto const& f = index->search_schema.at(a_filter.field_name);
|
||||
art_tree* t = index->search_index.at(a_filter.field_name);
|
||||
|
||||
uint32_t* or_ids = nullptr;
|
||||
size_t or_ids_size = 0;
|
||||
|
||||
// aggregates IDs across array of filter values and reduces excessive ORing
|
||||
std::vector<uint32_t> f_id_buff;
|
||||
|
||||
for (const std::string& filter_value : a_filter.values) {
|
||||
std::vector<void*> posting_lists;
|
||||
|
||||
// there could be multiple tokens in a filter value, which we have to treat as ANDs
|
||||
// e.g. country: South Africa
|
||||
Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators);
|
||||
|
||||
std::string str_token;
|
||||
size_t token_index = 0;
|
||||
std::vector<std::string> str_tokens;
|
||||
|
||||
while (tokenizer.next(str_token, token_index)) {
|
||||
str_tokens.push_back(str_token);
|
||||
|
||||
art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(),
|
||||
str_token.length()+1);
|
||||
if (leaf == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
posting_lists.push_back(leaf->values);
|
||||
}
|
||||
|
||||
if (posting_lists.size() != str_tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) {
|
||||
// needs intersection + exact matching (unlike CONTAINS)
|
||||
std::vector<uint32_t> result_id_vec;
|
||||
posting_t::intersect(posting_lists, result_id_vec);
|
||||
|
||||
if (result_id_vec.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// need to do exact match
|
||||
uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()];
|
||||
size_t exact_str_ids_size = 0;
|
||||
std::unique_ptr<uint32_t[]> exact_str_ids_guard(exact_str_ids);
|
||||
|
||||
posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(),
|
||||
exact_str_ids, exact_str_ids_size);
|
||||
|
||||
if (exact_str_ids_size == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t ei = 0; ei < exact_str_ids_size; ei++) {
|
||||
f_id_buff.push_back(exact_str_ids[ei]);
|
||||
}
|
||||
} else {
|
||||
// CONTAINS
|
||||
size_t before_size = f_id_buff.size();
|
||||
posting_t::intersect(posting_lists, f_id_buff);
|
||||
if (f_id_buff.size() == before_size) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) {
|
||||
gfx::timsort(f_id_buff.begin(), f_id_buff.end());
|
||||
f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out);
|
||||
delete[] or_ids;
|
||||
or_ids = out;
|
||||
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
|
||||
}
|
||||
}
|
||||
|
||||
if (!f_id_buff.empty()) {
|
||||
gfx::timsort(f_id_buff.begin(), f_id_buff.end());
|
||||
f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out);
|
||||
delete[] or_ids;
|
||||
or_ids = out;
|
||||
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
|
||||
}
|
||||
|
||||
filter_result.docs = or_ids;
|
||||
filter_result.count = or_ids_size;
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), filter_result.docs, filter_result.count);
|
||||
}
|
||||
|
||||
if (filter_result.count == 0) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
result_index = 0;
|
||||
seq_id = filter_result.docs[result_index];
|
||||
is_filter_result_initialized = true;
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
}
|
||||
|
46
src/http_proxy.cpp
Normal file
46
src/http_proxy.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
#include "http_proxy.h"
|
||||
#include "logger.h"
|
||||
#include <chrono>
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
HttpProxy::HttpProxy() : cache(30s){
|
||||
}
|
||||
|
||||
|
||||
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map<std::string, std::string>& headers) {
|
||||
// check if url is in cache
|
||||
uint64_t key = StringUtils::hash_wy(url.c_str(), url.size());
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size()));
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size()));
|
||||
for (auto& header : headers) {
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size()));
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size()));
|
||||
}
|
||||
if (cache.contains(key)) {
|
||||
return cache[key];
|
||||
}
|
||||
// if not, make http request
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
http_proxy_res_t res;
|
||||
|
||||
if(method == "GET") {
|
||||
res.status_code = client.get_response(url, res.body, res.headers, headers, 30 * 1000);
|
||||
} else if(method == "POST") {
|
||||
res.status_code = client.post_response(url, body, res.body, res.headers, headers, 30 * 1000);
|
||||
} else if(method == "PUT") {
|
||||
res.status_code = client.put_response(url, body, res.body, res.headers, 30 * 1000);
|
||||
} else if(method == "DELETE") {
|
||||
res.status_code = client.delete_response(url, res.body, res.headers, 30 * 1000);
|
||||
} else {
|
||||
res.status_code = 400;
|
||||
nlohmann::json j;
|
||||
j["message"] = "Parameter `method` must be one of GET, POST, PUT, DELETE.";
|
||||
res.body = j.dump();
|
||||
}
|
||||
|
||||
// add to cache
|
||||
cache.insert(key, res);
|
||||
|
||||
return res;
|
||||
}
|
@ -632,6 +632,9 @@ int HttpServer::async_req_cb(void *ctx, int is_end_stream) {
|
||||
|
||||
if(request->first_chunk_aggregate) {
|
||||
request->first_chunk_aggregate = false;
|
||||
|
||||
// ensures that the first response need not wait for previous chunk to be done sending
|
||||
response->notify();
|
||||
}
|
||||
|
||||
// default value for last_chunk_aggregate is false
|
||||
@ -803,6 +806,13 @@ void HttpServer::stream_response(stream_response_state_t& state) {
|
||||
|
||||
h2o_req_t* req = state.get_req();
|
||||
|
||||
if(state.is_res_start) {
|
||||
h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, NULL,
|
||||
state.res_content_type.base, state.res_content_type.len);
|
||||
req->res.status = state.status;
|
||||
req->res.reason = state.reason;
|
||||
}
|
||||
|
||||
if(state.is_req_early_exit) {
|
||||
// premature termination of async request: handle this explicitly as otherwise, request is not being closed
|
||||
LOG(INFO) << "Premature termination of async request.";
|
||||
|
139
src/index.cpp
139
src/index.cpp
@ -69,8 +69,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
art_tree_init(t);
|
||||
search_index.emplace(a_field.name, t);
|
||||
} else if(a_field.is_geopoint()) {
|
||||
auto field_geo_index = new spp::sparse_hash_map<std::string, std::vector<uint32_t>>();
|
||||
geopoint_index.emplace(a_field.name, field_geo_index);
|
||||
geo_range_index.emplace(a_field.name, new NumericTrie());
|
||||
|
||||
if(!a_field.is_single_geopoint()) {
|
||||
spp::sparse_hash_map<uint32_t, int64_t*> * doc_to_geos = new spp::sparse_hash_map<uint32_t, int64_t*>();
|
||||
@ -79,6 +78,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
} else {
|
||||
num_tree_t* num_tree = new num_tree_t;
|
||||
numerical_index.emplace(a_field.name, num_tree);
|
||||
|
||||
if (a_field.range_index) {
|
||||
auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64);
|
||||
range_index.emplace(a_field.name, trie);
|
||||
}
|
||||
}
|
||||
|
||||
if(a_field.sort) {
|
||||
@ -127,12 +131,12 @@ Index::~Index() {
|
||||
|
||||
search_index.clear();
|
||||
|
||||
for(auto & name_index: geopoint_index) {
|
||||
for(auto & name_index: geo_range_index) {
|
||||
delete name_index.second;
|
||||
name_index.second = nullptr;
|
||||
}
|
||||
|
||||
geopoint_index.clear();
|
||||
geo_range_index.clear();
|
||||
|
||||
for(auto& name_index: geo_array_index) {
|
||||
for(auto& kv: *name_index.second) {
|
||||
@ -152,6 +156,13 @@ Index::~Index() {
|
||||
|
||||
numerical_index.clear();
|
||||
|
||||
for(auto & name_tree: range_index) {
|
||||
delete name_tree.second;
|
||||
name_tree.second = nullptr;
|
||||
}
|
||||
|
||||
range_index.clear();
|
||||
|
||||
for(auto & name_map: sort_index) {
|
||||
delete name_map.second;
|
||||
name_map.second = nullptr;
|
||||
@ -484,7 +495,7 @@ void Index::validate_and_preprocess(Index *index,
|
||||
index_rec.index_failure(400, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
auto embed_op = batch_embed_fields(docs_to_embed, embedding_fields, search_schema);
|
||||
if(!embed_op.ok()) {
|
||||
for(size_t i = 0; i < batch_size; i++) {
|
||||
@ -783,6 +794,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
if(!afield.is_string()) {
|
||||
if (afield.type == field_types::INT32) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
int32_t value = record.doc[afield.name].get<int32_t>();
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -792,6 +812,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::INT64) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
int64_t value = record.doc[afield.name].get<int64_t>();
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -801,6 +830,16 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::FLOAT) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
float fvalue = record.doc[afield.name].get<float>();
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -816,10 +855,10 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
num_tree->insert(value, seq_id, afield.is_facet());
|
||||
});
|
||||
} else if(afield.type == field_types::GEOPOINT || afield.type == field_types::GEOPOINT_ARRAY) {
|
||||
auto geo_index = geopoint_index.at(afield.name);
|
||||
auto geopoint_range_index = geo_range_index.at(afield.name);
|
||||
|
||||
iterate_and_index_numerical_field(iter_batch, afield,
|
||||
[&afield, &geo_array_index=geo_array_index, geo_index](const index_record& record, uint32_t seq_id) {
|
||||
[&afield, &geo_array_index=geo_array_index, geopoint_range_index](const index_record& record, uint32_t seq_id) {
|
||||
// nested geopoint value inside an array of object will be a simple array so must be treated as geopoint
|
||||
bool nested_obj_arr_geopoint = (afield.nested && afield.type == field_types::GEOPOINT_ARRAY &&
|
||||
!record.doc[afield.name].empty() && record.doc[afield.name][0].is_number());
|
||||
@ -833,9 +872,8 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
S2RegionTermIndexer indexer(options);
|
||||
S2Point point = S2LatLng::FromDegrees(latlongs[li], latlongs[li+1]).ToPoint();
|
||||
|
||||
for(const auto& term: indexer.GetIndexTerms(point, "")) {
|
||||
(*geo_index)[term].push_back(seq_id);
|
||||
}
|
||||
auto cell = S2CellId(point);
|
||||
geopoint_range_index->insert_geopoint(cell.id(), seq_id);
|
||||
}
|
||||
|
||||
if(nested_obj_arr_geopoint) {
|
||||
@ -863,9 +901,9 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
for(size_t li = 0; li < latlongs.size(); li++) {
|
||||
auto& latlong = latlongs[li];
|
||||
S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint();
|
||||
for(const auto& term: indexer.GetIndexTerms(point, "")) {
|
||||
(*geo_index)[term].push_back(seq_id);
|
||||
}
|
||||
|
||||
auto cell = S2CellId(point);
|
||||
geopoint_range_index->insert_geopoint(cell.id(), seq_id);
|
||||
|
||||
int64_t packed_latlong = GeoPoint::pack_lat_lng(latlong[0], latlong[1]);
|
||||
packed_latlongs[li + 1] = packed_latlong;
|
||||
@ -945,7 +983,8 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
// all other numerical arrays
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr;
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) {
|
||||
const auto& arr_value = record.doc[afield.name][arr_i];
|
||||
@ -953,17 +992,26 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
if(afield.type == field_types::INT32_ARRAY) {
|
||||
const int32_t value = arr_value;
|
||||
num_tree->insert(value, seq_id, afield.is_facet());
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::INT64_ARRAY) {
|
||||
const int64_t value = arr_value;
|
||||
num_tree->insert(value, seq_id, afield.is_facet());
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::FLOAT_ARRAY) {
|
||||
const float fvalue = arr_value;
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
num_tree->insert(value, seq_id, afield.is_facet());
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::BOOL_ARRAY) {
|
||||
@ -1628,7 +1676,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
|
||||
bool Index::field_is_indexed(const std::string& field_name) const {
|
||||
return search_index.count(field_name) != 0 ||
|
||||
numerical_index.count(field_name) != 0 ||
|
||||
geopoint_index.count(field_name) != 0;
|
||||
geo_range_index.count(field_name) != 0;
|
||||
}
|
||||
|
||||
void Index::aproximate_numerical_match(num_tree_t* const num_tree,
|
||||
@ -4597,7 +4645,9 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
|
||||
const std::vector<size_t>& geopoint_indices) const {
|
||||
|
||||
filter_result_iterator->compute_result();
|
||||
auto const& approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length;
|
||||
|
||||
uint32_t token_bits = 0;
|
||||
const bool check_for_circuit_break = (approx_filter_ids_length > 1000000);
|
||||
|
||||
@ -5459,6 +5509,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
std::vector<int32_t>{document[field_name].get<int32_t>()} :
|
||||
document[field_name].get<std::vector<int32_t>>();
|
||||
for(int32_t value: values) {
|
||||
if (search_field.range_index) {
|
||||
auto const& trie = range_index.at(search_field.name);
|
||||
trie->remove(value, seq_id);
|
||||
}
|
||||
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
num_tree->remove(value, seq_id);
|
||||
if(search_field.facet) {
|
||||
@ -5470,6 +5525,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
std::vector<int64_t>{document[field_name].get<int64_t>()} :
|
||||
document[field_name].get<std::vector<int64_t>>();
|
||||
for(int64_t value: values) {
|
||||
if (search_field.range_index) {
|
||||
auto const& trie = range_index.at(search_field.name);
|
||||
trie->remove(value, seq_id);
|
||||
}
|
||||
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
num_tree->remove(value, seq_id);
|
||||
if(search_field.facet) {
|
||||
@ -5484,8 +5544,14 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
document[field_name].get<std::vector<float>>();
|
||||
|
||||
for(float value: values) {
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
int64_t fintval = float_to_int64_t(value);
|
||||
|
||||
if (search_field.range_index) {
|
||||
auto const& trie = range_index.at(search_field.name);
|
||||
trie->remove(fintval, seq_id);
|
||||
}
|
||||
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
num_tree->remove(fintval, seq_id);
|
||||
if(search_field.facet) {
|
||||
remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id);
|
||||
@ -5505,7 +5571,7 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
}
|
||||
}
|
||||
} else if(search_field.is_geopoint()) {
|
||||
auto geo_index = geopoint_index[field_name];
|
||||
auto geopoint_range_index = geo_range_index[field_name];
|
||||
S2RegionTermIndexer::Options options;
|
||||
options.set_index_contains_points_only(true);
|
||||
S2RegionTermIndexer indexer(options);
|
||||
@ -5516,17 +5582,8 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
|
||||
for(const std::vector<double>& latlong: latlongs) {
|
||||
S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint();
|
||||
for(const auto& term: indexer.GetIndexTerms(point, "")) {
|
||||
auto term_it = geo_index->find(term);
|
||||
if(term_it == geo_index->end()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<uint32_t>& ids = term_it->second;
|
||||
ids.erase(std::remove(ids.begin(), ids.end(), seq_id), ids.end());
|
||||
if(ids.empty()) {
|
||||
geo_index->erase(term);
|
||||
}
|
||||
}
|
||||
auto cell = S2CellId(point);
|
||||
geopoint_range_index->delete_geopoint(cell.id(), seq_id);
|
||||
}
|
||||
|
||||
if(!search_field.is_single_geopoint()) {
|
||||
@ -5664,8 +5721,7 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
art_tree_init(t);
|
||||
search_index.emplace(new_field.name, t);
|
||||
} else if(new_field.is_geopoint()) {
|
||||
auto field_geo_index = new spp::sparse_hash_map<std::string, std::vector<uint32_t>>();
|
||||
geopoint_index.emplace(new_field.name, field_geo_index);
|
||||
geo_range_index.emplace(new_field.name, new NumericTrie());
|
||||
if(!new_field.is_single_geopoint()) {
|
||||
auto geo_array_map = new spp::sparse_hash_map<uint32_t, int64_t*>();
|
||||
geo_array_index.emplace(new_field.name, geo_array_map);
|
||||
@ -5673,6 +5729,10 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
} else {
|
||||
num_tree_t* num_tree = new num_tree_t;
|
||||
numerical_index.emplace(new_field.name, num_tree);
|
||||
|
||||
if (new_field.range_index) {
|
||||
range_index.emplace(new_field.name, new NumericTrie(new_field.is_int32() ? 32 : 64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5715,8 +5775,8 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
delete search_index[del_field.name];
|
||||
search_index.erase(del_field.name);
|
||||
} else if(del_field.is_geopoint()) {
|
||||
delete geopoint_index[del_field.name];
|
||||
geopoint_index.erase(del_field.name);
|
||||
delete geo_range_index[del_field.name];
|
||||
geo_range_index.erase(del_field.name);
|
||||
|
||||
if(!del_field.is_single_geopoint()) {
|
||||
spp::sparse_hash_map<uint32_t, int64_t*>* geo_array_map = geo_array_index[del_field.name];
|
||||
@ -5729,6 +5789,11 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
} else {
|
||||
delete numerical_index[del_field.name];
|
||||
numerical_index.erase(del_field.name);
|
||||
|
||||
if (del_field.range_index) {
|
||||
delete range_index[del_field.name];
|
||||
range_index.erase(del_field.name);
|
||||
}
|
||||
}
|
||||
|
||||
if(del_field.is_sortable()) {
|
||||
@ -6105,8 +6170,15 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
}
|
||||
}
|
||||
}
|
||||
texts_to_embed.push_back(std::make_pair(document, text));
|
||||
if(!text.empty()) {
|
||||
texts_to_embed.push_back(std::make_pair(document, text));
|
||||
}
|
||||
}
|
||||
|
||||
if(texts_to_embed.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]);
|
||||
|
||||
@ -6128,7 +6200,6 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
}
|
||||
|
||||
auto embedding_op = embedder_op.get()->batch_embed(texts);
|
||||
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<bool>(400, embedding_op.error());
|
||||
}
|
||||
|
@ -98,6 +98,9 @@ void master_server_routes() {
|
||||
server->del("/limits/active/:id", del_throttle);
|
||||
server->del("/limits/exceeds/:id", del_exceed);
|
||||
server->post("/config", post_config, false, false);
|
||||
|
||||
// for proxying remote embedders
|
||||
server->post("/proxy", post_proxy);
|
||||
}
|
||||
|
||||
void (*backward::SignalHandling::_callback)(int sig, backward::StackTrace&) = nullptr;
|
||||
@ -144,6 +147,7 @@ int main(int argc, char **argv) {
|
||||
Option<bool> config_validitation = config.is_valid();
|
||||
|
||||
if(!config_validitation.ok()) {
|
||||
std::cerr << "Typesense " << TYPESENSE_VERSION << std::endl;
|
||||
std::cerr << "Invalid configuration: " << config_validitation.error() << std::endl;
|
||||
std::cerr << "Command line " << options.usage() << std::endl;
|
||||
std::cerr << "You can also pass these arguments as environment variables such as "
|
||||
|
908
src/numeric_range_trie.cpp
Normal file
908
src/numeric_range_trie.cpp
Normal file
@ -0,0 +1,908 @@
|
||||
#include <timsort.hpp>
|
||||
#include <set>
|
||||
#include "numeric_range_trie_test.h"
|
||||
#include "array_utils.h"
|
||||
|
||||
void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) {
|
||||
if (value < 0) {
|
||||
if (negative_trie == nullptr) {
|
||||
negative_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
negative_trie->insert(std::abs(value), seq_id, max_level);
|
||||
} else {
|
||||
if (positive_trie == nullptr) {
|
||||
positive_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
positive_trie->insert(value, seq_id, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::remove(const int64_t& value, const uint32_t& seq_id) {
|
||||
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
negative_trie->remove(std::abs(value), seq_id, max_level);
|
||||
} else {
|
||||
positive_trie->remove(value, seq_id, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id) {
|
||||
if (positive_trie == nullptr) {
|
||||
positive_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
positive_trie->insert_geopoint(cell_id, seq_id, max_level);
|
||||
}
|
||||
|
||||
void NumericTrie::search_geopoints(const std::vector<uint64_t>& cell_ids, std::vector<uint32_t>& geo_result_ids) {
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
positive_trie->search_geopoints(cell_ids, max_level, geo_result_ids);
|
||||
}
|
||||
|
||||
void NumericTrie::delete_geopoint(const uint64_t& cell_id, uint32_t id) {
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
positive_trie->delete_geopoint(cell_id, id, max_level);
|
||||
}
|
||||
|
||||
void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (low < 0 && high >= 0) {
|
||||
// Have to combine the results of >low from negative_trie and <high from positive_trie
|
||||
|
||||
if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ...
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
auto abs_low = std::abs(low);
|
||||
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
} else if (low >= 0) {
|
||||
// Search only in positive_trie
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
} else {
|
||||
// Search only in negative_trie
|
||||
if (negative_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
// Since we store absolute values, switching low and high would produce the correct result.
|
||||
auto abs_high = std::abs(high), abs_low = std::abs(low);
|
||||
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
|
||||
max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
if (low > high) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (low < 0 && high >= 0) {
|
||||
// Have to combine the results of >low from negative_trie and <high from positive_trie
|
||||
|
||||
if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ...
|
||||
auto abs_low = std::abs(low);
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches);
|
||||
}
|
||||
|
||||
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches);
|
||||
}
|
||||
} else if (low >= 0) {
|
||||
// Search only in positive_trie
|
||||
if (positive_trie == nullptr) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches);
|
||||
} else {
|
||||
// Search only in negative_trie
|
||||
if (negative_trie == nullptr) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
auto abs_high = std::abs(high), abs_low = std::abs(low);
|
||||
// Since we store absolute values, switching low and high would produce the correct result.
|
||||
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
|
||||
max_level, matches);
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->get_all_ids(positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (value >= 0) {
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
} else {
|
||||
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
|
||||
|
||||
if (negative_trie != nullptr) {
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
auto abs_low = std::abs(value);
|
||||
|
||||
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->get_all_ids(positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
matches.push_back(positive_trie);
|
||||
}
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (value >= 0) {
|
||||
if (positive_trie != nullptr) {
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches);
|
||||
}
|
||||
} else {
|
||||
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
|
||||
if (negative_trie != nullptr) {
|
||||
auto abs_low = std::abs(value);
|
||||
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches);
|
||||
}
|
||||
if (positive_trie != nullptr) {
|
||||
matches.push_back(positive_trie);
|
||||
}
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
negative_trie->get_all_ids(negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
if (negative_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
auto abs_low = std::abs(value);
|
||||
|
||||
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
} else {
|
||||
// Have to combine the results of <value from positive_trie and all the ids in negative_trie
|
||||
|
||||
if (positive_trie != nullptr) {
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
if (negative_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
negative_trie->get_all_ids(negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
matches.push_back(negative_trie);
|
||||
}
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
if (negative_trie != nullptr) {
|
||||
auto abs_low = std::abs(value);
|
||||
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches);
|
||||
}
|
||||
} else {
|
||||
// Have to combine the results of <value from positive_trie and all the ids in negative_trie
|
||||
if (positive_trie != nullptr) {
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, max_level, matches);
|
||||
}
|
||||
if (negative_trie != nullptr) {
|
||||
matches.push_back(negative_trie);
|
||||
}
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* equal_ids = nullptr;
|
||||
uint32_t equal_ids_length = 0;
|
||||
|
||||
if (value < 0) {
|
||||
negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length);
|
||||
} else {
|
||||
positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length);
|
||||
}
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(equal_ids, equal_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] equal_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) {
|
||||
std::vector<Node*> matches;
|
||||
if (value < 0 && negative_trie != nullptr) {
|
||||
negative_trie->search_equal_to(std::abs(value), max_level, matches);
|
||||
} else if (value >= 0 && positive_trie != nullptr) {
|
||||
positive_trie->search_equal_to(value, max_level, matches);
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level) {
|
||||
char level = 0;
|
||||
return insert_helper(cell_id, seq_id, level, max_level);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level) {
|
||||
char level = 0;
|
||||
return insert_geopoint_helper(cell_id, seq_id, level, max_level);
|
||||
}
|
||||
|
||||
inline int get_index(const int64_t& value, const char& level, const char& max_level) {
|
||||
// Values are index considering higher order of the bytes first.
|
||||
// 0x01020408 (16909320) would be indexed in the trie as follows:
|
||||
// Level Index
|
||||
// 1 1
|
||||
// 2 2
|
||||
// 3 4
|
||||
// 4 8
|
||||
return (value >> (8 * (max_level - level))) & 0xFF;
|
||||
}
|
||||
|
||||
inline int get_geopoint_index(const uint64_t& cell_id, const char& level) {
|
||||
// Doing 8-level since cell_id is a 64 bit number.
|
||||
return (cell_id >> (8 * (8 - level))) & 0xFF;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::remove(const int64_t& value, const uint32_t& id, const char& max_level) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level < max_level) {
|
||||
root->seq_ids.remove_value(id);
|
||||
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
root->seq_ids.remove_value(id);
|
||||
if (root->children != nullptr && root->children[index] != nullptr) {
|
||||
auto& child = root->children[index];
|
||||
|
||||
child->seq_ids.remove_value(id);
|
||||
if (child->seq_ids.getLength() == 0) {
|
||||
delete child;
|
||||
child = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) {
|
||||
if (level > max_level) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Root node contains all the sequence ids present in the tree.
|
||||
if (!seq_ids.contains(seq_id)) {
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
if (++level <= max_level) {
|
||||
if (children == nullptr) {
|
||||
children = new NumericTrie::Node* [EXPANSE]{nullptr};
|
||||
}
|
||||
|
||||
auto index = get_index(value, level, max_level);
|
||||
if (children[index] == nullptr) {
|
||||
children[index] = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
return children[index]->insert_helper(value, seq_id, level, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level,
|
||||
const char& max_level) {
|
||||
if (level > max_level) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Root node contains all the sequence ids present in the tree.
|
||||
if (!seq_ids.contains(seq_id)) {
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
if (++level <= max_level) {
|
||||
if (children == nullptr) {
|
||||
children = new NumericTrie::Node* [EXPANSE]{nullptr};
|
||||
}
|
||||
|
||||
auto index = get_geopoint_index(cell_id, level);
|
||||
if (children[index] == nullptr) {
|
||||
children[index] = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
return children[index]->insert_geopoint_helper(cell_id, seq_id, level, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
char get_max_search_level(const uint64_t& cell_id, const char& max_level) {
|
||||
// For cell id 0x47E66C3000000000, we only have to prefix match the top four bytes since rest of the bytes are 0.
|
||||
// So the max search level would be 4 in this case.
|
||||
|
||||
auto mask = (uint64_t) 0xFF << (8 * (8 - max_level)); // We're only indexing top 8-max_level bytes.
|
||||
char i = max_level;
|
||||
while (((cell_id & mask) == 0) && --i > 0) {
|
||||
mask <<= 8;
|
||||
}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level,
|
||||
std::set<Node*>& matches) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_geopoint_index(cell_id, level);
|
||||
auto max_search_level = get_max_search_level(cell_id, max_index_level);
|
||||
|
||||
while (level < max_search_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_geopoint_index(cell_id, ++level);
|
||||
}
|
||||
|
||||
matches.insert(root);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_geopoints(const std::vector<uint64_t>& cell_ids, const char& max_level,
|
||||
std::vector<uint32_t>& geo_result_ids) {
|
||||
std::set<Node*> matches;
|
||||
for (const auto &cell_id: cell_ids) {
|
||||
search_geopoints_helper(cell_id, max_level, matches);
|
||||
}
|
||||
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
geo_result_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
|
||||
geo_result_ids.erase(unique(geo_result_ids.begin(), geo_result_ids.end()), geo_result_ids.end());
|
||||
}
|
||||
|
||||
void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_geopoint_index(cell_id, level);
|
||||
|
||||
while (level < max_level) {
|
||||
root->seq_ids.remove_value(id);
|
||||
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_geopoint_index(cell_id, ++level);
|
||||
}
|
||||
|
||||
root->seq_ids.remove_value(id);
|
||||
if (root->children != nullptr && root->children[index] != nullptr) {
|
||||
auto& child = root->children[index];
|
||||
|
||||
child->seq_ids.remove_value(id);
|
||||
if (child->seq_ids.getLength() == 0) {
|
||||
delete child;
|
||||
child = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) {
|
||||
ids = seq_ids.uncompress();
|
||||
ids_length = seq_ids.getLength();
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_less_than_helper(value, level, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
consolidated_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
ids, ids_length, &out);
|
||||
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 0;
|
||||
search_less_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
} else if (level > max_level || children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level, max_level);
|
||||
if (children[index] != nullptr) {
|
||||
children[index]->search_less_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
while (--index >= 0) {
|
||||
if (children[index] != nullptr) {
|
||||
matches.push_back(children[index]);
|
||||
}
|
||||
}
|
||||
|
||||
--level;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_range_helper(low, high, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
consolidated_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
ids, ids_length, &out);
|
||||
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
|
||||
search_range_helper(low, high, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
// Segregating the nodes into matching low, in-between, and matching high.
|
||||
|
||||
NumericTrie::Node* root = this;
|
||||
char level = 1;
|
||||
auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level);
|
||||
|
||||
// Keep updating the root while the range is contained within a single child node.
|
||||
while (root->children != nullptr && low_index == high_index && level < max_level) {
|
||||
if (root->children[low_index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[low_index];
|
||||
level++;
|
||||
low_index = get_index(low, level, max_level);
|
||||
high_index = get_index(high, level, max_level);
|
||||
}
|
||||
|
||||
if (root->children == nullptr) {
|
||||
return;
|
||||
} else if (low_index == high_index) { // low and high are equal
|
||||
if (root->children[low_index] != nullptr) {
|
||||
matches.push_back(root->children[low_index]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (root->children[low_index] != nullptr) {
|
||||
// Collect all the sub-nodes that are greater than low.
|
||||
root->children[low_index]->search_greater_than_helper(low, level, max_level, matches);
|
||||
}
|
||||
|
||||
auto index = low_index + 1;
|
||||
// All the nodes in-between low and high are a match by default.
|
||||
while (index < std::min(high_index, (int)EXPANSE)) {
|
||||
if (root->children[index] != nullptr) {
|
||||
matches.push_back(root->children[index]);
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
|
||||
if (index < EXPANSE && index == high_index && root->children[index] != nullptr) {
|
||||
// Collect all the sub-nodes that are lesser than high.
|
||||
root->children[index]->search_less_than_helper(high, level, max_level, matches);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_greater_than_helper(value, level, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
consolidated_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
ids, ids_length, &out);
|
||||
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 0;
|
||||
search_greater_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
} else if (level > max_level || children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level, max_level);
|
||||
if (children[index] != nullptr) {
|
||||
children[index]->search_greater_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
while (++index < EXPANSE) {
|
||||
if (children[index] != nullptr) {
|
||||
matches.push_back(children[index]);
|
||||
}
|
||||
}
|
||||
|
||||
--level;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level <= max_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
root->get_all_ids(ids, ids_length);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level <= max_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
matches.push_back(root);
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::reset() {
|
||||
for (auto& match: matches) {
|
||||
match->index = 0;
|
||||
}
|
||||
|
||||
is_valid = true;
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::skip_to(uint32_t id) {
|
||||
for (auto& match: matches) {
|
||||
ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id);
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::next() {
|
||||
// Advance all the matches at seq_id.
|
||||
for (auto& match: matches) {
|
||||
if (match->index < match->ids_length && match->ids[match->index] == seq_id) {
|
||||
match->index++;
|
||||
}
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t::iterator_t(std::vector<Node*>& node_matches) {
|
||||
for (auto const& node_match: node_matches) {
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length;
|
||||
node_match->get_all_ids(ids, ids_length);
|
||||
if (ids_length > 0) {
|
||||
matches.emplace_back(new match_state(ids, ids_length));
|
||||
}
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::set_seq_id() {
|
||||
// Find the lowest id of all the matches and update the seq_id.
|
||||
bool one_is_valid = false;
|
||||
uint32_t lowest_id = UINT32_MAX;
|
||||
|
||||
for (auto& match: matches) {
|
||||
if (match->index < match->ids_length) {
|
||||
one_is_valid = true;
|
||||
|
||||
if (match->ids[match->index] < lowest_id) {
|
||||
lowest_id = match->ids[match->index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (one_is_valid) {
|
||||
seq_id = lowest_id;
|
||||
}
|
||||
|
||||
is_valid = one_is_valid;
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept {
|
||||
if (&obj == this)
|
||||
return *this;
|
||||
|
||||
for (auto& match: matches) {
|
||||
delete match;
|
||||
}
|
||||
matches.clear();
|
||||
|
||||
matches = std::move(obj.matches);
|
||||
seq_id = obj.seq_id;
|
||||
is_valid = obj.is_valid;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
@ -11,6 +11,27 @@ Option<bool> RemoteEmbedder::validate_string_properties(const nlohmann::json& mo
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body,
|
||||
std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers) {
|
||||
if(raft_server == nullptr) {
|
||||
if(method == "GET") {
|
||||
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 10000, true);
|
||||
} else if(method == "POST") {
|
||||
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 10000, true);
|
||||
} else {
|
||||
return 400;
|
||||
}
|
||||
}
|
||||
auto leader_url = raft_server->get_leader_url();
|
||||
leader_url += "proxy";
|
||||
nlohmann::json req_body;
|
||||
req_body["method"] = method;
|
||||
req_body["url"] = url;
|
||||
req_body["body"] = body;
|
||||
req_body["headers"] = req_headers;
|
||||
return HttpClient::get_instance().post_response(leader_url, req_body.dump(), res_body, headers, {}, 10000, true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
|
||||
@ -32,12 +53,11 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return Option<bool>(400, "Property `embed.model_config.model_name` malformed.");
|
||||
}
|
||||
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
std::string res;
|
||||
auto res_code = client.get_response(OPENAI_LIST_MODELS, res, res_headers, headers);
|
||||
auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
@ -67,7 +87,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
|
||||
std::string embedding_res;
|
||||
headers["Content-Type"] = "application/json";
|
||||
res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
|
||||
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
|
||||
|
||||
|
||||
if (res_code != 200) {
|
||||
@ -84,7 +104,6 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
}
|
||||
|
||||
Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
@ -94,7 +113,7 @@ Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
req_body["input"] = text;
|
||||
// remove "openai/" prefix
|
||||
req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path);
|
||||
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
@ -115,9 +134,7 @@ Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::v
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
std::string res;
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
|
||||
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
@ -159,7 +176,6 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return Option<bool>(400, "Property `embed.model_config.model_name` is not a supported Google model.");
|
||||
}
|
||||
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
@ -167,7 +183,7 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
nlohmann::json req_body;
|
||||
req_body["text"] = "test";
|
||||
|
||||
auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
@ -183,7 +199,6 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
}
|
||||
|
||||
Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
@ -191,7 +206,7 @@ Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
nlohmann::json req_body;
|
||||
req_body["text"] = text;
|
||||
|
||||
auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
@ -246,7 +261,6 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
|
||||
|
||||
auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name);
|
||||
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
@ -258,7 +272,7 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
|
||||
instance["content"] = "typesense";
|
||||
req_body["instances"].push_back(instance);
|
||||
|
||||
auto res_code = client.post_response(get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
@ -295,9 +309,8 @@ Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
std::string res;
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
|
||||
auto res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
if(res_code == 401) {
|
||||
@ -308,7 +321,7 @@ Option<std::vector<float>> GCPEmbedder::Embed(const std::string& text) {
|
||||
access_token = refresh_op.get();
|
||||
// retry
|
||||
headers["Authorization"] = "Bearer " + access_token;
|
||||
res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
}
|
||||
}
|
||||
|
||||
@ -353,8 +366,7 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
std::string res;
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
auto res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
if(res_code != 200) {
|
||||
if(res_code == 401) {
|
||||
auto refresh_op = generate_access_token(refresh_token, client_id, client_secret);
|
||||
@ -364,7 +376,7 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
access_token = refresh_op.get();
|
||||
// retry
|
||||
headers["Authorization"] = "Bearer " + access_token;
|
||||
res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers);
|
||||
}
|
||||
}
|
||||
|
||||
@ -386,7 +398,6 @@ Option<std::vector<std::vector<float>>> GCPEmbedder::batch_embed(const std::vect
|
||||
}
|
||||
|
||||
Option<std::string> GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
@ -394,7 +405,7 @@ Option<std::string> GCPEmbedder::generate_access_token(const std::string& refres
|
||||
std::string req_body;
|
||||
req_body = "grant_type=refresh_token&client_id=" + client_id + "&client_secret=" + client_secret + "&refresh_token=" + refresh_token;
|
||||
|
||||
auto res_code = client.post_response(GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
|
@ -453,6 +453,8 @@ int run_server(const Config & config, const std::string & version, void (*master
|
||||
AnalyticsManager::get_instance().run(&replication_state);
|
||||
});
|
||||
|
||||
RemoteEmbedder::init(&replication_state);
|
||||
|
||||
std::string path_to_nodes = config.get_nodes();
|
||||
start_raft_server(replication_state, state_dir, path_to_nodes,
|
||||
config.get_peering_address(),
|
||||
|
@ -1231,6 +1231,16 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) {
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_STREQ("123", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "id: != 123",
|
||||
{}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get();
|
||||
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_STREQ("125", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("127", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("129", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// single ID with backtick
|
||||
|
||||
results = coll1->search("*",
|
||||
@ -1283,6 +1293,14 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) {
|
||||
ASSERT_STREQ("125", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("127", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "id:!= [123,125] && num_employees: <300",
|
||||
{}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get();
|
||||
|
||||
ASSERT_EQ(1, results["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_STREQ("127", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// empty id list not allowed
|
||||
auto res_op = coll1->search("*", {}, "id:=", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true});
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
@ -1296,13 +1314,6 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) {
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Error with filter field `id`: Filter value cannot be empty.", res_op.error());
|
||||
|
||||
// not equals is not supported yet
|
||||
res_op = coll1->search("*",
|
||||
{}, "id:!= [123,125] && num_employees: <300",
|
||||
{}, sort_fields, {0}, 10, 1, FREQUENCY, {true});
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Not equals filtering is not supported on the `id` field.", res_op.error());
|
||||
|
||||
// when no IDs exist
|
||||
results = coll1->search("*",
|
||||
{}, "id: [1000] && num_employees: <300",
|
||||
@ -1397,9 +1408,10 @@ TEST_F(CollectionFilteringTest, NumericalFilteringWithArray) {
|
||||
TEST_F(CollectionFilteringTest, NegationOperatorBasics) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("artist", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
std::vector<field> fields = {
|
||||
field("title", field_types::STRING, false),
|
||||
field("artist", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
|
@ -69,6 +69,7 @@ TEST_F(CollectionGroupingTest, GroupingBasics) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(12, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(3, res["found"].get<size_t>());
|
||||
ASSERT_EQ(3, res["grouped_hits"].size());
|
||||
ASSERT_EQ(11, res["grouped_hits"][0]["group_key"][0].get<size_t>());
|
||||
@ -116,6 +117,7 @@ TEST_F(CollectionGroupingTest, GroupingBasics) {
|
||||
{}, {}, {"rating"}, 2).get();
|
||||
|
||||
// 7 unique ratings
|
||||
ASSERT_EQ(12, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(7, res["found"].get<size_t>());
|
||||
ASSERT_EQ(7, res["grouped_hits"].size());
|
||||
ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][0]["group_key"][0].get<float>());
|
||||
@ -167,7 +169,7 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) {
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10,
|
||||
{}, {}, {"size", "brand"}, 2).get();
|
||||
|
||||
ASSERT_EQ(12, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(10, res["found"].get<size_t>());
|
||||
ASSERT_EQ(10, res["grouped_hits"].size());
|
||||
|
||||
@ -227,6 +229,7 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) {
|
||||
ASSERT_STREQ("0", res["grouped_hits"][0]["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// total count and facet counts should be the same
|
||||
ASSERT_EQ(12, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(10, res["found"].get<size_t>());
|
||||
ASSERT_EQ(2, res["grouped_hits"].size());
|
||||
ASSERT_EQ(10, res["grouped_hits"][0]["group_key"][0].get<size_t>());
|
||||
@ -313,6 +316,7 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) {
|
||||
"", 10,
|
||||
{}, {}, {"genre"}, 2).get();
|
||||
|
||||
ASSERT_EQ(7, results["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["grouped_hits"].size());
|
||||
|
||||
@ -345,6 +349,7 @@ TEST_F(CollectionGroupingTest, GroupingWithGropLimitOfOne) {
|
||||
"", 10,
|
||||
{}, {}, {"brand"}, 1).get();
|
||||
|
||||
ASSERT_EQ(12, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(5, res["found"].get<size_t>());
|
||||
ASSERT_EQ(5, res["grouped_hits"].size());
|
||||
|
||||
@ -430,6 +435,7 @@ TEST_F(CollectionGroupingTest, GroupingWithArrayFieldAndOverride) {
|
||||
"", 10,
|
||||
{}, {}, {"colors"}, 2).get();
|
||||
|
||||
ASSERT_EQ(9, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(4, res["found"].get<size_t>());
|
||||
ASSERT_EQ(4, res["grouped_hits"].size());
|
||||
|
||||
@ -611,6 +617,7 @@ TEST_F(CollectionGroupingTest, SortingOnGroupCount) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(12, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(3, res["found"].get<size_t>());
|
||||
ASSERT_EQ(3, res["grouped_hits"].size());
|
||||
|
||||
@ -635,6 +642,7 @@ TEST_F(CollectionGroupingTest, SortingOnGroupCount) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(12, res2["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(3, res2["found"].get<size_t>());
|
||||
ASSERT_EQ(3, res2["grouped_hits"].size());
|
||||
|
||||
@ -715,6 +723,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(1000, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(300, res["found"].get<size_t>());
|
||||
ASSERT_EQ(100, res["grouped_hits"].size());
|
||||
|
||||
@ -734,6 +743,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(1000, res["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(300, res["found"].get<size_t>());
|
||||
ASSERT_EQ(100, res["grouped_hits"].size());
|
||||
|
||||
@ -757,6 +767,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(1000, res2["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(300, res2["found"].get<size_t>());
|
||||
ASSERT_EQ(100, res2["grouped_hits"].size());
|
||||
|
||||
@ -775,6 +786,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) {
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(1000, res2["found_docs"].get<size_t>());
|
||||
ASSERT_EQ(300, res2["found"].get<size_t>());
|
||||
ASSERT_EQ(100, res2["grouped_hits"].size());
|
||||
|
||||
|
@ -114,6 +114,33 @@ TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnSingleField) {
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSpecificMoreTest, TypoCorrectionShouldUseMaxCandidates) {
|
||||
Collection *coll1;
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false)};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < 20; i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Independent" + std::to_string(i);
|
||||
doc["points"] = i;
|
||||
coll1->add(doc.dump());
|
||||
}
|
||||
|
||||
size_t max_candidates = 20;
|
||||
auto results = coll1->search("independent", {"title"}, "", {}, {}, {2}, 30, 1, FREQUENCY, {false}, 0,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000*1000, 4, 7,
|
||||
off, max_candidates).get();
|
||||
|
||||
ASSERT_EQ(20, results["hits"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnMultiField) {
|
||||
Collection *coll1;
|
||||
std::vector<field> fields = {field("location", field_types::STRING, false),
|
||||
|
@ -253,6 +253,40 @@ TEST_F(CoreAPIUtilsTest, MultiSearchEmbeddedKeys) {
|
||||
|
||||
}
|
||||
|
||||
TEST_F(CoreAPIUtilsTest, SearchEmbeddedPresetKey) {
|
||||
nlohmann::json preset_value = R"(
|
||||
{"per_page": 100}
|
||||
)"_json;
|
||||
|
||||
Option<bool> success_op = collectionManager.upsert_preset("apple", preset_value);
|
||||
ASSERT_TRUE(success_op.ok());
|
||||
|
||||
std::shared_ptr<http_req> req = std::make_shared<http_req>();
|
||||
std::shared_ptr<http_res> res = std::make_shared<http_res>(nullptr);
|
||||
|
||||
nlohmann::json embedded_params;
|
||||
embedded_params["preset"] = "apple";
|
||||
req->embedded_params_vec.push_back(embedded_params);
|
||||
req->params["collection"] = "foo";
|
||||
|
||||
get_search(req, res);
|
||||
ASSERT_EQ("100", req->params["per_page"]);
|
||||
|
||||
// with multi search
|
||||
|
||||
req->params.clear();
|
||||
nlohmann::json body;
|
||||
body["searches"] = nlohmann::json::array();
|
||||
nlohmann::json search;
|
||||
search["collection"] = "users";
|
||||
search["filter_by"] = "age: > 100";
|
||||
body["searches"].push_back(search);
|
||||
req->body = body.dump();
|
||||
|
||||
post_multi_search(req, res);
|
||||
ASSERT_EQ("100", req->params["per_page"]);
|
||||
}
|
||||
|
||||
TEST_F(CoreAPIUtilsTest, ExtractCollectionsFromRequestBody) {
|
||||
std::map<std::string, std::string> req_params;
|
||||
std::string body = R"(
|
||||
@ -956,3 +990,130 @@ TEST_F(CoreAPIUtilsTest, ExportIncludeExcludeFieldsWithFilter) {
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(CoreAPIUtilsTest, TestProxy) {
|
||||
std::string res;
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
|
||||
std::string url = "https://typesense.org";
|
||||
|
||||
long expected_status_code = HttpClient::get_instance().get_response(url, res, res_headers, headers);
|
||||
|
||||
auto req = std::make_shared<http_req>();
|
||||
auto resp = std::make_shared<http_res>(nullptr);
|
||||
|
||||
nlohmann::json body;
|
||||
body["url"] = url;
|
||||
body["method"] = "GET";
|
||||
body["headers"] = headers;
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(expected_status_code, resp->status_code);
|
||||
ASSERT_EQ(res, resp->body);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CoreAPIUtilsTest, TestProxyInvalid) {
|
||||
nlohmann::json body;
|
||||
|
||||
|
||||
|
||||
auto req = std::make_shared<http_req>();
|
||||
auto resp = std::make_shared<http_res>(nullptr);
|
||||
|
||||
// test with url as empty string
|
||||
body["url"] = "";
|
||||
body["method"] = "GET";
|
||||
body["headers"] = nlohmann::json::object();
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("URL and method must be non-empty strings.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
// test with url as integer
|
||||
body["url"] = 123;
|
||||
body["method"] = "GET";
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("URL and method must be non-empty strings.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
// test with no url parameter
|
||||
body.erase("url");
|
||||
body["method"] = "GET";
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("Missing required fields.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
|
||||
// test with invalid method
|
||||
body["url"] = "https://typesense.org";
|
||||
body["method"] = "INVALID";
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("Parameter `method` must be one of GET, POST, PUT, DELETE.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
// test with method as integer
|
||||
body["method"] = 123;
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("URL and method must be non-empty strings.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
// test with no method parameter
|
||||
body.erase("method");
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("Missing required fields.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
|
||||
// test with body as integer
|
||||
body["method"] = "POST";
|
||||
body["body"] = 123;
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("Body must be a string.", nlohmann::json::parse(resp->body)["message"]);
|
||||
|
||||
|
||||
// test with headers as integer
|
||||
body["body"] = "";
|
||||
body["headers"] = 123;
|
||||
|
||||
req->body = body.dump();
|
||||
|
||||
post_proxy(req, resp);
|
||||
|
||||
ASSERT_EQ(400, resp->status_code);
|
||||
ASSERT_EQ("Headers must be a JSON object.", nlohmann::json::parse(resp->body)["message"]);
|
||||
}
|
802
test/numeric_range_trie_test.cpp
Normal file
802
test/numeric_range_trie_test.cpp
Normal file
@ -0,0 +1,802 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
class NumericRangeTrieTest : public ::testing::Test {
|
||||
protected:
|
||||
Store *store;
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
std::atomic<bool> quit = false;
|
||||
|
||||
std::vector<std::string> query_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
void setupCollection() {
|
||||
std::string state_dir_path = "/tmp/typesense_test/collection_filtering";
|
||||
LOG(INFO) << "Truncating and creating: " << state_dir_path;
|
||||
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
|
||||
|
||||
store = new Store(state_dir_path);
|
||||
collectionManager.init(store, 1.0, "auth_key", quit);
|
||||
collectionManager.load(8, 1000);
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
collectionManager.dispose();
|
||||
delete store;
|
||||
}
|
||||
};
|
||||
|
||||
void reset(uint32_t*& ids, uint32_t& ids_length) {
|
||||
delete [] ids;
|
||||
ids = nullptr;
|
||||
ids_length = 0;
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchRange) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32768, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_range(32768, true, -32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 32768, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size() - 1, ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size() - 1; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 134217728, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 0, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, false, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size() - 1, ids_length);
|
||||
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
|
||||
if (i == 3) continue; // id for -32768 would not be present
|
||||
ASSERT_EQ(pairs[i].second, ids[j++]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-134217728, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-134217728, true, 134217728, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, false, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, true, 0, true, ids, ids_length);
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, false, 0, false, ids, ids_length);
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(8192, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(8192, true, 0x2000000, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16384, true, 16384, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(56, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16384, true, 16384, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16384, false, 16384, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16383, true, 16383, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(8193, true, 16383, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, -8192, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchGreaterThan) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32768, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_greater_than(0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-1, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-1, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-24576, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
|
||||
if (i == 3) continue; // id for -32768 would not be present
|
||||
ASSERT_EQ(pairs[i].second, ids[j++]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-32768, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
|
||||
if (i == 3) continue; // id for -32768 would not be present
|
||||
ASSERT_EQ(pairs[i].second, ids[j++]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(8192, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(8192, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(3, ids_length);
|
||||
for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(1000000, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-1000000, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(8, ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchLessThan) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-32768, 8},
|
||||
{-24576, 32},
|
||||
{-16384, 35},
|
||||
{-8192, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_less_than(0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-1, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-16384, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(3, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-16384, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(2, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(8192, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(8192, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-1000000, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(1000000, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(8, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchEqualTo) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32769, 41},
|
||||
{-32768, 43},
|
||||
{-32767, 45},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_equal_to(0, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(-32768, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(43, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(24576, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(58, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(0x202020, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32769, 41},
|
||||
{-32768, 43},
|
||||
{-32767, 45},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{24576, 60},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
auto iterator = trie->search_equal_to(0);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(0x202020);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(-32768);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(43, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(24576);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(60, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
|
||||
iterator.reset();
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(4);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(59);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(60, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(66);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, MultivalueData) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-0x202020, 32},
|
||||
{-32768, 5},
|
||||
{-32768, 8},
|
||||
{-24576, 32},
|
||||
{-16384, 35},
|
||||
{-8192, 43},
|
||||
{0, 43},
|
||||
{0, 49},
|
||||
{1, 8},
|
||||
{256, 91},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91},
|
||||
{0x202020, 35},
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
std::vector<uint32_t> expected = {5, 8, 32, 35, 43};
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-16380, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(16384, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35, 43, 49, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
|
||||
expected = {8, 35, 43, 49, 56, 58, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(256, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
|
||||
expected = {35, 49, 56, 58, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(9, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35, 43, 49, 56, 58, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(6, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35, 43, 49};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, Remove) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-0x202020, 32},
|
||||
{-32768, 5},
|
||||
{-32768, 8},
|
||||
{-24576, 32},
|
||||
{-16384, 35},
|
||||
{-8192, 43},
|
||||
{0, 2},
|
||||
{0, 49},
|
||||
{1, 8},
|
||||
{256, 91},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91},
|
||||
{0x202020, 35},
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
std::vector<uint32_t> expected = {5, 8, 32, 35, 43};
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
trie->remove(-24576, 32);
|
||||
trie->remove(-0x202020, 32);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
expected = {5, 8, 35, 43};
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(0, ids, ids_length);
|
||||
|
||||
expected = {2, 49};
|
||||
ASSERT_EQ(2, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
trie->remove(0, 2);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(0, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(49, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, EmptyTrieOperations) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_range(-32768, true, 32768, true, ids, ids_length);
|
||||
std::unique_ptr<uint32_t[]> ids_guard(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_range(-32768, true, -1, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_range(1, true, 32768, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_greater_than(0, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_greater_than(15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_greater_than(-15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_less_than(-15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_less_than(15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_equal_to(15, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->remove(15, 0);
|
||||
trie->remove(-15, 0);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, Integration) {
|
||||
Collection *coll_array_fields;
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
|
||||
std::vector<field> fields = {
|
||||
field("name", field_types::STRING, false),
|
||||
field("rating", field_types::FLOAT, false),
|
||||
field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(),
|
||||
true), // Setting range index true.
|
||||
field("years", field_types::INT32_ARRAY, false),
|
||||
field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "",
|
||||
nlohmann::json(), true),
|
||||
field("tags", field_types::STRING_ARRAY, true)
|
||||
};
|
||||
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields").get();
|
||||
if(coll_array_fields == nullptr) {
|
||||
// ensure that default_sorting_field is a non-array numerical field
|
||||
auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years");
|
||||
ASSERT_EQ(false, coll_op.ok());
|
||||
ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str());
|
||||
|
||||
// let's try again properly
|
||||
coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age");
|
||||
coll_array_fields = coll_op.get();
|
||||
}
|
||||
|
||||
std::string json_line;
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
auto add_op = coll_array_fields->add(json_line);
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
query_fields = {"name"};
|
||||
std::vector<std::string> facets;
|
||||
// Searching on an int32 field
|
||||
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"3", "1", "4"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// searching on an int64 array field - also ensure that padded space causes no issues
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"1", "4", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user