diff --git a/include/core_api.h b/include/core_api.h index 9d31d65a..780d3d1b 100644 --- a/include/core_api.h +++ b/include/core_api.h @@ -167,3 +167,6 @@ bool is_doc_del_route(uint64_t route_hash); Option> get_api_key_and_ip(const std::string& metadata); void init_api(uint32_t cache_num_entries); + + +bool post_proxy(const std::shared_ptr& req, const std::shared_ptr& res); diff --git a/include/field.h b/include/field.h index bac1becd..9d7cb928 100644 --- a/include/field.h +++ b/include/field.h @@ -54,6 +54,7 @@ namespace fields { static const std::string from = "from"; static const std::string embed_from = "embed_from"; static const std::string model_name = "model_name"; + static const std::string range_index = "range_index"; // Some models require additional parameters to be passed to the model during indexing/querying // For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying @@ -93,13 +94,17 @@ struct field { std::string reference; // Foo.bar (reference to bar field in Foo collection). + bool range_index; + field() {} field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false, - int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) : + int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, + std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale), - nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) { + nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), + embed(embed), range_index(range_index) { set_computed_defaults(sort, infix); } diff --git a/include/filter.h b/include/filter.h index f52b086f..01469d5e 100644 --- a/include/filter.h +++ b/include/filter.h @@ -19,7 +19,7 @@ struct filter { std::string field_name; std::vector values; std::vector comparators; - // Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the + // Would be set when `field: != ...` is encountered with id/string field or `field: != [ ... ]` is encountered in the // case of int and float fields. During filtering, all the results of matching the field against the values are // aggregated and then this flag is checked if negation on the aggregated result is required. bool apply_not_equals = false; diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index b3b12555..d74cb523 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -160,6 +160,9 @@ public: /// Returns the status of the initialization of iterator tree. Option init_status(); + /// Recursively computes the result of each node and stores the final result in the root node. + void compute_result(); + /// Returns a tri-state: /// 0: id is not valid /// 1: id is valid diff --git a/include/http_proxy.h b/include/http_proxy.h new file mode 100644 index 00000000..4b8315b3 --- /dev/null +++ b/include/http_proxy.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include "http_client.h" +#include "lru/lru.hpp" + + +struct http_proxy_res_t { + std::string body; + std::map headers; + long status_code; + + bool operator==(const http_proxy_res_t& other) const { + return body == other.body && headers == other.headers && status_code == other.status_code; + } + + bool operator!=(const http_proxy_res_t& other) const { + return !(*this == other); + } +}; + + +class HttpProxy { + // singleton class for http proxy + public: + static HttpProxy& get_instance() { + static HttpProxy instance; + return instance; + } + HttpProxy(const HttpProxy&) = delete; + void operator=(const HttpProxy&) = delete; + HttpProxy(HttpProxy&&) = delete; + void operator=(HttpProxy&&) = delete; + http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers); + private: + HttpProxy(); + ~HttpProxy() = default; + + + // lru cache for http requests + LRU::TimedCache cache; +}; \ No newline at end of file diff --git a/include/http_server.h b/include/http_server.h index 8edef305..b6465e43 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -52,6 +52,9 @@ public: bool is_res_start = true; h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS; h2o_iovec_t res_body{}; + h2o_iovec_t res_content_type{}; + int status = 0; + const char* reason = nullptr; h2o_generator_t* generator = nullptr; @@ -65,10 +68,9 @@ public: res_body = h2o_strdup(&req->pool, body.c_str(), SIZE_MAX); if(is_res_start) { - req->res.status = status_code; - req->res.reason = http_res::get_status_reason(status_code); - h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, NULL, - content_type.c_str(), content_type.size()); + res_content_type = h2o_strdup(&req->pool, content_type.c_str(), SIZE_MAX); + status = status_code; + reason = http_res::get_status_reason(status_code); } } diff --git a/include/index.h b/include/index.h index 0ac4c5fd..4a1e760f 100644 --- a/include/index.h +++ b/include/index.h @@ -31,6 +31,7 @@ #include "hnswlib/hnswlib.h" #include "filter.h" #include "facet_index.h" +#include "numeric_range_trie_test.h" static constexpr size_t ARRAY_FACET_DIM = 4; using facet_map_t = spp::sparse_hash_map; @@ -309,7 +310,9 @@ private: spp::sparse_hash_map numerical_index; - spp::sparse_hash_map>*> geopoint_index; + spp::sparse_hash_map range_index; + + spp::sparse_hash_map geo_range_index; // geo_array_field => (seq_id => values) used for exact filtering of geo array records spp::sparse_hash_map*> geo_array_index; diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h new file mode 100644 index 00000000..8b7bd22c --- /dev/null +++ b/include/numeric_range_trie_test.h @@ -0,0 +1,154 @@ +#pragma once + +#include +#include "sorted_array.h" + +constexpr short EXPANSE = 256; + +class NumericTrie { + char max_level = 4; + + class Node { + Node** children = nullptr; + sorted_array seq_ids; + + void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level); + + void insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, const char& max_level); + + void search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level, std::set& matches); + + void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, + std::vector& matches); + + void search_less_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches); + + void search_greater_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches); + + public: + + ~Node() { + if (children != nullptr) { + for (auto i = 0; i < EXPANSE; i++) { + delete children[i]; + } + } + + delete [] children; + } + + void insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + + void remove(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level); + + void search_geopoints(const std::vector& cell_ids, const char& max_level, + std::vector& geo_result_ids); + + void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level); + + void get_all_ids(uint32_t*& ids, uint32_t& ids_length); + + void search_range(const int64_t& low, const int64_t& high, const char& max_level, + uint32_t*& ids, uint32_t& ids_length); + + void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector& matches); + + void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + + void search_less_than(const int64_t& value, const char& max_level, std::vector& matches); + + void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + + void search_greater_than(const int64_t& value, const char& max_level, std::vector& matches); + + void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + + void search_equal_to(const int64_t& value, const char& max_level, std::vector& matches); + }; + + Node* negative_trie = nullptr; + Node* positive_trie = nullptr; + +public: + + explicit NumericTrie(char num_bits = 32) { + max_level = num_bits / 8; + } + + ~NumericTrie() { + delete negative_trie; + delete positive_trie; + } + + class iterator_t { + struct match_state { + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + uint32_t index = 0; + + explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {} + + ~match_state() { + delete [] ids; + } + }; + + std::vector matches; + + void set_seq_id(); + + public: + + explicit iterator_t(std::vector& matches); + + ~iterator_t() { + for (auto& match: matches) { + delete match; + } + } + + iterator_t& operator=(iterator_t&& obj) noexcept; + + uint32_t seq_id = 0; + bool is_valid = true; + + void next(); + void skip_to(uint32_t id); + void reset(); + }; + + void insert(const int64_t& value, const uint32_t& seq_id); + + void remove(const int64_t& value, const uint32_t& seq_id); + + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id); + + void search_geopoints(const std::vector& cell_ids, std::vector& geo_result_ids); + + void delete_geopoint(const uint64_t& cell_id, uint32_t id); + + void search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive, + uint32_t*& ids, uint32_t& ids_length); + + iterator_t search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive); + + void search_less_than(const int64_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); + + iterator_t search_less_than(const int64_t& value, const bool& inclusive); + + void search_greater_than(const int64_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); + + iterator_t search_greater_than(const int64_t& value, const bool& inclusive); + + void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length); + + iterator_t search_equal_to(const int64_t& value); +}; diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 095272c2..99f9146f 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -4,6 +4,7 @@ #include #include #include "http_client.h" +#include "raft_server.h" #include "option.h" @@ -12,9 +13,15 @@ class RemoteEmbedder { protected: static Option validate_string_properties(const nlohmann::json& model_config, const std::vector& properties); + static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); + static inline ReplicationState* raft_server = nullptr; public: virtual Option> Embed(const std::string& text) = 0; virtual Option>> batch_embed(const std::vector& inputs) = 0; + static void init(ReplicationState* rs) { + raft_server = rs; + } + }; diff --git a/src/collection.cpp b/src/collection.cpp index aefc2973..098ea064 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -28,7 +28,6 @@ struct sort_fields_guard_t { ~sort_fields_guard_t() { for(auto& sort_by_clause: sort_fields_std) { delete sort_by_clause.eval.filter_tree_root; - if(sort_by_clause.eval.ids) { delete [] sort_by_clause.eval.ids; sort_by_clause.eval.ids = nullptr; @@ -1387,7 +1386,7 @@ Option Collection::search(std::string raw_query, std::vector> raw_result_kvs; std::vector> override_result_kvs; - size_t total_found = 0; + size_t total = 0; std::vector excluded_ids; std::vector> included_ids; // ID -> position @@ -1566,12 +1565,13 @@ Option Collection::search(std::string raw_query, // for grouping we have to aggregate group set sizes to a count value if(group_limit) { - total_found = search_params->groups_processed.size() + override_result_kvs.size(); + total = search_params->groups_processed.size() + override_result_kvs.size(); } else { - total_found = search_params->all_result_ids_len; + total = search_params->all_result_ids_len; } + - if(search_cutoff && total_found == 0) { + if(search_cutoff && total == 0) { // this can happen if other requests stopped this request from being processed // we should return an error so that request can be retried by client return Option(408, "Request Timeout"); @@ -1692,7 +1692,10 @@ Option Collection::search(std::string raw_query, } nlohmann::json result = nlohmann::json::object(); - result["found"] = total_found; + result["found"] = total; + if(group_limit != 0) { + result["found_docs"] = search_params->all_result_ids_len; + } if(exclude_fields.count("out_of") == 0) { result["out_of"] = num_documents.load(); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 5a3468b2..b0c76f69 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -731,6 +731,27 @@ Option CollectionManager::do_search(std::map& re AuthManager::add_item_to_params(req_params, item, true); } + const auto preset_it = req_params.find("preset"); + + if(preset_it != req_params.end()) { + nlohmann::json preset; + const auto& preset_op = CollectionManager::get_instance().get_preset(preset_it->second, preset); + + if(preset_op.ok()) { + if(!preset.is_object()) { + return Option(400, "Search preset is not an object."); + } + + for(const auto& search_item: preset.items()) { + // overwrite = false since req params will contain embedded params and so has higher priority + bool populated = AuthManager::add_item_to_params(req_params, search_item, false); + if(!populated) { + return Option(400, "One or more search parameters are malformed."); + } + } + } + } + CollectionManager & collectionManager = CollectionManager::get_instance(); const std::string& orig_coll_name = req_params["collection"]; auto collection = collectionManager.get_collection(orig_coll_name); diff --git a/src/core_api.cpp b/src/core_api.cpp index 61fd894d..7de42f6a 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -15,6 +15,7 @@ #include "lru/lru.hpp" #include "ratelimit_manager.h" #include "event_manager.h" +#include "http_proxy.h" using namespace std::chrono_literals; @@ -57,10 +58,8 @@ void stream_response(const std::shared_ptr& req, const std::shared_ptr return ; } - if(req->_req->res.status != 0) { - // not the first response chunk, so wait for previous chunk to finish - res->wait(); - } + // wait for previous chunk to finish (if any) + res->wait(); auto req_res = new async_req_res_t(req, res, true); server->get_message_dispatcher()->send_message(HttpServer::STREAM_RESPONSE_MESSAGE, req_res); @@ -368,29 +367,6 @@ bool get_search(const std::shared_ptr& req, const std::shared_ptrparams.find("preset"); - - if(preset_it != req->params.end()) { - nlohmann::json preset; - const auto& preset_op = CollectionManager::get_instance().get_preset(preset_it->second, preset); - - if(preset_op.ok()) { - if(!preset.is_object()) { - res->set_400("Search preset is not an object."); - return false; - } - - for(const auto& search_item: preset.items()) { - // overwrite = false since req params will contain embedded params and so has higher priority - bool populated = AuthManager::add_item_to_params(req->params, search_item, false); - if(!populated) { - res->set_400("One or more search parameters are malformed."); - return false; - } - } - } - } - if(req->embedded_params_vec.empty()) { res->set_500("Embedded params is empty."); return false; @@ -569,27 +545,6 @@ bool post_multi_search(const std::shared_ptr& req, const std::shared_p } } - if(search_params.count("preset") != 0) { - nlohmann::json preset; - auto preset_op = CollectionManager::get_instance().get_preset(search_params["preset"].get(), - preset); - if(preset_op.ok()) { - if(!search_params.is_object()) { - res->set_400("Search preset is not an object."); - return false; - } - - for(const auto& search_item: preset.items()) { - // overwrite = false since req params will contain embedded params and so has higher priority - bool populated = AuthManager::add_item_to_params(req->params, search_item, false); - if(!populated) { - res->set_400("One or more search parameters are malformed."); - return false; - } - } - } - } - std::string results_json_str; Option search_op = CollectionManager::do_search(req->params, req->embedded_params_vec[i], results_json_str, req->conn_ts); @@ -774,7 +729,7 @@ bool get_export_documents(const std::shared_ptr& req, const std::share } } - res->content_type_header = "application/octet-stream"; + res->content_type_header = "text/plain; charset=utf8"; res->status_code = 200; stream_response(req, res); @@ -2136,3 +2091,65 @@ bool del_analytics_rules(const std::shared_ptr& req, const std::shared res->set_200(R"({"ok": true)"); return true; } + + +bool post_proxy(const std::shared_ptr& req, const std::shared_ptr& res) { + HttpProxy& proxy = HttpProxy::get_instance(); + + nlohmann::json req_json; + + try { + req_json = nlohmann::json::parse(req->body); + } catch(const nlohmann::json::parse_error& e) { + LOG(ERROR) << "JSON error: " << e.what(); + res->set_400("Bad JSON."); + return false; + } + + std::string body, url, method; + std::unordered_map headers; + + if(req_json.count("url") == 0 || req_json.count("method") == 0) { + res->set_400("Missing required fields."); + return false; + } + + if(!req_json["url"].is_string() || !req_json["method"].is_string() || req_json["url"].get().empty() || req_json["method"].get().empty()) { + res->set_400("URL and method must be non-empty strings."); + return false; + } + + try { + if(req_json.count("body") != 0 && !req_json["body"].is_string()) { + res->set_400("Body must be a string."); + return false; + } + if(req_json.count("headers") != 0 && !req_json["headers"].is_object()) { + res->set_400("Headers must be a JSON object."); + return false; + } + if(req_json.count("body")) { + body = req_json["body"].get(); + } + url = req_json["url"].get(); + method = req_json["method"].get(); + if(req_json.count("headers")) { + headers = req_json["headers"].get>(); + } + } catch(const std::exception& e) { + LOG(ERROR) << "JSON error: " << e.what(); + res->set_400("Bad JSON."); + return false; + } + + auto response = proxy.send(url, method, body, headers); + + if(response.status_code != 200) { + int code = response.status_code; + res->set_body(code, response.body); + return false; + } + + res->set_200(response.body); + return true; +} \ No newline at end of file diff --git a/src/field.cpp b/src/field.cpp index e6f925ad..b20d70c4 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -75,6 +75,24 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::reference] = ""; } + if (field_json.count(fields::range_index) != 0) { + if (!field_json.at(fields::range_index).is_boolean()) { + return Option(400, std::string("The `range_index` property of the field `") + + field_json[fields::name].get() + + std::string("` should be a boolean.")); + } + + auto const& type = field_json["type"]; + if (field_json[fields::range_index] && + type != field_types::INT32 && type != field_types::INT32_ARRAY && + type != field_types::INT64 && type != field_types::INT64_ARRAY && + type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) { + return Option(400, std::string("The `range_index` property is only allowed for the numerical fields`")); + } + } else { + field_json[fields::range_index] = false; + } + if(field_json["name"] == ".*") { if(field_json.count(fields::facet) == 0) { field_json[fields::facet] = false; @@ -297,7 +315,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::optional], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix], field_json[fields::nested], field_json[fields::nested_array], field_json[fields::num_dim], vec_dist, - field_json[fields::reference], field_json[fields::embed]) + field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index]) ); if (!field_json[fields::reference].get().empty()) { diff --git a/src/filter.cpp b/src/filter.cpp index c152d77f..5d94c66c 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -422,7 +422,10 @@ Option toFilter(const std::string expression, id_comparator = EQUALS; while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - return Option(400, "Not equals filtering is not supported on the `id` field."); + id_comparator = NOT_EQUALS; + filter_exp.apply_not_equals = true; + filter_value_index++; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); } if (filter_value_index != 0) { raw_value = raw_value.substr(filter_value_index); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 185e376c..42816867 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -401,6 +401,22 @@ void filter_result_iterator_t::next() { return; } + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. + if (is_filter_result_initialized) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + if (filter_node->isOperator) { // Advance the subtrees and then apply operators to arrive at the next valid doc. if (filter_node->filter_operator == AND) { @@ -423,21 +439,6 @@ void filter_result_iterator_t::next() { return; } - if (is_filter_result_initialized) { - if (++result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - reference.clear(); - for (auto const& item: filter_result.reference_filter_results) { - reference[item.first] = item.second[result_index]; - } - - return; - } - const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { @@ -619,11 +620,6 @@ void filter_result_iterator_t::init() { } if (a_filter.field_name == "id") { - if (a_filter.values.empty()) { - is_valid = false; - return; - } - // we handle `ids` separately std::vector result_ids; for (const auto& id_str : a_filter.values) { @@ -636,6 +632,16 @@ void filter_result_iterator_t::init() { filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); + } + + if (filter_result.count == 0) { + is_valid = false; + return; + } + seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; approx_filter_ids_length = filter_result.count; @@ -650,27 +656,62 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); if (f.is_integer()) { - auto num_tree = index->numerical_index.at(a_filter.field_name); + if (f.range_index) { + auto const& trie = index->range_index.at(a_filter.field_name); - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - int64_t value = (int64_t)std::stol(filter_value); + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + auto const& value = (int32_t)std::stoi(filter_value); - size_t result_size = filter_result.count; - if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi + 1]; - auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); - } else { - num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const& range_end_value = (int32_t)std::stoi(next_filter_value); + trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count); + fi++; + } else if (a_filter.comparators[fi] == EQUALS) { + trie->search_equal_to(value, filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + uint32_t to_exclude_ids_len = 0; + trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len); + + auto all_ids = index->seq_ids->uncompress(); + filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(), + to_exclude_ids, to_exclude_ids_len, &filter_result.docs); + + delete[] all_ids; + delete[] to_exclude_ids; + } else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) { + trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS, + filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) { + trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS, + filter_result.docs, filter_result.count); + } } + } else { + auto num_tree = index->numerical_index.at(a_filter.field_name); - filter_result.count = result_size; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + int64_t value = (int64_t)std::stol(filter_value); + + size_t result_size = filter_result.count; + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const range_end_value = (int64_t)std::stol(next_filter_value); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); + fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + } + + filter_result.count = result_size; + } } if (a_filter.apply_not_equals) { @@ -688,28 +729,64 @@ void filter_result_iterator_t::init() { approx_filter_ids_length = filter_result.count; return; } else if (f.is_float()) { - auto num_tree = index->numerical_index.at(a_filter.field_name); + if (f.range_index) { + auto const& trie = index->range_index.at(a_filter.field_name); - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - float value = (float)std::atof(filter_value.c_str()); - int64_t float_int64 = Index::float_to_int64_t(value); + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); - size_t result_size = filter_result.count; - if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, float_int64, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); - } else { - num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count); + fi++; + } else if (a_filter.comparators[fi] == EQUALS) { + trie->search_equal_to(float_int64, filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + uint32_t to_exclude_ids_len = 0; + trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len); + + auto all_ids = index->seq_ids->uncompress(); + filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(), + to_exclude_ids, to_exclude_ids_len, &filter_result.docs); + + delete[] all_ids; + delete[] to_exclude_ids; + } else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) { + trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS, + filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) { + trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS, + filter_result.docs, filter_result.count); + } } + } else { + auto num_tree = index->numerical_index.at(a_filter.field_name); - filter_result.count = result_size; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); + + size_t result_size = filter_result.count; + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi+1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); + fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, float_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); + } + + filter_result.count = result_size; + } } if (a_filter.apply_not_equals) { @@ -821,17 +898,15 @@ void filter_result_iterator_t::init() { S2RegionTermIndexer::Options options; options.set_index_contains_points_only(true); S2RegionTermIndexer indexer(options); + auto const& geo_range_index = index->geo_range_index.at(a_filter.field_name); + std::vector cell_ids; for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { - auto geo_index = index->geopoint_index.at(a_filter.field_name); - const auto& ids_it = geo_index->find(term); - if(ids_it != geo_index->end()) { - geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); - } + auto cell = S2CellId::FromToken(term); + cell_ids.push_back(cell.id()); } - gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); - geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); + geo_range_index->search_geopoints(cell_ids, geo_result_ids); // Skip exact filtering step if query radius is greater than the threshold. if (fi < a_filter.params.size() && @@ -955,20 +1030,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - if (filter_node->isOperator) { - // Skip the subtrees to id and then apply operators to arrive at the next valid doc. - left_it->skip_to(id); - right_it->skip_to(id); - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); - } - - return; - } - + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. if (is_filter_result_initialized) { ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id); @@ -986,6 +1048,20 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + left_it->skip_to(id); + right_it->skip_to(id); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { @@ -1068,6 +1144,12 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. + if (is_filter_result_initialized) { + skip_to(id); + return is_valid ? (seq_id == id ? 1 : 0) : -1; + } + if (filter_node->isOperator) { auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); @@ -1181,21 +1263,7 @@ void filter_result_iterator_t::reset() { return; } - if (filter_node->isOperator) { - // Reset the subtrees then apply operators to arrive at the first valid doc. - left_it->reset(); - right_it->reset(); - is_valid = true; - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); - } - - return; - } - + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. if (is_filter_result_initialized) { if (filter_result.count == 0) { is_valid = false; @@ -1214,6 +1282,21 @@ void filter_result_iterator_t::reset() { return; } + if (filter_node->isOperator) { + // Reset the subtrees then apply operators to arrive at the first valid doc. + left_it->reset(); + right_it->reset(); + is_valid = true; + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { @@ -1459,3 +1542,136 @@ void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_ root_iterator->seq_id = left_it->seq_id; filter_result_iterator = root_iterator; } + +void filter_result_iterator_t::compute_result() { + if (filter_node->isOperator) { + left_it->compute_result(); + right_it->compute_result(); + + if (filter_node->filter_operator == AND) { + filter_result_t::and_filter_results(left_it->filter_result, right_it->filter_result, filter_result); + } else { + filter_result_t::or_filter_results(left_it->filter_result, right_it->filter_result, filter_result); + } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; + return; + } + + // Only string field filter needs to be evaluated. + if (is_filter_result_initialized || index->search_index.count(filter_node->filter_exp.field_name) == 0) { + return; + } + + auto const& a_filter = filter_node->filter_exp; + auto const& f = index->search_schema.at(a_filter.field_name); + art_tree* t = index->search_index.at(a_filter.field_name); + + uint32_t* or_ids = nullptr; + size_t or_ids_size = 0; + + // aggregates IDs across array of filter values and reduces excessive ORing + std::vector f_id_buff; + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) { + // needs intersection + exact matching (unlike CONTAINS) + std::vector result_id_vec; + posting_t::intersect(posting_lists, result_id_vec); + + if (result_id_vec.empty()) { + continue; + } + + // need to do exact match + uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()]; + size_t exact_str_ids_size = 0; + std::unique_ptr exact_str_ids_guard(exact_str_ids); + + posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(), + exact_str_ids, exact_str_ids_size); + + if (exact_str_ids_size == 0) { + continue; + } + + for (size_t ei = 0; ei < exact_str_ids_size; ei++) { + f_id_buff.push_back(exact_str_ids[ei]); + } + } else { + // CONTAINS + size_t before_size = f_id_buff.size(); + posting_t::intersect(posting_lists, f_id_buff); + if (f_id_buff.size() == before_size) { + continue; + } + } + + if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) { + gfx::timsort(f_id_buff.begin(), f_id_buff.end()); + f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); + + uint32_t* out = nullptr; + or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); + delete[] or_ids; + or_ids = out; + std::vector().swap(f_id_buff); // clears out memory + } + } + + if (!f_id_buff.empty()) { + gfx::timsort(f_id_buff.begin(), f_id_buff.end()); + f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); + + uint32_t* out = nullptr; + or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); + delete[] or_ids; + or_ids = out; + std::vector().swap(f_id_buff); // clears out memory + } + + filter_result.docs = or_ids; + filter_result.count = or_ids_size; + + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), filter_result.docs, filter_result.count); + } + + if (filter_result.count == 0) { + is_valid = false; + return; + } + + result_index = 0; + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; +} diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp new file mode 100644 index 00000000..b2d2c684 --- /dev/null +++ b/src/http_proxy.cpp @@ -0,0 +1,46 @@ +#include "http_proxy.h" +#include "logger.h" +#include + +using namespace std::chrono_literals; + +HttpProxy::HttpProxy() : cache(30s){ +} + + +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { + // check if url is in cache + uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); + key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); + key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); + for (auto& header : headers) { + key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size())); + key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size())); + } + if (cache.contains(key)) { + return cache[key]; + } + // if not, make http request + HttpClient& client = HttpClient::get_instance(); + http_proxy_res_t res; + + if(method == "GET") { + res.status_code = client.get_response(url, res.body, res.headers, headers, 30 * 1000); + } else if(method == "POST") { + res.status_code = client.post_response(url, body, res.body, res.headers, headers, 30 * 1000); + } else if(method == "PUT") { + res.status_code = client.put_response(url, body, res.body, res.headers, 30 * 1000); + } else if(method == "DELETE") { + res.status_code = client.delete_response(url, res.body, res.headers, 30 * 1000); + } else { + res.status_code = 400; + nlohmann::json j; + j["message"] = "Parameter `method` must be one of GET, POST, PUT, DELETE."; + res.body = j.dump(); + } + + // add to cache + cache.insert(key, res); + + return res; +} \ No newline at end of file diff --git a/src/http_server.cpp b/src/http_server.cpp index b71a8a24..6a704153 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -632,6 +632,9 @@ int HttpServer::async_req_cb(void *ctx, int is_end_stream) { if(request->first_chunk_aggregate) { request->first_chunk_aggregate = false; + + // ensures that the first response need not wait for previous chunk to be done sending + response->notify(); } // default value for last_chunk_aggregate is false @@ -803,6 +806,13 @@ void HttpServer::stream_response(stream_response_state_t& state) { h2o_req_t* req = state.get_req(); + if(state.is_res_start) { + h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, NULL, + state.res_content_type.base, state.res_content_type.len); + req->res.status = state.status; + req->res.reason = state.reason; + } + if(state.is_req_early_exit) { // premature termination of async request: handle this explicitly as otherwise, request is not being closed LOG(INFO) << "Premature termination of async request."; diff --git a/src/index.cpp b/src/index.cpp index 58f29736..f2a98987 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -69,8 +69,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* art_tree_init(t); search_index.emplace(a_field.name, t); } else if(a_field.is_geopoint()) { - auto field_geo_index = new spp::sparse_hash_map>(); - geopoint_index.emplace(a_field.name, field_geo_index); + geo_range_index.emplace(a_field.name, new NumericTrie()); if(!a_field.is_single_geopoint()) { spp::sparse_hash_map * doc_to_geos = new spp::sparse_hash_map(); @@ -79,6 +78,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* } else { num_tree_t* num_tree = new num_tree_t; numerical_index.emplace(a_field.name, num_tree); + + if (a_field.range_index) { + auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64); + range_index.emplace(a_field.name, trie); + } } if(a_field.sort) { @@ -127,12 +131,12 @@ Index::~Index() { search_index.clear(); - for(auto & name_index: geopoint_index) { + for(auto & name_index: geo_range_index) { delete name_index.second; name_index.second = nullptr; } - geopoint_index.clear(); + geo_range_index.clear(); for(auto& name_index: geo_array_index) { for(auto& kv: *name_index.second) { @@ -152,6 +156,13 @@ Index::~Index() { numerical_index.clear(); + for(auto & name_tree: range_index) { + delete name_tree.second; + name_tree.second = nullptr; + } + + range_index.clear(); + for(auto & name_map: sort_index) { delete name_map.second; name_map.second = nullptr; @@ -484,7 +495,7 @@ void Index::validate_and_preprocess(Index *index, index_rec.index_failure(400, e.what()); } } - + auto embed_op = batch_embed_fields(docs_to_embed, embedding_fields, search_schema); if(!embed_op.ok()) { for(size_t i = 0; i < batch_size; i++) { @@ -783,6 +794,15 @@ void Index::index_field_in_memory(const field& afield, std::vector if(!afield.is_string()) { if (afield.type == field_types::INT32) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + int32_t value = record.doc[afield.name].get(); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -792,6 +812,15 @@ void Index::index_field_in_memory(const field& afield, std::vector } else if(afield.type == field_types::INT64) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + int64_t value = record.doc[afield.name].get(); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -801,6 +830,16 @@ void Index::index_field_in_memory(const field& afield, std::vector } else if(afield.type == field_types::FLOAT) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + float fvalue = record.doc[afield.name].get(); + int64_t value = float_to_int64_t(fvalue); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -816,10 +855,10 @@ void Index::index_field_in_memory(const field& afield, std::vector num_tree->insert(value, seq_id, afield.is_facet()); }); } else if(afield.type == field_types::GEOPOINT || afield.type == field_types::GEOPOINT_ARRAY) { - auto geo_index = geopoint_index.at(afield.name); + auto geopoint_range_index = geo_range_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, - [&afield, &geo_array_index=geo_array_index, geo_index](const index_record& record, uint32_t seq_id) { + [&afield, &geo_array_index=geo_array_index, geopoint_range_index](const index_record& record, uint32_t seq_id) { // nested geopoint value inside an array of object will be a simple array so must be treated as geopoint bool nested_obj_arr_geopoint = (afield.nested && afield.type == field_types::GEOPOINT_ARRAY && !record.doc[afield.name].empty() && record.doc[afield.name][0].is_number()); @@ -833,9 +872,8 @@ void Index::index_field_in_memory(const field& afield, std::vector S2RegionTermIndexer indexer(options); S2Point point = S2LatLng::FromDegrees(latlongs[li], latlongs[li+1]).ToPoint(); - for(const auto& term: indexer.GetIndexTerms(point, "")) { - (*geo_index)[term].push_back(seq_id); - } + auto cell = S2CellId(point); + geopoint_range_index->insert_geopoint(cell.id(), seq_id); } if(nested_obj_arr_geopoint) { @@ -863,9 +901,9 @@ void Index::index_field_in_memory(const field& afield, std::vector for(size_t li = 0; li < latlongs.size(); li++) { auto& latlong = latlongs[li]; S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint(); - for(const auto& term: indexer.GetIndexTerms(point, "")) { - (*geo_index)[term].push_back(seq_id); - } + + auto cell = S2CellId(point); + geopoint_range_index->insert_geopoint(cell.id(), seq_id); int64_t packed_latlong = GeoPoint::pack_lat_lng(latlong[0], latlong[1]); packed_latlongs[li + 1] = packed_latlong; @@ -945,7 +983,8 @@ void Index::index_field_in_memory(const field& afield, std::vector // all other numerical arrays auto num_tree = numerical_index.at(afield.name); - iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] + auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr; + iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie] (const index_record& record, uint32_t seq_id) { for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) { const auto& arr_value = record.doc[afield.name][arr_i]; @@ -953,17 +992,26 @@ void Index::index_field_in_memory(const field& afield, std::vector if(afield.type == field_types::INT32_ARRAY) { const int32_t value = arr_value; num_tree->insert(value, seq_id, afield.is_facet()); + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::INT64_ARRAY) { const int64_t value = arr_value; num_tree->insert(value, seq_id, afield.is_facet()); + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::FLOAT_ARRAY) { const float fvalue = arr_value; int64_t value = float_to_int64_t(fvalue); num_tree->insert(value, seq_id, afield.is_facet()); + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::BOOL_ARRAY) { @@ -1628,7 +1676,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, bool Index::field_is_indexed(const std::string& field_name) const { return search_index.count(field_name) != 0 || numerical_index.count(field_name) != 0 || - geopoint_index.count(field_name) != 0; + geo_range_index.count(field_name) != 0; } void Index::aproximate_numerical_match(num_tree_t* const num_tree, @@ -4597,7 +4645,9 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::array*, 3>& field_values, const std::vector& geopoint_indices) const { + filter_result_iterator->compute_result(); auto const& approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; + uint32_t token_bits = 0; const bool check_for_circuit_break = (approx_filter_ids_length > 1000000); @@ -5459,6 +5509,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const std::vector{document[field_name].get()} : document[field_name].get>(); for(int32_t value: values) { + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(value, seq_id); + } + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(value, seq_id); if(search_field.facet) { @@ -5470,6 +5525,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const std::vector{document[field_name].get()} : document[field_name].get>(); for(int64_t value: values) { + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(value, seq_id); + } + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(value, seq_id); if(search_field.facet) { @@ -5484,8 +5544,14 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const document[field_name].get>(); for(float value: values) { - num_tree_t* num_tree = numerical_index.at(field_name); int64_t fintval = float_to_int64_t(value); + + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(fintval, seq_id); + } + + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(fintval, seq_id); if(search_field.facet) { remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id); @@ -5505,7 +5571,7 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const } } } else if(search_field.is_geopoint()) { - auto geo_index = geopoint_index[field_name]; + auto geopoint_range_index = geo_range_index[field_name]; S2RegionTermIndexer::Options options; options.set_index_contains_points_only(true); S2RegionTermIndexer indexer(options); @@ -5516,17 +5582,8 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const for(const std::vector& latlong: latlongs) { S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint(); - for(const auto& term: indexer.GetIndexTerms(point, "")) { - auto term_it = geo_index->find(term); - if(term_it == geo_index->end()) { - continue; - } - std::vector& ids = term_it->second; - ids.erase(std::remove(ids.begin(), ids.end(), seq_id), ids.end()); - if(ids.empty()) { - geo_index->erase(term); - } - } + auto cell = S2CellId(point); + geopoint_range_index->delete_geopoint(cell.id(), seq_id); } if(!search_field.is_single_geopoint()) { @@ -5664,8 +5721,7 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec art_tree_init(t); search_index.emplace(new_field.name, t); } else if(new_field.is_geopoint()) { - auto field_geo_index = new spp::sparse_hash_map>(); - geopoint_index.emplace(new_field.name, field_geo_index); + geo_range_index.emplace(new_field.name, new NumericTrie()); if(!new_field.is_single_geopoint()) { auto geo_array_map = new spp::sparse_hash_map(); geo_array_index.emplace(new_field.name, geo_array_map); @@ -5673,6 +5729,10 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec } else { num_tree_t* num_tree = new num_tree_t; numerical_index.emplace(new_field.name, num_tree); + + if (new_field.range_index) { + range_index.emplace(new_field.name, new NumericTrie(new_field.is_int32() ? 32 : 64)); + } } } @@ -5715,8 +5775,8 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec delete search_index[del_field.name]; search_index.erase(del_field.name); } else if(del_field.is_geopoint()) { - delete geopoint_index[del_field.name]; - geopoint_index.erase(del_field.name); + delete geo_range_index[del_field.name]; + geo_range_index.erase(del_field.name); if(!del_field.is_single_geopoint()) { spp::sparse_hash_map* geo_array_map = geo_array_index[del_field.name]; @@ -5729,6 +5789,11 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec } else { delete numerical_index[del_field.name]; numerical_index.erase(del_field.name); + + if (del_field.range_index) { + delete range_index[del_field.name]; + range_index.erase(del_field.name); + } } if(del_field.is_sortable()) { @@ -6105,8 +6170,15 @@ Option Index::batch_embed_fields(std::vector& documents, } } } - texts_to_embed.push_back(std::make_pair(document, text)); + if(!text.empty()) { + texts_to_embed.push_back(std::make_pair(document, text)); + } } + + if(texts_to_embed.empty()) { + continue; + } + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]); @@ -6128,7 +6200,6 @@ Option Index::batch_embed_fields(std::vector& documents, } auto embedding_op = embedder_op.get()->batch_embed(texts); - if(!embedding_op.ok()) { return Option(400, embedding_op.error()); } diff --git a/src/main/typesense_server.cpp b/src/main/typesense_server.cpp index 30b7281d..d4df4f2a 100644 --- a/src/main/typesense_server.cpp +++ b/src/main/typesense_server.cpp @@ -98,6 +98,9 @@ void master_server_routes() { server->del("/limits/active/:id", del_throttle); server->del("/limits/exceeds/:id", del_exceed); server->post("/config", post_config, false, false); + + // for proxying remote embedders + server->post("/proxy", post_proxy); } void (*backward::SignalHandling::_callback)(int sig, backward::StackTrace&) = nullptr; @@ -144,6 +147,7 @@ int main(int argc, char **argv) { Option config_validitation = config.is_valid(); if(!config_validitation.ok()) { + std::cerr << "Typesense " << TYPESENSE_VERSION << std::endl; std::cerr << "Invalid configuration: " << config_validitation.error() << std::endl; std::cerr << "Command line " << options.usage() << std::endl; std::cerr << "You can also pass these arguments as environment variables such as " diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp new file mode 100644 index 00000000..f70de113 --- /dev/null +++ b/src/numeric_range_trie.cpp @@ -0,0 +1,908 @@ +#include +#include +#include "numeric_range_trie_test.h" +#include "array_utils.h" + +void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) { + if (value < 0) { + if (negative_trie == nullptr) { + negative_trie = new NumericTrie::Node(); + } + + negative_trie->insert(std::abs(value), seq_id, max_level); + } else { + if (positive_trie == nullptr) { + positive_trie = new NumericTrie::Node(); + } + + positive_trie->insert(value, seq_id, max_level); + } +} + +void NumericTrie::remove(const int64_t& value, const uint32_t& seq_id) { + if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { + return; + } + + if (value < 0) { + negative_trie->remove(std::abs(value), seq_id, max_level); + } else { + positive_trie->remove(value, seq_id, max_level); + } +} + +void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id) { + if (positive_trie == nullptr) { + positive_trie = new NumericTrie::Node(); + } + + positive_trie->insert_geopoint(cell_id, seq_id, max_level); +} + +void NumericTrie::search_geopoints(const std::vector& cell_ids, std::vector& geo_result_ids) { + if (positive_trie == nullptr) { + return; + } + + positive_trie->search_geopoints(cell_ids, max_level, geo_result_ids); +} + +void NumericTrie::delete_geopoint(const uint64_t& cell_id, uint32_t id) { + if (positive_trie == nullptr) { + return; + } + + positive_trie->delete_geopoint(cell_id, id, max_level); +} + +void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive, + uint32_t*& ids, uint32_t& ids_length) { + if (low > high) { + return; + } + + if (low < 0 && high >= 0) { + // Have to combine the results of >low from negative_trie and low from negative_trie. + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, + negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } + + if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, + positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } + } else if (low >= 0) { + // Search only in positive_trie + if (positive_trie == nullptr) { + return; + } + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, + positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } else { + // Search only in negative_trie + if (negative_trie == nullptr) { + return; + } + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + // Since we store absolute values, switching low and high would produce the correct result. + auto abs_high = std::abs(high), abs_low = std::abs(low); + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + max_level, + negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } +} + +NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive) { + std::vector matches; + if (low > high) { + return NumericTrie::iterator_t(matches); + } + + if (low < 0 && high >= 0) { + // Have to combine the results of >low from negative_trie and low from negative_trie. + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches); + } + + if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) + positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches); + } + } else if (low >= 0) { + // Search only in positive_trie + if (positive_trie == nullptr) { + return NumericTrie::iterator_t(matches); + } + + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches); + } else { + // Search only in negative_trie + if (negative_trie == nullptr) { + return NumericTrie::iterator_t(matches); + } + + auto abs_high = std::abs(high), abs_low = std::abs(low); + // Since we store absolute values, switching low and high would produce the correct result. + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + max_level, matches); + } + + return NumericTrie::iterator_t(matches); +} + +void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { + if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) + if (positive_trie != nullptr) { + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->get_all_ids(positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } + return; + } + + if (value >= 0) { + if (positive_trie == nullptr) { + return; + } + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } else { + // Have to combine the results of >value from negative_trie and all the ids in positive_trie + + if (negative_trie != nullptr) { + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + auto abs_low = std::abs(value); + + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, + negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } + + if (positive_trie == nullptr) { + return; + } + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->get_all_ids(positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } +} + +NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) { + std::vector matches; + + if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) + if (positive_trie != nullptr) { + matches.push_back(positive_trie); + } + return NumericTrie::iterator_t(matches); + } + + if (value >= 0) { + if (positive_trie != nullptr) { + positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches); + } + } else { + // Have to combine the results of >value from negative_trie and all the ids in positive_trie + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches); + } + if (positive_trie != nullptr) { + matches.push_back(positive_trie); + } + } + + return NumericTrie::iterator_t(matches); +} + +void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { + if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] + if (negative_trie != nullptr) { + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + negative_trie->get_all_ids(negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } + return; + } + + if (value < 0) { + if (negative_trie == nullptr) { + return; + } + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + auto abs_low = std::abs(value); + + // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, + negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } else { + // Have to combine the results of search_less_than(inclusive ? value : value - 1, max_level, + positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } + + if (negative_trie == nullptr) { + return; + } + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + negative_trie->get_all_ids(negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } +} + +NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) { + std::vector matches; + + if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] + if (negative_trie != nullptr) { + matches.push_back(negative_trie); + } + return NumericTrie::iterator_t(matches); + } + + if (value < 0) { + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches); + } + } else { + // Have to combine the results of search_less_than(inclusive ? value : value - 1, max_level, matches); + } + if (negative_trie != nullptr) { + matches.push_back(negative_trie); + } + } + + return NumericTrie::iterator_t(matches); +} + +void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) { + if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { + return; + } + + uint32_t* equal_ids = nullptr; + uint32_t equal_ids_length = 0; + + if (value < 0) { + negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length); + } else { + positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length); + } + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(equal_ids, equal_ids_length, ids, ids_length, &out); + + delete [] equal_ids; + delete [] ids; + ids = out; +} + +NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) { + std::vector matches; + if (value < 0 && negative_trie != nullptr) { + negative_trie->search_equal_to(std::abs(value), max_level, matches); + } else if (value >= 0 && positive_trie != nullptr) { + positive_trie->search_equal_to(value, max_level, matches); + } + + return NumericTrie::iterator_t(matches); +} + +void NumericTrie::Node::insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level) { + char level = 0; + return insert_helper(cell_id, seq_id, level, max_level); +} + +void NumericTrie::Node::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level) { + char level = 0; + return insert_geopoint_helper(cell_id, seq_id, level, max_level); +} + +inline int get_index(const int64_t& value, const char& level, const char& max_level) { + // Values are index considering higher order of the bytes first. + // 0x01020408 (16909320) would be indexed in the trie as follows: + // Level Index + // 1 1 + // 2 2 + // 3 4 + // 4 8 + return (value >> (8 * (max_level - level))) & 0xFF; +} + +inline int get_geopoint_index(const uint64_t& cell_id, const char& level) { + // Doing 8-level since cell_id is a 64 bit number. + return (cell_id >> (8 * (8 - level))) & 0xFF; +} + +void NumericTrie::Node::remove(const int64_t& value, const uint32_t& id, const char& max_level) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level < max_level) { + root->seq_ids.remove_value(id); + + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + root->seq_ids.remove_value(id); + if (root->children != nullptr && root->children[index] != nullptr) { + auto& child = root->children[index]; + + child->seq_ids.remove_value(id); + if (child->seq_ids.getLength() == 0) { + delete child; + child = nullptr; + } + } +} + +void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { + if (level > max_level) { + return; + } + + // Root node contains all the sequence ids present in the tree. + if (!seq_ids.contains(seq_id)) { + seq_ids.append(seq_id); + } + + if (++level <= max_level) { + if (children == nullptr) { + children = new NumericTrie::Node* [EXPANSE]{nullptr}; + } + + auto index = get_index(value, level, max_level); + if (children[index] == nullptr) { + children[index] = new NumericTrie::Node(); + } + + return children[index]->insert_helper(value, seq_id, level, max_level); + } +} + +void NumericTrie::Node::insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, + const char& max_level) { + if (level > max_level) { + return; + } + + // Root node contains all the sequence ids present in the tree. + if (!seq_ids.contains(seq_id)) { + seq_ids.append(seq_id); + } + + if (++level <= max_level) { + if (children == nullptr) { + children = new NumericTrie::Node* [EXPANSE]{nullptr}; + } + + auto index = get_geopoint_index(cell_id, level); + if (children[index] == nullptr) { + children[index] = new NumericTrie::Node(); + } + + return children[index]->insert_geopoint_helper(cell_id, seq_id, level, max_level); + } +} + +char get_max_search_level(const uint64_t& cell_id, const char& max_level) { + // For cell id 0x47E66C3000000000, we only have to prefix match the top four bytes since rest of the bytes are 0. + // So the max search level would be 4 in this case. + + auto mask = (uint64_t) 0xFF << (8 * (8 - max_level)); // We're only indexing top 8-max_level bytes. + char i = max_level; + while (((cell_id & mask) == 0) && --i > 0) { + mask <<= 8; + } + + return i; +} + +void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level, + std::set& matches) { + char level = 1; + Node* root = this; + auto index = get_geopoint_index(cell_id, level); + auto max_search_level = get_max_search_level(cell_id, max_index_level); + + while (level < max_search_level) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_geopoint_index(cell_id, ++level); + } + + matches.insert(root); +} + +void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, const char& max_level, + std::vector& geo_result_ids) { + std::set matches; + for (const auto &cell_id: cell_ids) { + search_geopoints_helper(cell_id, max_level, matches); + } + + for (auto const& match: matches) { + auto const& m_seq_ids = match->seq_ids.uncompress(); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + geo_result_ids.push_back(m_seq_ids[i]); + } + + delete [] m_seq_ids; + } + + gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); + geo_result_ids.erase(unique(geo_result_ids.begin(), geo_result_ids.end()), geo_result_ids.end()); +} + +void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level) { + char level = 1; + Node* root = this; + auto index = get_geopoint_index(cell_id, level); + + while (level < max_level) { + root->seq_ids.remove_value(id); + + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_geopoint_index(cell_id, ++level); + } + + root->seq_ids.remove_value(id); + if (root->children != nullptr && root->children[index] != nullptr) { + auto& child = root->children[index]; + + child->seq_ids.remove_value(id); + if (child->seq_ids.getLength() == 0) { + delete child; + child = nullptr; + } + } +} + +void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { + ids = seq_ids.uncompress(); + ids_length = seq_ids.getLength(); +} + +void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { + char level = 0; + std::vector matches; + search_less_than_helper(value, level, max_level, matches); + + std::vector consolidated_ids; + for (auto const& match: matches) { + auto const& m_seq_ids = match->seq_ids.uncompress(); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + consolidated_ids.push_back(m_seq_ids[i]); + } + + delete [] m_seq_ids; + } + + gfx::timsort(consolidated_ids.begin(), consolidated_ids.end()); + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(), + ids, ids_length, &out); + + delete [] ids; + ids = out; +} + +void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 0; + search_less_than_helper(value, level, max_level, matches); +} + +void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches) { + if (level == max_level) { + matches.push_back(this); + return; + } else if (level > max_level || children == nullptr) { + return; + } + + auto index = get_index(value, ++level, max_level); + if (children[index] != nullptr) { + children[index]->search_less_than_helper(value, level, max_level, matches); + } + + while (--index >= 0) { + if (children[index] != nullptr) { + matches.push_back(children[index]); + } + } + + --level; +} + +void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { + if (low > high) { + return; + } + std::vector matches; + search_range_helper(low, high, max_level, matches); + + std::vector consolidated_ids; + for (auto const& match: matches) { + auto const& m_seq_ids = match->seq_ids.uncompress(); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + consolidated_ids.push_back(m_seq_ids[i]); + } + + delete [] m_seq_ids; + } + + gfx::timsort(consolidated_ids.begin(), consolidated_ids.end()); + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(), + ids, ids_length, &out); + + delete [] ids; + ids = out; +} + +void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level, + std::vector& matches) { + if (low > high) { + return; + } + + search_range_helper(low, high, max_level, matches); +} + +void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, + std::vector& matches) { + // Segregating the nodes into matching low, in-between, and matching high. + + NumericTrie::Node* root = this; + char level = 1; + auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level); + + // Keep updating the root while the range is contained within a single child node. + while (root->children != nullptr && low_index == high_index && level < max_level) { + if (root->children[low_index] == nullptr) { + return; + } + + root = root->children[low_index]; + level++; + low_index = get_index(low, level, max_level); + high_index = get_index(high, level, max_level); + } + + if (root->children == nullptr) { + return; + } else if (low_index == high_index) { // low and high are equal + if (root->children[low_index] != nullptr) { + matches.push_back(root->children[low_index]); + } + return; + } + + if (root->children[low_index] != nullptr) { + // Collect all the sub-nodes that are greater than low. + root->children[low_index]->search_greater_than_helper(low, level, max_level, matches); + } + + auto index = low_index + 1; + // All the nodes in-between low and high are a match by default. + while (index < std::min(high_index, (int)EXPANSE)) { + if (root->children[index] != nullptr) { + matches.push_back(root->children[index]); + } + + index++; + } + + if (index < EXPANSE && index == high_index && root->children[index] != nullptr) { + // Collect all the sub-nodes that are lesser than high. + root->children[index]->search_less_than_helper(high, level, max_level, matches); + } +} + +void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { + char level = 0; + std::vector matches; + search_greater_than_helper(value, level, max_level, matches); + + std::vector consolidated_ids; + for (auto const& match: matches) { + auto const& m_seq_ids = match->seq_ids.uncompress(); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + consolidated_ids.push_back(m_seq_ids[i]); + } + + delete [] m_seq_ids; + } + + gfx::timsort(consolidated_ids.begin(), consolidated_ids.end()); + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(), + ids, ids_length, &out); + + delete [] ids; + ids = out; +} + +void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 0; + search_greater_than_helper(value, level, max_level, matches); +} + +void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches) { + if (level == max_level) { + matches.push_back(this); + return; + } else if (level > max_level || children == nullptr) { + return; + } + + auto index = get_index(value, ++level, max_level); + if (children[index] != nullptr) { + children[index]->search_greater_than_helper(value, level, max_level, matches); + } + + while (++index < EXPANSE) { + if (children[index] != nullptr) { + matches.push_back(children[index]); + } + } + + --level; +} + +void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level <= max_level) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + root->get_all_ids(ids, ids_length); +} + +void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level <= max_level) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + matches.push_back(root); +} + +void NumericTrie::iterator_t::reset() { + for (auto& match: matches) { + match->index = 0; + } + + is_valid = true; + set_seq_id(); +} + +void NumericTrie::iterator_t::skip_to(uint32_t id) { + for (auto& match: matches) { + ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id); + } + + set_seq_id(); +} + +void NumericTrie::iterator_t::next() { + // Advance all the matches at seq_id. + for (auto& match: matches) { + if (match->index < match->ids_length && match->ids[match->index] == seq_id) { + match->index++; + } + } + + set_seq_id(); +} + +NumericTrie::iterator_t::iterator_t(std::vector& node_matches) { + for (auto const& node_match: node_matches) { + uint32_t* ids = nullptr; + uint32_t ids_length; + node_match->get_all_ids(ids, ids_length); + if (ids_length > 0) { + matches.emplace_back(new match_state(ids, ids_length)); + } + } + + set_seq_id(); +} + +void NumericTrie::iterator_t::set_seq_id() { + // Find the lowest id of all the matches and update the seq_id. + bool one_is_valid = false; + uint32_t lowest_id = UINT32_MAX; + + for (auto& match: matches) { + if (match->index < match->ids_length) { + one_is_valid = true; + + if (match->ids[match->index] < lowest_id) { + lowest_id = match->ids[match->index]; + } + } + } + + if (one_is_valid) { + seq_id = lowest_id; + } + + is_valid = one_is_valid; +} + +NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept { + if (&obj == this) + return *this; + + for (auto& match: matches) { + delete match; + } + matches.clear(); + + matches = std::move(obj.matches); + seq_id = obj.seq_id; + is_valid = obj.is_valid; + + return *this; +} + diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 604e1071..083aa2e1 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -11,6 +11,27 @@ Option RemoteEmbedder::validate_string_properties(const nlohmann::json& mo return Option(true); } +long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, + std::map& headers, const std::unordered_map& req_headers) { + if(raft_server == nullptr) { + if(method == "GET") { + return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 10000, true); + } else if(method == "POST") { + return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 10000, true); + } else { + return 400; + } + } + auto leader_url = raft_server->get_leader_url(); + leader_url += "proxy"; + nlohmann::json req_body; + req_body["method"] = method; + req_body["url"] = url; + req_body["body"] = body; + req_body["headers"] = req_headers; + return HttpClient::get_instance().post_response(leader_url, req_body.dump(), res_body, headers, {}, 10000, true); +} + OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) { @@ -32,12 +53,11 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(400, "Property `embed.model_config.model_name` malformed."); } - HttpClient& client = HttpClient::get_instance(); std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; std::string res; - auto res_code = client.get_response(OPENAI_LIST_MODELS, res, res_headers, headers); + auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers); if (res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { @@ -67,7 +87,7 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::string embedding_res; headers["Content-Type"] = "application/json"; - res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); if (res_code != 200) { @@ -84,7 +104,6 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, } Option> OpenAIEmbedder::Embed(const std::string& text) { - HttpClient& client = HttpClient::get_instance(); std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; @@ -94,7 +113,7 @@ Option> OpenAIEmbedder::Embed(const std::string& text) { req_body["input"] = text; // remove "openai/" prefix req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); - auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if (res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { @@ -115,9 +134,7 @@ Option>> OpenAIEmbedder::batch_embed(const std::v headers["Content-Type"] = "application/json"; std::map res_headers; std::string res; - HttpClient& client = HttpClient::get_instance(); - - auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); @@ -159,7 +176,6 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(400, "Property `embed.model_config.model_name` is not a supported Google model."); } - HttpClient& client = HttpClient::get_instance(); std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; @@ -167,7 +183,7 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, nlohmann::json req_body; req_body["text"] = "test"; - auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); @@ -183,7 +199,6 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, } Option> GoogleEmbedder::Embed(const std::string& text) { - HttpClient& client = HttpClient::get_instance(); std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; @@ -191,7 +206,7 @@ Option> GoogleEmbedder::Embed(const std::string& text) { nlohmann::json req_body; req_body["text"] = text; - auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); @@ -246,7 +261,6 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name); - HttpClient& client = HttpClient::get_instance(); std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; @@ -258,7 +272,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns instance["content"] = "typesense"; req_body["instances"].push_back(instance); - auto res_code = client.post_response(get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers); if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); @@ -295,9 +309,8 @@ Option> GCPEmbedder::Embed(const std::string& text) { headers["Content-Type"] = "application/json"; std::map res_headers; std::string res; - HttpClient& client = HttpClient::get_instance(); - auto res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); if(res_code != 200) { if(res_code == 401) { @@ -308,7 +321,7 @@ Option> GCPEmbedder::Embed(const std::string& text) { access_token = refresh_op.get(); // retry headers["Authorization"] = "Bearer " + access_token; - res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); + res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); } } @@ -353,8 +366,7 @@ Option>> GCPEmbedder::batch_embed(const std::vect headers["Content-Type"] = "application/json"; std::map res_headers; std::string res; - HttpClient& client = HttpClient::get_instance(); - auto res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); if(res_code != 200) { if(res_code == 401) { auto refresh_op = generate_access_token(refresh_token, client_id, client_secret); @@ -364,7 +376,7 @@ Option>> GCPEmbedder::batch_embed(const std::vect access_token = refresh_op.get(); // retry headers["Authorization"] = "Bearer " + access_token; - res_code = client.post_response(get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); + res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name), req_body.dump(), res, res_headers, headers); } } @@ -386,7 +398,6 @@ Option>> GCPEmbedder::batch_embed(const std::vect } Option GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) { - HttpClient& client = HttpClient::get_instance(); std::unordered_map headers; headers["Content-Type"] = "application/x-www-form-urlencoded"; std::map res_headers; @@ -394,7 +405,7 @@ Option GCPEmbedder::generate_access_token(const std::string& refres std::string req_body; req_body = "grant_type=refresh_token&client_id=" + client_id + "&client_secret=" + client_secret + "&refresh_token=" + refresh_token; - auto res_code = client.post_response(GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers); + auto res_code = call_remote_api("POST", GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers); if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); diff --git a/src/typesense_server_utils.cpp b/src/typesense_server_utils.cpp index 46f66ab0..c628c569 100644 --- a/src/typesense_server_utils.cpp +++ b/src/typesense_server_utils.cpp @@ -453,6 +453,8 @@ int run_server(const Config & config, const std::string & version, void (*master AnalyticsManager::get_instance().run(&replication_state); }); + RemoteEmbedder::init(&replication_state); + std::string path_to_nodes = config.get_nodes(); start_raft_server(replication_state, state_dir, path_to_nodes, config.get_peering_address(), diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 988b035a..b3a5e600 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1231,6 +1231,16 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) { ASSERT_EQ(1, results["hits"].size()); ASSERT_STREQ("123", results["hits"][0]["document"]["id"].get().c_str()); + results = coll1->search("*", + {}, "id: != 123", + {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + ASSERT_STREQ("125", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("127", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("129", results["hits"][2]["document"]["id"].get().c_str()); + // single ID with backtick results = coll1->search("*", @@ -1283,6 +1293,14 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) { ASSERT_STREQ("125", results["hits"][1]["document"]["id"].get().c_str()); ASSERT_STREQ("127", results["hits"][2]["document"]["id"].get().c_str()); + results = coll1->search("*", + {}, "id:!= [123,125] && num_employees: <300", + {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("127", results["hits"][0]["document"]["id"].get().c_str()); + // empty id list not allowed auto res_op = coll1->search("*", {}, "id:=", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}); ASSERT_FALSE(res_op.ok()); @@ -1296,13 +1314,6 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) { ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Error with filter field `id`: Filter value cannot be empty.", res_op.error()); - // not equals is not supported yet - res_op = coll1->search("*", - {}, "id:!= [123,125] && num_employees: <300", - {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}); - ASSERT_FALSE(res_op.ok()); - ASSERT_EQ("Not equals filtering is not supported on the `id` field.", res_op.error()); - // when no IDs exist results = coll1->search("*", {}, "id: [1000] && num_employees: <300", @@ -1397,9 +1408,10 @@ TEST_F(CollectionFilteringTest, NumericalFilteringWithArray) { TEST_F(CollectionFilteringTest, NegationOperatorBasics) { Collection *coll1; - std::vector fields = {field("title", field_types::STRING, false), - field("artist", field_types::STRING, false), - field("points", field_types::INT32, false),}; + std::vector fields = { + field("title", field_types::STRING, false), + field("artist", field_types::STRING, false), + field("points", field_types::INT32, false),}; coll1 = collectionManager.get_collection("coll1").get(); if(coll1 == nullptr) { diff --git a/test/collection_grouping_test.cpp b/test/collection_grouping_test.cpp index f88d9551..1f9843e9 100644 --- a/test/collection_grouping_test.cpp +++ b/test/collection_grouping_test.cpp @@ -69,6 +69,7 @@ TEST_F(CollectionGroupingTest, GroupingBasics) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(12, res["found_docs"].get()); ASSERT_EQ(3, res["found"].get()); ASSERT_EQ(3, res["grouped_hits"].size()); ASSERT_EQ(11, res["grouped_hits"][0]["group_key"][0].get()); @@ -116,6 +117,7 @@ TEST_F(CollectionGroupingTest, GroupingBasics) { {}, {}, {"rating"}, 2).get(); // 7 unique ratings + ASSERT_EQ(12, res["found_docs"].get()); ASSERT_EQ(7, res["found"].get()); ASSERT_EQ(7, res["grouped_hits"].size()); ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][0]["group_key"][0].get()); @@ -167,7 +169,7 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) { spp::sparse_hash_set(), 10, "", 30, 5, "", 10, {}, {}, {"size", "brand"}, 2).get(); - + ASSERT_EQ(12, res["found_docs"].get()); ASSERT_EQ(10, res["found"].get()); ASSERT_EQ(10, res["grouped_hits"].size()); @@ -227,6 +229,7 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) { ASSERT_STREQ("0", res["grouped_hits"][0]["hits"][1]["document"]["id"].get().c_str()); // total count and facet counts should be the same + ASSERT_EQ(12, res["found_docs"].get()); ASSERT_EQ(10, res["found"].get()); ASSERT_EQ(2, res["grouped_hits"].size()); ASSERT_EQ(10, res["grouped_hits"][0]["group_key"][0].get()); @@ -313,6 +316,7 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) { "", 10, {}, {}, {"genre"}, 2).get(); + ASSERT_EQ(7, results["found_docs"].get()); ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["grouped_hits"].size()); @@ -345,6 +349,7 @@ TEST_F(CollectionGroupingTest, GroupingWithGropLimitOfOne) { "", 10, {}, {}, {"brand"}, 1).get(); + ASSERT_EQ(12, res["found_docs"].get()); ASSERT_EQ(5, res["found"].get()); ASSERT_EQ(5, res["grouped_hits"].size()); @@ -430,6 +435,7 @@ TEST_F(CollectionGroupingTest, GroupingWithArrayFieldAndOverride) { "", 10, {}, {}, {"colors"}, 2).get(); + ASSERT_EQ(9, res["found_docs"].get()); ASSERT_EQ(4, res["found"].get()); ASSERT_EQ(4, res["grouped_hits"].size()); @@ -611,6 +617,7 @@ TEST_F(CollectionGroupingTest, SortingOnGroupCount) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(12, res["found_docs"].get()); ASSERT_EQ(3, res["found"].get()); ASSERT_EQ(3, res["grouped_hits"].size()); @@ -635,6 +642,7 @@ TEST_F(CollectionGroupingTest, SortingOnGroupCount) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(12, res2["found_docs"].get()); ASSERT_EQ(3, res2["found"].get()); ASSERT_EQ(3, res2["grouped_hits"].size()); @@ -715,6 +723,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(1000, res["found_docs"].get()); ASSERT_EQ(300, res["found"].get()); ASSERT_EQ(100, res["grouped_hits"].size()); @@ -734,6 +743,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(1000, res["found_docs"].get()); ASSERT_EQ(300, res["found"].get()); ASSERT_EQ(100, res["grouped_hits"].size()); @@ -757,6 +767,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(1000, res2["found_docs"].get()); ASSERT_EQ(300, res2["found"].get()); ASSERT_EQ(100, res2["grouped_hits"].size()); @@ -775,6 +786,7 @@ TEST_F(CollectionGroupingTest, SortingMoreThanMaxTopsterSize) { "", 10, {}, {}, {"size"}, 2).get(); + ASSERT_EQ(1000, res2["found_docs"].get()); ASSERT_EQ(300, res2["found"].get()); ASSERT_EQ(100, res2["grouped_hits"].size()); diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index b8419b37..0d101b0d 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -114,6 +114,33 @@ TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnSingleField) { ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); } +TEST_F(CollectionSpecificMoreTest, TypoCorrectionShouldUseMaxCandidates) { + Collection *coll1; + std::vector fields = {field("title", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + for(size_t i = 0; i < 20; i++) { + nlohmann::json doc; + doc["title"] = "Independent" + std::to_string(i); + doc["points"] = i; + coll1->add(doc.dump()); + } + + size_t max_candidates = 20; + auto results = coll1->search("independent", {"title"}, "", {}, {}, {2}, 30, 1, FREQUENCY, {false}, 0, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000*1000, 4, 7, + off, max_candidates).get(); + + ASSERT_EQ(20, results["hits"].size()); +} + TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnMultiField) { Collection *coll1; std::vector fields = {field("location", field_types::STRING, false), diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index 362ef643..95af7546 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -253,6 +253,40 @@ TEST_F(CoreAPIUtilsTest, MultiSearchEmbeddedKeys) { } +TEST_F(CoreAPIUtilsTest, SearchEmbeddedPresetKey) { + nlohmann::json preset_value = R"( + {"per_page": 100} + )"_json; + + Option success_op = collectionManager.upsert_preset("apple", preset_value); + ASSERT_TRUE(success_op.ok()); + + std::shared_ptr req = std::make_shared(); + std::shared_ptr res = std::make_shared(nullptr); + + nlohmann::json embedded_params; + embedded_params["preset"] = "apple"; + req->embedded_params_vec.push_back(embedded_params); + req->params["collection"] = "foo"; + + get_search(req, res); + ASSERT_EQ("100", req->params["per_page"]); + + // with multi search + + req->params.clear(); + nlohmann::json body; + body["searches"] = nlohmann::json::array(); + nlohmann::json search; + search["collection"] = "users"; + search["filter_by"] = "age: > 100"; + body["searches"].push_back(search); + req->body = body.dump(); + + post_multi_search(req, res); + ASSERT_EQ("100", req->params["per_page"]); +} + TEST_F(CoreAPIUtilsTest, ExtractCollectionsFromRequestBody) { std::map req_params; std::string body = R"( @@ -956,3 +990,130 @@ TEST_F(CoreAPIUtilsTest, ExportIncludeExcludeFieldsWithFilter) { collectionManager.drop_collection("coll1"); } + + + +TEST_F(CoreAPIUtilsTest, TestProxy) { + std::string res; + std::unordered_map headers; + std::map res_headers; + + std::string url = "https://typesense.org"; + + long expected_status_code = HttpClient::get_instance().get_response(url, res, res_headers, headers); + + auto req = std::make_shared(); + auto resp = std::make_shared(nullptr); + + nlohmann::json body; + body["url"] = url; + body["method"] = "GET"; + body["headers"] = headers; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(expected_status_code, resp->status_code); + ASSERT_EQ(res, resp->body); +} + + +TEST_F(CoreAPIUtilsTest, TestProxyInvalid) { + nlohmann::json body; + + + + auto req = std::make_shared(); + auto resp = std::make_shared(nullptr); + + // test with url as empty string + body["url"] = ""; + body["method"] = "GET"; + body["headers"] = nlohmann::json::object(); + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("URL and method must be non-empty strings.", nlohmann::json::parse(resp->body)["message"]); + + // test with url as integer + body["url"] = 123; + body["method"] = "GET"; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("URL and method must be non-empty strings.", nlohmann::json::parse(resp->body)["message"]); + + // test with no url parameter + body.erase("url"); + body["method"] = "GET"; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("Missing required fields.", nlohmann::json::parse(resp->body)["message"]); + + + // test with invalid method + body["url"] = "https://typesense.org"; + body["method"] = "INVALID"; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("Parameter `method` must be one of GET, POST, PUT, DELETE.", nlohmann::json::parse(resp->body)["message"]); + + // test with method as integer + body["method"] = 123; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("URL and method must be non-empty strings.", nlohmann::json::parse(resp->body)["message"]); + + // test with no method parameter + body.erase("method"); + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("Missing required fields.", nlohmann::json::parse(resp->body)["message"]); + + + // test with body as integer + body["method"] = "POST"; + body["body"] = 123; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("Body must be a string.", nlohmann::json::parse(resp->body)["message"]); + + + // test with headers as integer + body["body"] = ""; + body["headers"] = 123; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(400, resp->status_code); + ASSERT_EQ("Headers must be a JSON object.", nlohmann::json::parse(resp->body)["message"]); +} \ No newline at end of file diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp new file mode 100644 index 00000000..2412b5a5 --- /dev/null +++ b/test/numeric_range_trie_test.cpp @@ -0,0 +1,802 @@ +#include +#include +#include "collection.h" +#include "numeric_range_trie_test.h" + +class NumericRangeTrieTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +void reset(uint32_t*& ids, uint32_t& ids_length) { + delete [] ids; + ids = nullptr; + ids_length = 0; +} + +TEST_F(NumericRangeTrieTest, SearchRange) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32768, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_range(32768, true, -32768, true, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(-32768, true, 32768, true, ids, ids_length); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-32768, true, 32768, false, ids, ids_length); + + ASSERT_EQ(pairs.size() - 1, ids_length); + for (uint32_t i = 0; i < pairs.size() - 1; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-32768, true, 134217728, true, ids, ids_length); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-32768, true, 0, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < 4; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-32768, true, 0, false, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < 4; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-32768, false, 32768, true, ids, ids_length); + + ASSERT_EQ(pairs.size() - 1, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + reset(ids, ids_length); + trie->search_range(-134217728, true, 32768, true, ids, ids_length); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-134217728, true, 134217728, true, ids, ids_length); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-1, true, 32768, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_range(-1, false, 32768, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_range(-1, true, 0, true, ids, ids_length); + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(-1, false, 0, false, ids, ids_length); + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(8192, true, 32768, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_range(8192, true, 0x2000000, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_range(16384, true, 16384, true, ids, ids_length); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(56, ids[0]); + + reset(ids, ids_length); + trie->search_range(16384, true, 16384, false, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(16384, false, 16384, true, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(16383, true, 16383, true, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(8193, true, 16383, true, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_range(-32768, true, -8192, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); +} + +TEST_F(NumericRangeTrieTest, SearchGreaterThan) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32768, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_greater_than(0, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_greater_than(-1, false, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_greater_than(-1, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_greater_than(-24576, true, ids, ids_length); + + ASSERT_EQ(7, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + reset(ids, ids_length); + trie->search_greater_than(-32768, false, ids, ids_length); + + ASSERT_EQ(7, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + reset(ids, ids_length); + trie->search_greater_than(8192, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_greater_than(8192, false, ids, ids_length); + + ASSERT_EQ(3, ids_length); + for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_greater_than(1000000, false, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_greater_than(-1000000, false, ids, ids_length); + + ASSERT_EQ(8, ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); +} + +TEST_F(NumericRangeTrieTest, SearchLessThan) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_less_than(0, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + reset(ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(-1, true, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(-16384, true, ids, ids_length); + + ASSERT_EQ(3, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(-16384, false, ids, ids_length); + + ASSERT_EQ(2, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(8192, true, ids, ids_length); + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(8192, false, ids, ids_length); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(-1000000, false, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_less_than(1000000, true, ids, ids_length); + + ASSERT_EQ(8, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + reset(ids, ids_length); +} + +TEST_F(NumericRangeTrieTest, SearchEqualTo) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32769, 41}, + {-32768, 43}, + {-32767, 45}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_equal_to(0, ids, ids_length); + + ASSERT_EQ(0, ids_length); + + reset(ids, ids_length); + trie->search_equal_to(-32768, ids, ids_length); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(43, ids[0]); + + reset(ids, ids_length); + trie->search_equal_to(24576, ids, ids_length); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(58, ids[0]); + + reset(ids, ids_length); + trie->search_equal_to(0x202020, ids, ids_length); + + ASSERT_EQ(0, ids_length); +} + +TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32769, 41}, + {-32768, 43}, + {-32767, 45}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {24576, 60}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + auto iterator = trie->search_equal_to(0); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(0x202020); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(-32768); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(43, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(24576); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(60, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(false, iterator.is_valid); + + + iterator.reset(); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.skip_to(4); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.skip_to(59); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(60, iterator.seq_id); + + iterator.skip_to(66); + ASSERT_EQ(false, iterator.is_valid); +} + +TEST_F(NumericRangeTrieTest, MultivalueData) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-0x202020, 32}, + {-32768, 5}, + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {0, 43}, + {0, 49}, + {1, 8}, + {256, 91}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91}, + {0x202020, 35}, + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_less_than(0, false, ids, ids_length); + + std::vector expected = {5, 8, 32, 35, 43}; + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(-16380, false, ids, ids_length); + + ASSERT_EQ(4, ids_length); + + expected = {5, 8, 32, 35}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_less_than(16384, false, ids, ids_length); + + ASSERT_EQ(7, ids_length); + + expected = {5, 8, 32, 35, 43, 49, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_greater_than(0, true, ids, ids_length); + + ASSERT_EQ(7, ids_length); + + expected = {8, 35, 43, 49, 56, 58, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_greater_than(256, true, ids, ids_length); + + ASSERT_EQ(5, ids_length); + + expected = {35, 49, 56, 58, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_greater_than(-32768, true, ids, ids_length); + + ASSERT_EQ(9, ids_length); + + expected = {5, 8, 32, 35, 43, 49, 56, 58, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_range(-32768, true, 0, true, ids, ids_length); + + ASSERT_EQ(6, ids_length); + + expected = {5, 8, 32, 35, 43, 49}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); +} + +TEST_F(NumericRangeTrieTest, Remove) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-0x202020, 32}, + {-32768, 5}, + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {0, 2}, + {0, 49}, + {1, 8}, + {256, 91}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91}, + {0x202020, 35}, + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_less_than(0, false, ids, ids_length); + + std::vector expected = {5, 8, 32, 35, 43}; + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->remove(-24576, 32); + trie->remove(-0x202020, 32); + + reset(ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); + + expected = {5, 8, 35, 43}; + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_equal_to(0, ids, ids_length); + + expected = {2, 49}; + ASSERT_EQ(2, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->remove(0, 2); + + reset(ids, ids_length); + trie->search_equal_to(0, ids, ids_length); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(49, ids[0]); + + reset(ids, ids_length); +} + +TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_range(-32768, true, 32768, true, ids, ids_length); + std::unique_ptr ids_guard(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(-32768, true, -1, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(1, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_greater_than(0, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_greater_than(15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_greater_than(-15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_less_than(0, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_less_than(-15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_less_than(15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_equal_to(15, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->remove(15, 0); + trie->remove(-15, 0); +} + +TEST_F(NumericRangeTrieTest, Integration) { + Collection *coll_array_fields; + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::vector fields = { + field("name", field_types::STRING, false), + field("rating", field_types::FLOAT, false), + field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(), + true), // Setting range index true. + field("years", field_types::INT32_ARRAY, false), + field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "", + nlohmann::json(), true), + field("tags", field_types::STRING_ARRAY, true) + }; + + std::vector sort_fields = { sort_by("age", "DESC") }; + + coll_array_fields = collectionManager.get_collection("coll_array_fields").get(); + if(coll_array_fields == nullptr) { + // ensure that default_sorting_field is a non-array numerical field + auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years"); + ASSERT_EQ(false, coll_op.ok()); + ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str()); + + // let's try again properly + coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age"); + coll_array_fields = coll_op.get(); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + auto add_op = coll_array_fields->add(json_line); + ASSERT_TRUE(add_op.ok()); + } + + infile.close(); + + query_fields = {"name"}; + std::vector facets; + // Searching on an int32 field + nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); + + std::vector ids = {"3", "1", "4"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // searching on an int64 array field - also ensure that padded space causes no issues + results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(4, results["hits"].size()); + + ids = {"1", "4", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); +}