diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index 737cf5f9..f5d6add2 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -3,21 +3,25 @@ #include #include "sorted_array.h" -constexpr char MAX_LEVEL = 4; constexpr short EXPANSE = 256; class NumericTrie { + char max_level = 4; + class Node { Node** children = nullptr; sorted_array seq_ids; - void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level); + void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level); - void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); + void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, + std::vector& matches); - void search_less_than_helper(const int32_t& value, char& level, std::vector& matches); + void search_less_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches); - void search_greater_than_helper(const int32_t& value, char& level, std::vector& matches); + void search_greater_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches); public: @@ -31,18 +35,18 @@ class NumericTrie { delete [] children; } - void insert(const int32_t& value, const uint32_t& seq_id); + void insert(const int64_t& value, const uint32_t& seq_id, const char& max_level); void get_all_ids(uint32_t*& ids, uint32_t& ids_length); - void search_range(const int32_t& low, const int32_t& high, + void search_range(const int64_t& low, const int64_t& high, const char& max_level, uint32_t*& ids, uint32_t& ids_length); - void search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); - void search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); - void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); }; Node* negative_trie = nullptr; @@ -50,22 +54,26 @@ class NumericTrie { public: + explicit NumericTrie(char num_bits = 32) { + max_level = num_bits / 8; + } + ~NumericTrie() { delete negative_trie; delete positive_trie; } - void insert(const int32_t& value, const uint32_t& seq_id); + void insert(const int64_t& value, const uint32_t& seq_id); - void search_range(const int32_t& low, const bool& low_inclusive, - const int32_t& high, const bool& high_inclusive, + void search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_less_than(const int32_t& value, const bool& inclusive, + void search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_greater_than(const int32_t& value, const bool& inclusive, + void search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length); }; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 8a3c6d89..0e0a8b9a 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -646,7 +646,7 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); if (f.is_integer()) { - if (f.is_int32() && f.range_index) { + if (f.range_index) { auto const& trie = index->range_index.at(a_filter.field_name); for (size_t fi = 0; fi < a_filter.values.size(); fi++) { @@ -718,28 +718,64 @@ void filter_result_iterator_t::init() { is_filter_result_initialized = true; return; } else if (f.is_float()) { - auto num_tree = index->numerical_index.at(a_filter.field_name); + if (f.range_index) { + auto const& trie = index->range_index.at(a_filter.field_name); - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - float value = (float)std::atof(filter_value.c_str()); - int64_t float_int64 = Index::float_to_int64_t(value); + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); - size_t result_size = filter_result.count; - if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, float_int64, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); - } else { - num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count); + fi++; + } else if (a_filter.comparators[fi] == EQUALS) { + trie->search_equal_to(float_int64, filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + uint32_t to_exclude_ids_len = 0; + trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len); + + auto all_ids = index->seq_ids->uncompress(); + filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(), + to_exclude_ids, to_exclude_ids_len, &filter_result.docs); + + delete[] all_ids; + delete[] to_exclude_ids; + } else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) { + trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS, + filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) { + trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS, + filter_result.docs, filter_result.count); + } } + } else { + auto num_tree = index->numerical_index.at(a_filter.field_name); - filter_result.count = result_size; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); + + size_t result_size = filter_result.count; + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi+1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); + fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, float_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); + } + + filter_result.count = result_size; + } } if (a_filter.apply_not_equals) { diff --git a/src/index.cpp b/src/index.cpp index cc2510ec..d7c9940e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -90,7 +90,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* numerical_index.emplace(a_field.name, num_tree); if (a_field.range_index) { - auto trie = new NumericTrie(); + auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64); range_index.emplace(a_field.name, trie); } } @@ -767,6 +767,15 @@ void Index::index_field_in_memory(const field& afield, std::vector } else if(afield.type == field_types::INT64) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + int64_t value = record.doc[afield.name].get(); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -776,6 +785,16 @@ void Index::index_field_in_memory(const field& afield, std::vector } else if(afield.type == field_types::FLOAT) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + float fvalue = record.doc[afield.name].get(); + int64_t value = float_to_int64_t(fvalue); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -928,23 +947,30 @@ void Index::index_field_in_memory(const field& afield, std::vector if(afield.type == field_types::INT32_ARRAY) { const int32_t value = arr_value; + num_tree->insert(value, seq_id); if (afield.range_index) { trie->insert(value, seq_id); } - - num_tree->insert(value, seq_id); } else if(afield.type == field_types::INT64_ARRAY) { const int64_t value = arr_value; num_tree->insert(value, seq_id); + + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::FLOAT_ARRAY) { const float fvalue = arr_value; int64_t value = float_to_int64_t(fvalue); num_tree->insert(value, seq_id); + + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::BOOL_ARRAY) { diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 6ee805d0..3076f873 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -2,24 +2,24 @@ #include "numeric_range_trie_test.h" #include "array_utils.h" -void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { +void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) { if (value < 0) { if (negative_trie == nullptr) { negative_trie = new NumericTrie::Node(); } - negative_trie->insert(std::abs(value), seq_id); + negative_trie->insert(std::abs(value), seq_id, max_level); } else { if (positive_trie == nullptr) { positive_trie = new NumericTrie::Node(); } - positive_trie->insert(value, seq_id); + positive_trie->insert(value, seq_id, max_level); } } -void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, - const int32_t& high, const bool& high_inclusive, +void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length) { if (low > high) { return; @@ -34,7 +34,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, auto abs_low = std::abs(low); // Since we store absolute values, search_lesser would yield result for >low from negative_trie. - negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, + negative_ids, negative_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); @@ -47,7 +48,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_less_than(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); + positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, + positive_ids, positive_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); @@ -64,7 +66,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, positive_ids, positive_ids_length); uint32_t* out = nullptr; @@ -84,6 +86,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, // Since we store absolute values, switching low and high would produce the correct result. auto abs_high = std::abs(high), abs_low = std::abs(low); negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + max_level, negative_ids, negative_ids_length); uint32_t* out = nullptr; @@ -95,7 +98,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, } } -void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { uint32_t* positive_ids = nullptr; @@ -119,7 +122,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); + positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); @@ -136,7 +139,8 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv auto abs_low = std::abs(value); // Since we store absolute values, search_lesser would yield result for >value from negative_trie. - negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, + negative_ids, negative_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); @@ -163,7 +167,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv } } -void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] if (negative_trie != nullptr) { uint32_t* negative_ids = nullptr; @@ -190,7 +194,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, auto abs_low = std::abs(value); // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); + negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, + negative_ids, negative_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); @@ -204,7 +209,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, if (positive_trie != nullptr) { uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_less_than(inclusive ? value : value - 1, positive_ids, positive_ids_length); + positive_trie->search_less_than(inclusive ? value : value - 1, max_level, + positive_ids, positive_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); @@ -231,7 +237,7 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, } } -void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) { if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { return; } @@ -240,9 +246,9 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t uint32_t equal_ids_length = 0; if (value < 0) { - negative_trie->search_equal_to(std::abs(value), equal_ids, equal_ids_length); + negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length); } else { - positive_trie->search_equal_to(value, equal_ids, equal_ids_length); + positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length); } uint32_t* out = nullptr; @@ -253,12 +259,12 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t ids = out; } -void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) { +void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) { char level = 0; - return insert_helper(value, seq_id, level); + return insert_helper(value, seq_id, level, max_level); } -inline int get_index(const int32_t& value, char& level) { +inline int get_index(const int64_t& value, const char& level, const char& max_level) { // Values are index considering higher order of the bytes first. // 0x01020408 (16909320) would be indexed in the trie as follows: // Level Index @@ -266,11 +272,11 @@ inline int get_index(const int32_t& value, char& level) { // 2 2 // 3 4 // 4 8 - return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; + return (value >> (8 * (max_level - level))) & 0xFF; } -void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) { - if (level > MAX_LEVEL) { +void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { + if (level > max_level) { return; } @@ -279,17 +285,17 @@ void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_ seq_ids.append(seq_id); } - if (++level <= MAX_LEVEL) { + if (++level <= max_level) { if (children == nullptr) { children = new NumericTrie::Node* [EXPANSE]{nullptr}; } - auto index = get_index(value, level); + auto index = get_index(value, level, max_level); if (children[index] == nullptr) { children[index] = new NumericTrie::Node(); } - return children[index]->insert_helper(value, seq_id, level); + return children[index]->insert_helper(value, seq_id, level, max_level); } } @@ -298,10 +304,11 @@ void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { ids_length = seq_ids.getLength(); } -void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; - search_less_than_helper(value, level, matches); + search_less_than_helper(value, level, max_level, matches); std::vector consolidated_ids; for (auto const& match: matches) { @@ -324,17 +331,18 @@ void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, u ids = out; } -void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& level, std::vector& matches) { - if (level == MAX_LEVEL) { +void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches) { + if (level == max_level) { matches.push_back(this); return; - } else if (level > MAX_LEVEL || children == nullptr) { + } else if (level > max_level || children == nullptr) { return; } - auto index = get_index(value, ++level); + auto index = get_index(value, ++level, max_level); if (children[index] != nullptr) { - children[index]->search_less_than_helper(value, level, matches); + children[index]->search_less_than_helper(value, level, max_level, matches); } while (--index >= 0) { @@ -346,12 +354,13 @@ void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& leve --level; } -void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { if (low > high) { return; } std::vector matches; - search_range_helper(low, high, matches); + search_range_helper(low, high, max_level, matches); std::vector consolidated_ids; for (auto const& match: matches) { @@ -374,24 +383,24 @@ void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, ui ids = out; } -void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& high, - std::vector& matches) { +void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, + std::vector& matches) { // Segregating the nodes into matching low, in-between, and matching high. NumericTrie::Node* root = this; char level = 1; - auto low_index = get_index(low, level), high_index = get_index(high, level); + auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level); // Keep updating the root while the range is contained within a single child node. - while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) { + while (root->children != nullptr && low_index == high_index && level < max_level) { if (root->children[low_index] == nullptr) { return; } root = root->children[low_index]; level++; - low_index = get_index(low, level); - high_index = get_index(high, level); + low_index = get_index(low, level, max_level); + high_index = get_index(high, level, max_level); } if (root->children == nullptr) { @@ -405,7 +414,7 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h if (root->children[low_index] != nullptr) { // Collect all the sub-nodes that are greater than low. - root->children[low_index]->search_greater_than_helper(low, level, matches); + root->children[low_index]->search_greater_than_helper(low, level, max_level, matches); } auto index = low_index + 1; @@ -420,14 +429,15 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h if (index < EXPANSE && index == high_index && root->children[index] != nullptr) { // Collect all the sub-nodes that are lesser than high. - root->children[index]->search_less_than_helper(high, level, matches); + root->children[index]->search_less_than_helper(high, level, max_level, matches); } } -void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; - search_greater_than_helper(value, level, matches); + search_greater_than_helper(value, level, max_level, matches); std::vector consolidated_ids; for (auto const& match: matches) { @@ -450,17 +460,18 @@ void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids ids = out; } -void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& level, std::vector& matches) { - if (level == MAX_LEVEL) { +void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches) { + if (level == max_level) { matches.push_back(this); return; - } else if (level > MAX_LEVEL || children == nullptr) { + } else if (level > max_level || children == nullptr) { return; } - auto index = get_index(value, ++level); + auto index = get_index(value, ++level, max_level); if (children[index] != nullptr) { - children[index]->search_greater_than_helper(value, level, matches); + children[index]->search_greater_than_helper(value, level, max_level, matches); } while (++index < EXPANSE) { @@ -472,18 +483,19 @@ void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& l --level; } -void NumericTrie::Node::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { char level = 1; Node* root = this; - auto index = get_index(value, level); + auto index = get_index(value, level, max_level); - while (level <= MAX_LEVEL) { + while (level <= max_level) { if (root->children == nullptr || root->children[index] == nullptr) { return; } root = root->children[index]; - index = get_index(value, ++level); + index = get_index(value, ++level, max_level); } root->get_all_ids(ids, ids_length); diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 5d9cca7d..29dff68a 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -604,7 +604,8 @@ TEST_F(NumericRangeTrieTest, Integration) { field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(), true), // Setting range index true. field("years", field_types::INT32_ARRAY, false), - field("timestamps", field_types::INT64_ARRAY, false), + field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "", + nlohmann::json(), true), field("tags", field_types::STRING_ARRAY, true) }; @@ -626,32 +627,18 @@ TEST_F(NumericRangeTrieTest, Integration) { while (std::getline(infile, json_line)) { auto add_op = coll_array_fields->add(json_line); - LOG(INFO) << add_op.error(); ASSERT_TRUE(add_op.ok()); } infile.close(); - // Plain search with no filters - results should be sorted by rank fields query_fields = {"name"}; std::vector facets; - nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(5, results["hits"].size()); - - std::vector ids = {"3", "1", "4", "0", "2"}; - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - // Searching on an int32 field - results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(3, results["hits"].size()); - ids = {"3", "1", "4"}; + std::vector ids = {"3", "1", "4"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); @@ -659,4 +646,20 @@ TEST_F(NumericRangeTrieTest, Integration) { std::string id = ids.at(i); ASSERT_STREQ(id.c_str(), result_id.c_str()); } + + // searching on an int64 array field - also ensure that padded space causes no issues + results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(4, results["hits"].size()); + + ids = {"1", "4", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); }