From d440e510c0c4ce6a8bf4288d9e59f98adcaa7c5b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 26 May 2023 16:00:22 +0530 Subject: [PATCH] Add `NumericTrie`. --- include/numeric_range_trie_test.h | 60 +++++++++++++++ src/numeric_range_trie.cpp | 122 ++++++++++++++++++++++++++++++ test/numeric_range_trie_test.cpp | 41 ++++++++++ 3 files changed, 223 insertions(+) create mode 100644 include/numeric_range_trie_test.h create mode 100644 src/numeric_range_trie.cpp create mode 100644 test/numeric_range_trie_test.cpp diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h new file mode 100644 index 00000000..ce1da83d --- /dev/null +++ b/include/numeric_range_trie_test.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include "sorted_array.h" + +constexpr char MAX_LEVEL = 4; +constexpr short EXPANSE = 256; + +class NumericTrieNode { + NumericTrieNode** children = nullptr; + sorted_array seq_ids; + + void insert(const int32_t& value, const uint32_t& seq_id, char& level); + + void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); +public: + + ~NumericTrieNode() { + if (children != nullptr) { + for (auto i = 0; i < EXPANSE; i++) { + delete children[i]; + } + } + + delete [] children; + } + + void insert(const int32_t& value, const uint32_t& seq_id); + + void search_range(const int32_t& low,const int32_t& high, + uint32_t*& ids, uint32_t& ids_length); + + void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + + void search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); +}; + +class NumericTrie { + NumericTrieNode* negative_trie = nullptr; + NumericTrieNode* positive_trie = nullptr; + +public: + + ~NumericTrie() { + delete negative_trie; + delete positive_trie; + } + + void insert(const int32_t& value, const uint32_t& seq_id); + + void search_range(const int32_t& low, const bool& low_inclusive, + const int32_t& high, const bool& high_inclusive, + uint32_t*& ids, uint32_t& ids_length); + + void search_lesser(const int32_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); + + void search_greater(const int32_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); +}; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp new file mode 100644 index 00000000..c62765a5 --- /dev/null +++ b/src/numeric_range_trie.cpp @@ -0,0 +1,122 @@ +#include "numeric_range_trie_test.h" +#include "array_utils.h" + +void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { + if (value < 0) { + if (negative_trie == nullptr) { + negative_trie = new NumericTrieNode(); + } + + negative_trie->insert(std::abs(value), seq_id); + } else { + if (positive_trie == nullptr) { + positive_trie = new NumericTrieNode(); + } + + positive_trie->insert(value, seq_id); + } +} + +void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, + const int32_t& high, const bool& high_inclusive, + uint32_t*& ids, uint32_t& ids_length) { + if (low < 0 && high >= 0) { + // Have to combine the results of >low from negative_trie and low from negative_trie. + negative_trie->search_lesser(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + } + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + if (!(high == 0 && !high_inclusive)) { + positive_trie->search_lesser(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); + } + + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + + delete [] negative_ids; + delete [] positive_ids; + return; + } +} + +void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { + char level = 0; + return insert(value, seq_id, level); +} + +inline int get_index(const int32_t& value, char& level) { + return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; +} + +void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& level) { + if (level > MAX_LEVEL) { + return; + } + + if (!seq_ids.contains(seq_id)) { + seq_ids.append(seq_id); + } + + if (++level <= MAX_LEVEL) { + if (children == nullptr) { + children = new NumericTrieNode* [EXPANSE]{nullptr}; + } + + auto index = get_index(value, level); + if (children[index] == nullptr) { + children[index] = new NumericTrieNode(); + } + + return children[index]->insert(value, seq_id, level); + } +} + +void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { + char level = 0; + std::vector matches; + search_lesser_helper(value, level, matches); + + for (auto const& match: matches) { + uint32_t* out = nullptr; + auto const& m_seq_ids = match->seq_ids.uncompress(); + ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + + delete [] m_seq_ids; + delete [] ids; + ids = out; + } +} + +void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, std::vector& matches) { + if (level > MAX_LEVEL) { + return; + } else if (level == MAX_LEVEL) { + matches.push_back(this); + return; + } + + if (children == nullptr) { + return; + } + + auto index = get_index(value, ++level); + if (children[index] == nullptr) { + return; + } + + children[index]->search_lesser_helper(value, level, matches); + + while (--index >= 0) { + if (children[index] != nullptr) { + matches.push_back(children[index]); + } + } + + --level; +} diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp new file mode 100644 index 00000000..b21b659a --- /dev/null +++ b/test/numeric_range_trie_test.cpp @@ -0,0 +1,41 @@ +#include +#include "numeric_range_trie_test.h" + +class NumericRangeTrieTest : public ::testing::Test { +protected: + + virtual void SetUp() {} + + virtual void TearDown() {} +}; + +TEST_F(NumericRangeTrieTest, Insert) { + auto trie = new NumericTrie(); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32768, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_range(-32768, true, 32768, true, ids, ids_length); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + delete [] ids; + delete trie; +} \ No newline at end of file