From ba101d0b4034255db04dbefcda49987367cb7c85 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 27 Jan 2022 14:54:50 +0530 Subject: [PATCH] Infix basics. --- CMakeLists.txt | 1 + include/collection.h | 5 +- include/field.h | 23 +- include/index.h | 32 +- include/tsl/array-hash/array_growth_policy.h | 296 +++ include/tsl/array-hash/array_hash.h | 1807 +++++++++++++++ include/tsl/array-hash/array_map.h | 929 ++++++++ include/tsl/array-hash/array_set.h | 716 ++++++ include/tsl/htrie_hash.h | 2076 ++++++++++++++++++ include/tsl/htrie_map.h | 668 ++++++ include/tsl/htrie_set.h | 578 +++++ src/collection.cpp | 15 +- src/collection_manager.cpp | 41 +- src/index.cpp | 126 +- src/posting_list.cpp | 11 + test/collection_infix_search_test.cpp | 222 ++ 16 files changed, 7534 insertions(+), 12 deletions(-) create mode 100644 include/tsl/array-hash/array_growth_policy.h create mode 100644 include/tsl/array-hash/array_hash.h create mode 100644 include/tsl/array-hash/array_map.h create mode 100644 include/tsl/array-hash/array_set.h create mode 100644 include/tsl/htrie_hash.h create mode 100644 include/tsl/htrie_map.h create mode 100644 include/tsl/htrie_set.h create mode 100644 test/collection_infix_search_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index deea2f5c..d0b038bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,7 @@ FILE(GLOB SRC_FILES src/*.cpp ${DEP_ROOT_DIR}/${KAKASI_NAME}/data/*.cpp) FILE(GLOB TEST_FILES test/*.cpp) include_directories(include) +include_directories(include/tsl) include_directories(/usr/local/include) include_directories(${OPENSSL_INCLUDE_DIR}) include_directories(${CURL_INCLUDE_DIR}) diff --git a/include/collection.h b/include/collection.h index 28d45334..69046a86 100644 --- a/include/collection.h +++ b/include/collection.h @@ -395,7 +395,10 @@ public: size_t min_len_1typo = 4, size_t min_len_2typo = 7, bool split_join_tokens = true, - size_t max_candidates = 4) const; + size_t max_candidates = 4, + const std::vector& infixes = {off}, + const size_t max_extra_prefix = INT16_MAX, + const size_t max_extra_suffix = INT16_MAX) const; Option get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids); diff --git a/include/field.h b/include/field.h index 0102e7b1..f6753af8 100644 --- a/include/field.h +++ b/include/field.h @@ -38,6 +38,7 @@ namespace fields { static const std::string optional = "optional"; static const std::string index = "index"; static const std::string sort = "sort"; + static const std::string infix = "infix"; static const std::string locale = "locale"; } @@ -49,9 +50,10 @@ struct field { bool index; std::string locale; bool sort; + bool infix; field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, - bool index = true, std::string locale = "", int sort = -1) : + bool index = true, std::string locale = "", int sort = -1, int infix = -1) : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale) { if(sort != -1) { @@ -59,6 +61,8 @@ struct field { } else { this->sort = is_num_sort_field(); } + + this->infix = (infix != -1) ? bool(infix) : false; } bool is_auto() const { @@ -361,6 +365,11 @@ struct field { field_json[fields::name].get() + std::string("` should be a boolean.")); } + if(field_json.count(fields::infix) != 0 && !field_json.at(fields::infix).is_boolean()) { + return Option(400, std::string("The `infix` property of the field `") + + field_json[fields::name].get() + std::string("` should be a boolean.")); + } + if(field_json.count(fields::locale) != 0){ if(!field_json.at(fields::locale).is_string()) { return Option(400, std::string("The `locale` property of the field `") + @@ -395,6 +404,10 @@ struct field { field_json[fields::sort] = false; } + if(field_json.count(fields::infix) == 0) { + field_json[fields::infix] = false; + } + if(field_json[fields::optional] == false) { return Option(400, "Field `.*` must be an optional field."); } @@ -409,7 +422,7 @@ struct field { field fallback_field(field_json["name"], field_json["type"], field_json["facet"], field_json["optional"], field_json[fields::index], field_json[fields::locale], - field_json[fields::sort]); + field_json[fields::sort], field_json[fields::infix]); if(fallback_field.has_valid_type()) { fallback_field_type = fallback_field.type; @@ -444,6 +457,10 @@ struct field { } } + if(field_json.count(fields::infix) == 0) { + field_json[fields::infix] = false; + } + if(field_json.count(fields::optional) == 0) { // dynamic fields are always optional bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]); @@ -453,7 +470,7 @@ struct field { fields.emplace_back( field(field_json[fields::name], field_json[fields::type], field_json[fields::facet], field_json[fields::optional], field_json[fields::index], field_json[fields::locale], - field_json[fields::sort]) + field_json[fields::sort], field_json[fields::infix]) ); } diff --git a/include/index.h b/include/index.h index af1bf097..db43a6da 100644 --- a/include/index.h +++ b/include/index.h @@ -22,11 +22,15 @@ #include "posting_list.h" #include "threadpool.h" #include "adi_tree.h" +#include "tsl/htrie_set.h" static constexpr size_t ARRAY_FACET_DIM = 4; using facet_map_t = spp::sparse_hash_map; using array_mapped_facet_t = std::array; +static constexpr size_t ARRAY_INFIX_DIM = 4; +using array_mapped_infix_t = std::vector*>; + struct token_t { size_t position; std::string value; @@ -258,6 +262,12 @@ struct override_t { } }; +enum infix_t { + always, + fallback, + off +}; + struct search_args { std::vector field_query_tokens; std::vector search_fields; @@ -287,6 +297,9 @@ struct search_args { size_t min_len_1typo; size_t min_len_2typo; size_t max_candidates; + std::vector infixes; + const size_t max_extra_prefix; + const size_t max_extra_suffix; spp::sparse_hash_set groups_processed; std::vector> searched_queries; @@ -312,7 +325,10 @@ struct search_args { size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, - size_t max_candidates): + size_t max_candidates, + const std::vector& infixes, + const size_t max_extra_prefix, + const size_t max_extra_suffix): field_query_tokens(field_query_tokens), search_fields(search_fields), filters(filters), facets(facets), included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), @@ -323,7 +339,8 @@ struct search_args { prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0), exhaustive_search(exhaustive_search), concurrency(concurrency), filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms), - min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates) { + min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates), + infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix) { const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory topster = new Topster(topster_size, group_limit); @@ -431,6 +448,9 @@ private: // str_sort_field => adi_tree_t spp::sparse_hash_map str_sort_index; + // infix field => value + spp::sparse_hash_map infix_index; + // geo_array_field => (seq_id => values) used for exact filtering of geo array records spp::sparse_hash_map*> geo_array_index; @@ -697,7 +717,10 @@ public: size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, - size_t max_candidates) const; + size_t max_candidates, + const std::vector& infixes, + const size_t max_extra_prefix, + const size_t max_extra_suffix) const; Option remove(const uint32_t seq_id, const nlohmann::json & document, const bool is_update); @@ -758,6 +781,9 @@ public: uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, uint32_t filter_ids_length, const size_t concurrency) const; + void search_infix(const std::string& query, const std::string& field_name, std::vector& ids, + size_t max_extra_prefix, size_t max_extra_suffix) const; + void curate_filtered_ids(const std::vector& filters, const std::set& curated_ids, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, uint32_t*& filter_ids, uint32_t& filter_ids_length, const std::vector& curated_ids_sorted) const; diff --git a/include/tsl/array-hash/array_growth_policy.h b/include/tsl/array-hash/array_growth_policy.h new file mode 100644 index 00000000..65bba795 --- /dev/null +++ b/include/tsl/array-hash/array_growth_policy.h @@ -0,0 +1,296 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ARRAY_GROWTH_POLICY_H +#define TSL_ARRAY_GROWTH_POLICY_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tsl { +namespace ah { + +/** + * Grow the hash table by a factor of GrowthFactor keeping the bucket count to a + * power of two. It allows the table to use a mask operation instead of a modulo + * operation to map a hash to a bucket. + * + * GrowthFactor must be a power of two >= 2. + */ +template +class power_of_two_growth_policy { + public: + /** + * Called on the hash table creation and on rehash. The number of buckets for + * the table is passed in parameter. This number is a minimum, the policy may + * update this value with a higher value if needed (but not lower). + * + * If 0 is given, min_bucket_count_in_out must still be 0 after the policy + * creation and bucket_for_hash must always return 0 in this case. + */ + explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) { + if (min_bucket_count_in_out > max_bucket_count()) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + if (min_bucket_count_in_out > 0) { + min_bucket_count_in_out = + round_up_to_power_of_two(min_bucket_count_in_out); + m_mask = min_bucket_count_in_out - 1; + } else { + m_mask = 0; + } + } + + /** + * Return the bucket [0, bucket_count()) to which the hash belongs. + * If bucket_count() is 0, it must always return 0. + */ + std::size_t bucket_for_hash(std::size_t hash) const noexcept { + return hash & m_mask; + } + + /** + * Return the number of buckets that should be used on next growth. + */ + std::size_t next_bucket_count() const { + if ((m_mask + 1) > max_bucket_count() / GrowthFactor) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + return (m_mask + 1) * GrowthFactor; + } + + /** + * Return the maximum number of buckets supported by the policy. + */ + std::size_t max_bucket_count() const { + // Largest power of two. + return (std::numeric_limits::max() / 2) + 1; + } + + /** + * Reset the growth policy as if it was created with a bucket count of 0. + * After a clear, the policy must always return 0 when bucket_for_hash is + * called. + */ + void clear() noexcept { m_mask = 0; } + + private: + static std::size_t round_up_to_power_of_two(std::size_t value) { + if (is_power_of_two(value)) { + return value; + } + + if (value == 0) { + return 1; + } + + --value; + for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) { + value |= value >> i; + } + + return value + 1; + } + + static constexpr bool is_power_of_two(std::size_t value) { + return value != 0 && (value & (value - 1)) == 0; + } + + protected: + static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, + "GrowthFactor must be a power of two >= 2."); + + std::size_t m_mask; +}; + +/** + * Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo + * to map a hash to a bucket. Slower but it can be useful if you want a slower + * growth. + */ +template > +class mod_growth_policy { + public: + explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) { + if (min_bucket_count_in_out > max_bucket_count()) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + if (min_bucket_count_in_out > 0) { + m_mod = min_bucket_count_in_out; + } else { + m_mod = 1; + } + } + + std::size_t bucket_for_hash(std::size_t hash) const noexcept { + return hash % m_mod; + } + + std::size_t next_bucket_count() const { + if (m_mod == max_bucket_count()) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + const double next_bucket_count = + std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR); + if (!std::isnormal(next_bucket_count)) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + if (next_bucket_count > double(max_bucket_count())) { + return max_bucket_count(); + } else { + return std::size_t(next_bucket_count); + } + } + + std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; } + + void clear() noexcept { m_mod = 1; } + + private: + static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = + 1.0 * GrowthFactor::num / GrowthFactor::den; + static const std::size_t MAX_BUCKET_COUNT = + std::size_t(double(std::numeric_limits::max() / + REHASH_SIZE_MULTIPLICATION_FACTOR)); + + static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, + "Growth factor should be >= 1.1."); + + std::size_t m_mod; +}; + +namespace detail { + +static constexpr const std::array PRIMES = { + {1ul, 5ul, 17ul, 29ul, 37ul, + 53ul, 67ul, 79ul, 97ul, 131ul, + 193ul, 257ul, 389ul, 521ul, 769ul, + 1031ul, 1543ul, 2053ul, 3079ul, 6151ul, + 12289ul, 24593ul, 49157ul, 98317ul, 196613ul, + 393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul, + 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul, + 402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}}; + +template +static constexpr std::size_t mod(std::size_t hash) { + return hash % PRIMES[IPrime]; +} + +// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for +// faster modulo as the compiler can optimize the modulo code better with a +// constant known at the compilation. +static constexpr const std::array MOD_PRIME = + {{&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, + &mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>, + &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>, + &mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, + &mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>, + &mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}}; + +} // namespace detail + +/** + * Grow the hash table by using prime numbers as bucket count. Slower than + * tsl::ah::power_of_two_growth_policy in general but will probably distribute + * the values around better in the buckets with a poor hash function. + * + * To allow the compiler to optimize the modulo operation, a lookup table is + * used with constant primes numbers. + * + * With a switch the code would look like: + * \code + * switch(iprime) { // iprime is the current prime of the hash table + * case 0: hash % 5ul; + * break; + * case 1: hash % 17ul; + * break; + * case 2: hash % 29ul; + * break; + * ... + * } + * \endcode + * + * Due to the constant variable in the modulo the compiler is able to optimize + * the operation by a series of multiplications, substractions and shifts. + * + * The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) + * * 5' in a 64 bits environment. + */ +class prime_growth_policy { + public: + explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) { + auto it_prime = std::lower_bound( + detail::PRIMES.begin(), detail::PRIMES.end(), min_bucket_count_in_out); + if (it_prime == detail::PRIMES.end()) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + m_iprime = static_cast( + std::distance(detail::PRIMES.begin(), it_prime)); + if (min_bucket_count_in_out > 0) { + min_bucket_count_in_out = *it_prime; + } else { + min_bucket_count_in_out = 0; + } + } + + std::size_t bucket_for_hash(std::size_t hash) const noexcept { + return detail::MOD_PRIME[m_iprime](hash); + } + + std::size_t next_bucket_count() const { + if (m_iprime + 1 >= detail::PRIMES.size()) { + throw std::length_error("The hash table exceeds its maximum size."); + } + + return detail::PRIMES[m_iprime + 1]; + } + + std::size_t max_bucket_count() const { return detail::PRIMES.back(); } + + void clear() noexcept { m_iprime = 0; } + + private: + unsigned int m_iprime; + + static_assert(std::numeric_limits::max() >= + detail::PRIMES.size(), + "The type of m_iprime is not big enough."); +}; + +} // namespace ah +} // namespace tsl + +#endif diff --git a/include/tsl/array-hash/array_hash.h b/include/tsl/array-hash/array_hash.h new file mode 100644 index 00000000..b639e8a4 --- /dev/null +++ b/include/tsl/array-hash/array_hash.h @@ -0,0 +1,1807 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ARRAY_HASH_H +#define TSL_ARRAY_HASH_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "array_growth_policy.h" + +/* + * __has_include is a bit useless + * (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=79433), check also __cplusplus + * version. + */ +#ifdef __has_include +#if __has_include() && __cplusplus >= 201703L +#define TSL_AH_HAS_STRING_VIEW +#endif +#endif + +#ifdef TSL_AH_HAS_STRING_VIEW +#include +#endif + +#ifdef TSL_DEBUG +#define tsl_ah_assert(expr) assert(expr) +#else +#define tsl_ah_assert(expr) (static_cast(0)) +#endif + +/** + * Implementation of the array hash structure described in the + * "Cache-conscious collision resolution in string hash tables." (Askitis + * Nikolas and Justin Zobel, 2005) paper. + */ +namespace tsl { + +namespace ah { + +template +struct str_hash { +#ifdef TSL_AH_HAS_STRING_VIEW + std::size_t operator()(const CharT* key, std::size_t key_size) const { + return std::hash>()( + std::basic_string_view(key, key_size)); + } +#else + /** + * FNV-1a hash + */ + std::size_t operator()(const CharT* key, std::size_t key_size) const { + static const std::size_t init = std::size_t( + (sizeof(std::size_t) == 8) ? 0xcbf29ce484222325 : 0x811c9dc5); + static const std::size_t multiplier = + std::size_t((sizeof(std::size_t) == 8) ? 0x100000001b3 : 0x1000193); + + std::size_t hash = init; + for (std::size_t i = 0; i < key_size; ++i) { + hash ^= key[i]; + hash *= multiplier; + } + + return hash; + } +#endif +}; + +template +struct str_equal { + bool operator()(const CharT* key_lhs, std::size_t key_size_lhs, + const CharT* key_rhs, std::size_t key_size_rhs) const { + if (key_size_lhs != key_size_rhs) { + return false; + } else { + return std::memcmp(key_lhs, key_rhs, key_size_lhs * sizeof(CharT)) == 0; + } + } +}; +} // namespace ah + +namespace detail_array_hash { + +template +struct is_iterator : std::false_type {}; + +template +struct is_iterator::iterator_category, + void>::value>::type> : std::true_type {}; + +static constexpr bool is_power_of_two(std::size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +template +static T numeric_cast(U value, + const char* error_message = "numeric_cast() failed.") { + T ret = static_cast(value); + if (static_cast(ret) != value) { + throw std::runtime_error(error_message); + } + + const bool is_same_signedness = + (std::is_unsigned::value && std::is_unsigned::value) || + (std::is_signed::value && std::is_signed::value); + if (!is_same_signedness && (ret < T{}) != (value < U{})) { + throw std::runtime_error(error_message); + } + + return ret; +} + +/** + * Fixed size type used to represent size_type values on serialization. Need to + * be big enough to represent a std::size_t on 32 and 64 bits platforms, and + * must be the same size on both platforms. + */ +using slz_size_type = std::uint64_t; + +template +static T deserialize_value(Deserializer& deserializer) { + // MSVC < 2017 is not conformant, circumvent the problem by removing the + // template keyword +#if defined(_MSC_VER) && _MSC_VER < 1910 + return deserializer.Deserializer::operator()(); +#else + return deserializer.Deserializer::template operator()(); +#endif +} + +/** + * For each string in the bucket, store the size of the string, the chars of the + * string and T, if it's not void. T should be either void or an unsigned type. + * + * End the buffer with END_OF_BUCKET flag. END_OF_BUCKET has the same type as + * the string size variable. + * + * m_buffer (CharT*): + * | size of str1 (KeySizeT) | str1 (const CharT*) | value (T if T != void) | + * ... | | size of strN (KeySizeT) | strN (const CharT*) | value (T if T != + * void) | END_OF_BUCKET (KeySizeT) | + * + * m_buffer is null if there is no string in the bucket. + * + * KeySizeT and T are extended to be a multiple of CharT when stored in the + * buffer. + * + * Use std::malloc and std::free instead of new and delete so we can have access + * to std::realloc. + */ +template +class array_bucket { + template + using has_mapped_type = + typename std::integral_constant::value>; + + static_assert(!has_mapped_type::value || std::is_unsigned::value, + "T should be either void or an unsigned type."); + + static_assert(std::is_unsigned::value, + "KeySizeT should be an unsigned type."); + + public: + template + class array_bucket_iterator; + + using char_type = CharT; + using key_size_type = KeySizeT; + using mapped_type = T; + using size_type = std::size_t; + using key_equal = KeyEqual; + using iterator = array_bucket_iterator; + using const_iterator = array_bucket_iterator; + + static_assert(sizeof(KeySizeT) <= sizeof(size_type), + "sizeof(KeySizeT) should be <= sizeof(std::size_t;)"); + static_assert(std::is_unsigned::value, ""); + + private: + /** + * Return how much space in bytes the type U will take when stored in the + * buffer. As the buffer is of type CharT, U may take more space than + * sizeof(U). + * + * Example: sizeof(CharT) = 4, sizeof(U) = 2 => U will take 4 bytes in the + * buffer instead of 2. + */ + template + static constexpr size_type sizeof_in_buff() noexcept { + static_assert(is_power_of_two(sizeof(U)), + "sizeof(U) should be a power of two."); + static_assert(is_power_of_two(sizeof(CharT)), + "sizeof(CharT) should be a power of two."); + + return std::max(sizeof(U), sizeof(CharT)); + } + + /** + * Same as sizeof_in_buff, but instead of returning the size in bytes + * return it in term of sizeof(CharT). + */ + template + static constexpr size_type size_as_char_t() noexcept { + return sizeof_in_buff() / sizeof(CharT); + } + + static key_size_type read_key_size(const CharT* buffer) noexcept { + key_size_type key_size; + std::memcpy(&key_size, buffer, sizeof(key_size)); + + return key_size; + } + + static mapped_type read_value(const CharT* buffer) noexcept { + mapped_type value; + std::memcpy(&value, buffer, sizeof(value)); + + return value; + } + + static bool is_end_of_bucket(const CharT* buffer) noexcept { + return read_key_size(buffer) == END_OF_BUCKET; + } + + public: + /** + * Return the size required for an entry with a key of size 'key_size'. + */ + template ::value>::type* = nullptr> + static size_type entry_required_bytes(size_type key_size) noexcept { + return sizeof_in_buff() + + (key_size + KEY_EXTRA_SIZE) * sizeof(CharT); + } + + template ::value>::type* = nullptr> + static size_type entry_required_bytes(size_type key_size) noexcept { + return sizeof_in_buff() + + (key_size + KEY_EXTRA_SIZE) * sizeof(CharT) + + sizeof_in_buff(); + } + + private: + /** + * Return the size of the current entry in buffer. + */ + static size_type entry_size_bytes(const CharT* buffer) noexcept { + return entry_required_bytes(read_key_size(buffer)); + } + + public: + template + class array_bucket_iterator { + friend class array_bucket; + + using buffer_type = + typename std::conditional::type; + + explicit array_bucket_iterator(buffer_type* position) noexcept + : m_position(position) {} + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = void; + using difference_type = std::ptrdiff_t; + using reference = void; + using pointer = void; + + public: + array_bucket_iterator() noexcept : m_position(nullptr) {} + + const CharT* key() const { + return m_position + size_as_char_t(); + } + + size_type key_size() const { return read_key_size(m_position); } + + template ::value>::type* = nullptr> + U value() const { + return read_value(m_position + size_as_char_t() + + key_size() + KEY_EXTRA_SIZE); + } + + template ::value && !IsConst && + std::is_same::value>::type* = nullptr> + void set_value(U value) noexcept { + std::memcpy(m_position + size_as_char_t() + key_size() + + KEY_EXTRA_SIZE, + &value, sizeof(value)); + } + + array_bucket_iterator& operator++() { + m_position += entry_size_bytes(m_position) / sizeof(CharT); + if (is_end_of_bucket(m_position)) { + m_position = nullptr; + } + + return *this; + } + + array_bucket_iterator operator++(int) { + array_bucket_iterator tmp(*this); + ++*this; + + return tmp; + } + + friend bool operator==(const array_bucket_iterator& lhs, + const array_bucket_iterator& rhs) { + return lhs.m_position == rhs.m_position; + } + + friend bool operator!=(const array_bucket_iterator& lhs, + const array_bucket_iterator& rhs) { + return !(lhs == rhs); + } + + private: + buffer_type* m_position; + }; + + static iterator end_it() noexcept { return iterator(nullptr); } + + static const_iterator cend_it() noexcept { return const_iterator(nullptr); } + + public: + array_bucket() : m_buffer(nullptr) {} + + /** + * Reserve 'size' in the buffer of the bucket. The created bucket is empty. + */ + array_bucket(std::size_t size) : m_buffer(nullptr) { + if (size == 0) { + return; + } + + m_buffer = static_cast(std::malloc( + size * sizeof(CharT) + sizeof_in_buff())); + if (m_buffer == nullptr) { + throw std::bad_alloc(); + } + + const auto end_of_bucket = END_OF_BUCKET; + std::memcpy(m_buffer, &end_of_bucket, sizeof(end_of_bucket)); + } + + ~array_bucket() { clear(); } + + array_bucket(const array_bucket& other) { + if (other.m_buffer == nullptr) { + m_buffer = nullptr; + return; + } + + const size_type other_buffer_size = other.size(); + m_buffer = static_cast( + std::malloc(other_buffer_size * sizeof(CharT) + + sizeof_in_buff())); + if (m_buffer == nullptr) { + throw std::bad_alloc(); + } + + std::memcpy(m_buffer, other.m_buffer, other_buffer_size * sizeof(CharT)); + + const auto end_of_bucket = END_OF_BUCKET; + std::memcpy(m_buffer + other_buffer_size, &end_of_bucket, + sizeof(end_of_bucket)); + } + + array_bucket(array_bucket&& other) noexcept : m_buffer(other.m_buffer) { + other.m_buffer = nullptr; + } + + array_bucket& operator=(array_bucket other) noexcept { + other.swap(*this); + + return *this; + } + + void swap(array_bucket& other) noexcept { + std::swap(m_buffer, other.m_buffer); + } + + iterator begin() noexcept { return iterator(m_buffer); } + iterator end() noexcept { return iterator(nullptr); } + const_iterator begin() const noexcept { return cbegin(); } + const_iterator end() const noexcept { return cend(); } + const_iterator cbegin() const noexcept { return const_iterator(m_buffer); } + const_iterator cend() const noexcept { return const_iterator(nullptr); } + + /** + * Return an iterator pointing to the key entry if presents or, if not there, + * to the position past the last element of the bucket. Return end() if the + * bucket has not be initialized yet. + * + * The boolean of the pair is set to true if the key is there, false + * otherwise. + */ + std::pair find_or_end_of_bucket( + const CharT* key, size_type key_size) const noexcept { + if (m_buffer == nullptr) { + return std::make_pair(cend(), false); + } + + const CharT* buffer_ptr_in_out = m_buffer; + const bool found = + find_or_end_of_bucket_impl(key, key_size, buffer_ptr_in_out); + + return std::make_pair(const_iterator(buffer_ptr_in_out), found); + } + + /** + * Append the element 'key' with its potential value at the end of the bucket. + * 'end_of_bucket' should point past the end of the last element in the + * bucket, end() if the bucket was not initialized yet. You usually get this + * value from find_or_end_of_bucket. + * + * Return the position where the element was actually inserted. + */ + template + const_iterator append(const_iterator end_of_bucket, const CharT* key, + size_type key_size, ValueArgs&&... value) { + const key_size_type key_sz = as_key_size_type(key_size); + + if (end_of_bucket == cend()) { + tsl_ah_assert(m_buffer == nullptr); + + const size_type buffer_size = entry_required_bytes(key_sz) + + sizeof_in_buff(); + + m_buffer = static_cast(std::malloc(buffer_size)); + if (m_buffer == nullptr) { + throw std::bad_alloc(); + } + + append_impl(key, key_sz, m_buffer, std::forward(value)...); + + return const_iterator(m_buffer); + } else { + tsl_ah_assert(is_end_of_bucket(end_of_bucket.m_position)); + + const size_type current_size = + ((end_of_bucket.m_position + + size_as_char_t()) - + m_buffer) * + sizeof(CharT); + const size_type new_size = current_size + entry_required_bytes(key_sz); + + CharT* new_buffer = static_cast(std::realloc(m_buffer, new_size)); + if (new_buffer == nullptr) { + throw std::bad_alloc(); + } + m_buffer = new_buffer; + + CharT* buffer_append_pos = m_buffer + current_size / sizeof(CharT) - + size_as_char_t(); + append_impl(key, key_sz, buffer_append_pos, + std::forward(value)...); + + return const_iterator(buffer_append_pos); + } + } + + const_iterator erase(const_iterator position) noexcept { + tsl_ah_assert(position.m_position != nullptr && + !is_end_of_bucket(position.m_position)); + + // get mutable pointers + CharT* start_entry = m_buffer + (position.m_position - m_buffer); + CharT* start_next_entry = + start_entry + entry_size_bytes(start_entry) / sizeof(CharT); + + CharT* end_buffer_ptr = start_next_entry; + while (!is_end_of_bucket(end_buffer_ptr)) { + end_buffer_ptr += entry_size_bytes(end_buffer_ptr) / sizeof(CharT); + } + end_buffer_ptr += size_as_char_t(); + + const size_type size_to_move = + (end_buffer_ptr - start_next_entry) * sizeof(CharT); + std::memmove(start_entry, start_next_entry, size_to_move); + + if (is_end_of_bucket(m_buffer)) { + clear(); + return cend(); + } else if (is_end_of_bucket(start_entry)) { + return cend(); + } else { + return const_iterator(start_entry); + } + } + + /** + * Return true if an element has been erased + */ + bool erase(const CharT* key, size_type key_size) noexcept { + if (m_buffer == nullptr) { + return false; + } + + const CharT* entry_buffer_ptr_in_out = m_buffer; + bool found = + find_or_end_of_bucket_impl(key, key_size, entry_buffer_ptr_in_out); + if (found) { + erase(const_iterator(entry_buffer_ptr_in_out)); + + return true; + } else { + return false; + } + } + + /** + * Bucket should be big enough and there is no check to see if the key already + * exists. No check on key_size. + */ + template + void append_in_reserved_bucket_no_check(const CharT* key, size_type key_size, + ValueArgs&&... value) noexcept { + CharT* buffer_ptr = m_buffer; + while (!is_end_of_bucket(buffer_ptr)) { + buffer_ptr += entry_size_bytes(buffer_ptr) / sizeof(CharT); + } + + append_impl(key, key_size_type(key_size), buffer_ptr, + std::forward(value)...); + } + + bool empty() const noexcept { + return m_buffer == nullptr || is_end_of_bucket(m_buffer); + } + + void clear() noexcept { + std::free(m_buffer); + m_buffer = nullptr; + } + + iterator mutable_iterator(const_iterator pos) noexcept { + return iterator(m_buffer + (pos.m_position - m_buffer)); + } + + template + void serialize(Serializer& serializer) const { + const slz_size_type bucket_size = size(); + tsl_ah_assert(m_buffer != nullptr || bucket_size == 0); + + serializer(bucket_size); + serializer(m_buffer, bucket_size); + } + + template + static array_bucket deserialize(Deserializer& deserializer) { + array_bucket bucket; + const slz_size_type bucket_size_ds = + deserialize_value(deserializer); + + if (bucket_size_ds == 0) { + return bucket; + } + + const std::size_t bucket_size = numeric_cast( + bucket_size_ds, "Deserialized bucket_size is too big."); + bucket.m_buffer = static_cast( + std::malloc(bucket_size * sizeof(CharT) + + sizeof_in_buff())); + if (bucket.m_buffer == nullptr) { + throw std::bad_alloc(); + } + + deserializer(bucket.m_buffer, bucket_size); + + const auto end_of_bucket = END_OF_BUCKET; + std::memcpy(bucket.m_buffer + bucket_size, &end_of_bucket, + sizeof(end_of_bucket)); + + tsl_ah_assert(bucket.size() == bucket_size); + return bucket; + } + + private: + key_size_type as_key_size_type(size_type key_size) const { + if (key_size > MAX_KEY_SIZE) { + throw std::length_error("Key is too long."); + } + + return key_size_type(key_size); + } + + /* + * Return true if found, false otherwise. + * If true, buffer_ptr_in_out points to the start of the entry matching 'key'. + * If false, buffer_ptr_in_out points to where the 'key' should be inserted. + * + * Start search from buffer_ptr_in_out. + */ + bool find_or_end_of_bucket_impl(const CharT* key, size_type key_size, + const CharT*& buffer_ptr_in_out) const + noexcept { + while (!is_end_of_bucket(buffer_ptr_in_out)) { + const key_size_type buffer_key_size = read_key_size(buffer_ptr_in_out); + const CharT* buffer_str = + buffer_ptr_in_out + size_as_char_t(); + if (KeyEqual()(buffer_str, buffer_key_size, key, key_size)) { + return true; + } + + buffer_ptr_in_out += entry_size_bytes(buffer_ptr_in_out) / sizeof(CharT); + } + + return false; + } + + template ::value>::type* = nullptr> + void append_impl(const CharT* key, key_size_type key_size, + CharT* buffer_append_pos) noexcept { + std::memcpy(buffer_append_pos, &key_size, sizeof(key_size)); + buffer_append_pos += size_as_char_t(); + + std::memcpy(buffer_append_pos, key, key_size * sizeof(CharT)); + buffer_append_pos += key_size; + + const CharT zero = 0; + std::memcpy(buffer_append_pos, &zero, KEY_EXTRA_SIZE * sizeof(CharT)); + buffer_append_pos += KEY_EXTRA_SIZE; + + const auto end_of_bucket = END_OF_BUCKET; + std::memcpy(buffer_append_pos, &end_of_bucket, sizeof(end_of_bucket)); + } + + template ::value>::type* = nullptr> + void append_impl( + const CharT* key, key_size_type key_size, CharT* buffer_append_pos, + typename array_bucket::mapped_type value) noexcept { + std::memcpy(buffer_append_pos, &key_size, sizeof(key_size)); + buffer_append_pos += size_as_char_t(); + + std::memcpy(buffer_append_pos, key, key_size * sizeof(CharT)); + buffer_append_pos += key_size; + + const CharT zero = 0; + std::memcpy(buffer_append_pos, &zero, KEY_EXTRA_SIZE * sizeof(CharT)); + buffer_append_pos += KEY_EXTRA_SIZE; + + std::memcpy(buffer_append_pos, &value, sizeof(value)); + buffer_append_pos += size_as_char_t(); + + const auto end_of_bucket = END_OF_BUCKET; + std::memcpy(buffer_append_pos, &end_of_bucket, sizeof(end_of_bucket)); + } + + /** + * Return the number of CharT in m_buffer. As the size of the buffer is not + * stored to gain some space, the method need to find the EOF marker and is + * thus in O(n). + */ + size_type size() const noexcept { + if (m_buffer == nullptr) { + return 0; + } + + CharT* buffer_ptr = m_buffer; + while (!is_end_of_bucket(buffer_ptr)) { + buffer_ptr += entry_size_bytes(buffer_ptr) / sizeof(CharT); + } + + return buffer_ptr - m_buffer; + } + + private: + static const key_size_type END_OF_BUCKET = + std::numeric_limits::max(); + static const key_size_type KEY_EXTRA_SIZE = StoreNullTerminator ? 1 : 0; + + CharT* m_buffer; + + public: + static const key_size_type MAX_KEY_SIZE = + // -1 for END_OF_BUCKET + key_size_type(std::numeric_limits::max() - KEY_EXTRA_SIZE - + 1); +}; + +template +class value_container { + public: + void clear() noexcept { m_values.clear(); } + + void reserve(std::size_t new_cap) { m_values.reserve(new_cap); } + + void shrink_to_fit() { m_values.shrink_to_fit(); } + + friend void swap(value_container& lhs, value_container& rhs) { + lhs.m_values.swap(rhs.m_values); + } + + protected: + static constexpr float VECTOR_GROWTH_RATE = 1.5f; + + // TODO use a sparse array? or a std::deque + std::vector m_values; +}; + +template <> +class value_container { + public: + void clear() noexcept {} + + void shrink_to_fit() {} + + void reserve(std::size_t /*new_cap*/) {} +}; + +/** + * If there is no value in the array_hash (in the case of a set for example), T + * should be void. + * + * The size of a key string is limited to std::numeric_limits::max() + * - 1. + * + * The number of elements in the map is limited to + * std::numeric_limits::max(). + */ +template +class array_hash : private value_container, + private Hash, + private GrowthPolicy { + private: + template + using has_mapped_type = + typename std::integral_constant::value>; + + /** + * If there is a mapped type in array_hash, we store the values in m_values of + * value_container class and we store an index to m_values in the bucket. The + * index is of type IndexSizeT. + */ + using array_bucket = tsl::detail_array_hash::array_bucket< + CharT, + typename std::conditional::value, IndexSizeT, + void>::type, + KeyEqual, KeySizeT, StoreNullTerminator>; + + public: + template + class array_hash_iterator; + + using char_type = CharT; + using key_size_type = KeySizeT; + using index_size_type = IndexSizeT; + using size_type = std::size_t; + using hasher = Hash; + using key_equal = KeyEqual; + using iterator = array_hash_iterator; + using const_iterator = array_hash_iterator; + + /* + * Iterator classes + */ + public: + template + class array_hash_iterator { + friend class array_hash; + + private: + using iterator_array_bucket = typename array_bucket::const_iterator; + + using iterator_buckets = typename std::conditional< + IsConst, typename std::vector::const_iterator, + typename std::vector::iterator>::type; + + using array_hash_ptr = typename std::conditional::type; + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = + typename std::conditional::value, T, void>::type; + using difference_type = std::ptrdiff_t; + using reference = typename std::conditional< + has_mapped_type::value, + typename std::conditional< + IsConst, typename std::add_lvalue_reference::type, + typename std::add_lvalue_reference::type>::type, + void>::type; + using pointer = typename std::conditional< + has_mapped_type::value, + typename std::conditional::type, void>::type; + + private: + array_hash_iterator(iterator_buckets buckets_iterator, + iterator_array_bucket array_bucket_iterator, + array_hash_ptr array_hash_p) noexcept + : m_buckets_iterator(buckets_iterator), + m_array_bucket_iterator(array_bucket_iterator), + m_array_hash(array_hash_p) { + tsl_ah_assert(m_array_hash != nullptr); + } + + public: + array_hash_iterator() noexcept : m_array_hash(nullptr) {} + + template ::type* = nullptr> + array_hash_iterator(const array_hash_iterator& other) noexcept + : m_buckets_iterator(other.m_buckets_iterator), + m_array_bucket_iterator(other.m_array_bucket_iterator), + m_array_hash(other.m_array_hash) {} + + array_hash_iterator(const array_hash_iterator& other) = default; + array_hash_iterator(array_hash_iterator&& other) = default; + array_hash_iterator& operator=(const array_hash_iterator& other) = default; + array_hash_iterator& operator=(array_hash_iterator&& other) = default; + + const CharT* key() const { return m_array_bucket_iterator.key(); } + + size_type key_size() const { return m_array_bucket_iterator.key_size(); } + +#ifdef TSL_AH_HAS_STRING_VIEW + std::basic_string_view key_sv() const { + return std::basic_string_view(key(), key_size()); + } +#endif + + template ::value>::type* = nullptr> + reference value() const { + return this->m_array_hash->m_values[value_position()]; + } + + template ::value>::type* = nullptr> + reference operator*() const { + return value(); + } + + template ::value>::type* = nullptr> + pointer operator->() const { + return std::addressof(value()); + } + + array_hash_iterator& operator++() { + tsl_ah_assert(m_buckets_iterator != m_array_hash->m_buckets_data.end()); + tsl_ah_assert(m_array_bucket_iterator != m_buckets_iterator->cend()); + + ++m_array_bucket_iterator; + if (m_array_bucket_iterator == m_buckets_iterator->cend()) { + do { + ++m_buckets_iterator; + } while (m_buckets_iterator != m_array_hash->m_buckets_data.end() && + m_buckets_iterator->empty()); + + if (m_buckets_iterator != m_array_hash->m_buckets_data.end()) { + m_array_bucket_iterator = m_buckets_iterator->cbegin(); + } + } + + return *this; + } + + array_hash_iterator operator++(int) { + array_hash_iterator tmp(*this); + ++*this; + + return tmp; + } + + friend bool operator==(const array_hash_iterator& lhs, + const array_hash_iterator& rhs) { + return lhs.m_buckets_iterator == rhs.m_buckets_iterator && + lhs.m_array_bucket_iterator == rhs.m_array_bucket_iterator && + lhs.m_array_hash == rhs.m_array_hash; + } + + friend bool operator!=(const array_hash_iterator& lhs, + const array_hash_iterator& rhs) { + return !(lhs == rhs); + } + + private: + template ::value>::type* = nullptr> + IndexSizeT value_position() const { + return this->m_array_bucket_iterator.value(); + } + + private: + iterator_buckets m_buckets_iterator; + iterator_array_bucket m_array_bucket_iterator; + + array_hash_ptr m_array_hash; + }; + + public: + array_hash(size_type bucket_count, const Hash& hash, float max_load_factor) + : value_container(), + Hash(hash), + GrowthPolicy(bucket_count), + m_buckets_data(bucket_count > max_bucket_count() + ? throw std::length_error( + "The map exceeds its maximum bucket count.") + : bucket_count), + m_buckets(m_buckets_data.empty() ? static_empty_bucket_ptr() + : m_buckets_data.data()), + m_nb_elements(0) { + this->max_load_factor(max_load_factor); + } + + array_hash(const array_hash& other) + : value_container(other), + Hash(other), + GrowthPolicy(other), + m_buckets_data(other.m_buckets_data), + m_buckets(m_buckets_data.empty() ? static_empty_bucket_ptr() + : m_buckets_data.data()), + m_nb_elements(other.m_nb_elements), + m_max_load_factor(other.m_max_load_factor), + m_load_threshold(other.m_load_threshold) {} + + array_hash(array_hash&& other) noexcept( + std::is_nothrow_move_constructible>::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_constructible< + std::vector>::value) + : value_container(std::move(other)), + Hash(std::move(other)), + GrowthPolicy(std::move(other)), + m_buckets_data(std::move(other.m_buckets_data)), + m_buckets(m_buckets_data.empty() ? static_empty_bucket_ptr() + : m_buckets_data.data()), + m_nb_elements(other.m_nb_elements), + m_max_load_factor(other.m_max_load_factor), + m_load_threshold(other.m_load_threshold) { + other.value_container::clear(); + other.GrowthPolicy::clear(); + other.m_buckets_data.clear(); + other.m_buckets = static_empty_bucket_ptr(); + other.m_nb_elements = 0; + other.m_load_threshold = 0; + } + + array_hash& operator=(const array_hash& other) { + if (&other != this) { + value_container::operator=(other); + Hash::operator=(other); + GrowthPolicy::operator=(other); + + m_buckets_data = other.m_buckets_data; + m_buckets = m_buckets_data.empty() ? static_empty_bucket_ptr() + : m_buckets_data.data(); + m_nb_elements = other.m_nb_elements; + m_max_load_factor = other.m_max_load_factor; + m_load_threshold = other.m_load_threshold; + } + + return *this; + } + + array_hash& operator=(array_hash&& other) { + other.swap(*this); + other.clear(); + + return *this; + } + + /* + * Iterators + */ + iterator begin() noexcept { + auto begin = m_buckets_data.begin(); + while (begin != m_buckets_data.end() && begin->empty()) { + ++begin; + } + + return (begin != m_buckets_data.end()) + ? iterator(begin, begin->cbegin(), this) + : end(); + } + + const_iterator begin() const noexcept { return cbegin(); } + + const_iterator cbegin() const noexcept { + auto begin = m_buckets_data.cbegin(); + while (begin != m_buckets_data.cend() && begin->empty()) { + ++begin; + } + + return (begin != m_buckets_data.cend()) + ? const_iterator(begin, begin->cbegin(), this) + : cend(); + } + + iterator end() noexcept { + return iterator(m_buckets_data.end(), array_bucket::cend_it(), this); + } + + const_iterator end() const noexcept { return cend(); } + + const_iterator cend() const noexcept { + return const_iterator(m_buckets_data.end(), array_bucket::cend_it(), this); + } + + /* + * Capacity + */ + bool empty() const noexcept { return m_nb_elements == 0; } + + size_type size() const noexcept { return m_nb_elements; } + + size_type max_size() const noexcept { + return std::numeric_limits::max(); + } + + size_type max_key_size() const noexcept { return MAX_KEY_SIZE; } + + void shrink_to_fit() { + clear_old_erased_values(); + value_container::shrink_to_fit(); + + rehash_impl(size_type(std::ceil(float(size()) / max_load_factor()))); + } + + /* + * Modifiers + */ + void clear() noexcept { + value_container::clear(); + + for (auto& bucket : m_buckets_data) { + bucket.clear(); + } + + m_nb_elements = 0; + } + + template + std::pair emplace(const CharT* key, size_type key_size, + ValueArgs&&... value_args) { + const std::size_t hash = hash_key(key, key_size); + std::size_t ibucket = bucket_for_hash(hash); + + auto it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + if (it_find.second) { + return std::make_pair( + iterator(m_buckets_data.begin() + ibucket, it_find.first, this), + false); + } + + if (grow_on_high_load()) { + ibucket = bucket_for_hash(hash); + it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + } + + return emplace_impl(ibucket, it_find.first, key, key_size, + std::forward(value_args)...); + } + + template + std::pair insert_or_assign(const CharT* key, + size_type key_size, M&& obj) { + auto it = emplace(key, key_size, std::forward(obj)); + if (!it.second) { + it.first.value() = std::forward(obj); + } + + return it; + } + + iterator erase(const_iterator pos) { + if (should_clear_old_erased_values()) { + clear_old_erased_values(); + } + + return erase_from_bucket(mutable_iterator(pos)); + } + + iterator erase(const_iterator first, const_iterator last) { + if (first == last) { + return mutable_iterator(first); + } + + /** + * When erasing an element from a bucket with erase_from_bucket, it + * invalidates all the iterators in the array bucket of the element + * (m_array_bucket_iterator) but not the iterators of the buckets itself + * (m_buckets_iterator). + * + * So first erase all the values between first and last which are not part + * of the bucket of last, and then erase carefully the values in last's + * bucket. + */ + auto to_delete = mutable_iterator(first); + while (to_delete.m_buckets_iterator != last.m_buckets_iterator) { + to_delete = erase_from_bucket(to_delete); + } + + std::size_t nb_elements_until_last = std::distance( + to_delete.m_array_bucket_iterator, last.m_array_bucket_iterator); + while (nb_elements_until_last > 0) { + to_delete = erase_from_bucket(to_delete); + nb_elements_until_last--; + } + + if (should_clear_old_erased_values()) { + clear_old_erased_values(); + } + + return to_delete; + } + + size_type erase(const CharT* key, size_type key_size) { + return erase(key, key_size, hash_key(key, key_size)); + } + + size_type erase(const CharT* key, size_type key_size, std::size_t hash) { + if (should_clear_old_erased_values()) { + clear_old_erased_values(); + } + + const std::size_t ibucket = bucket_for_hash(hash); + if (m_buckets[ibucket].erase(key, key_size)) { + m_nb_elements--; + return 1; + } else { + return 0; + } + } + + void swap(array_hash& other) { + using std::swap; + + swap(static_cast&>(*this), + static_cast&>(other)); + swap(static_cast(*this), static_cast(other)); + swap(static_cast(*this), static_cast(other)); + swap(m_buckets_data, other.m_buckets_data); + swap(m_buckets, other.m_buckets); + swap(m_nb_elements, other.m_nb_elements); + swap(m_max_load_factor, other.m_max_load_factor); + swap(m_load_threshold, other.m_load_threshold); + } + + /* + * Lookup + */ + template ::value>::type* = nullptr> + U& at(const CharT* key, size_type key_size) { + return at(key, key_size, hash_key(key, key_size)); + } + + template ::value>::type* = nullptr> + const U& at(const CharT* key, size_type key_size) const { + return at(key, key_size, hash_key(key, key_size)); + } + + template ::value>::type* = nullptr> + U& at(const CharT* key, size_type key_size, std::size_t hash) { + return const_cast( + static_cast(this)->at(key, key_size, hash)); + } + + template ::value>::type* = nullptr> + const U& at(const CharT* key, size_type key_size, std::size_t hash) const { + const std::size_t ibucket = bucket_for_hash(hash); + + auto it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + if (it_find.second) { + return this->m_values[it_find.first.value()]; + } else { + throw std::out_of_range("Couldn't find key."); + } + } + + template ::value>::type* = nullptr> + U& access_operator(const CharT* key, size_type key_size) { + const std::size_t hash = hash_key(key, key_size); + std::size_t ibucket = bucket_for_hash(hash); + + auto it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + if (it_find.second) { + return this->m_values[it_find.first.value()]; + } else { + if (grow_on_high_load()) { + ibucket = bucket_for_hash(hash); + it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + } + + return emplace_impl(ibucket, it_find.first, key, key_size, U{}) + .first.value(); + } + } + + size_type count(const CharT* key, size_type key_size) const { + return count(key, key_size, hash_key(key, key_size)); + } + + size_type count(const CharT* key, size_type key_size, + std::size_t hash) const { + const std::size_t ibucket = bucket_for_hash(hash); + + auto it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + if (it_find.second) { + return 1; + } else { + return 0; + } + } + + iterator find(const CharT* key, size_type key_size) { + return find(key, key_size, hash_key(key, key_size)); + } + + const_iterator find(const CharT* key, size_type key_size) const { + return find(key, key_size, hash_key(key, key_size)); + } + + iterator find(const CharT* key, size_type key_size, std::size_t hash) { + const std::size_t ibucket = bucket_for_hash(hash); + + auto it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + if (it_find.second) { + return iterator(m_buckets_data.begin() + ibucket, it_find.first, this); + } else { + return end(); + } + } + + const_iterator find(const CharT* key, size_type key_size, + std::size_t hash) const { + const std::size_t ibucket = bucket_for_hash(hash); + + auto it_find = m_buckets[ibucket].find_or_end_of_bucket(key, key_size); + if (it_find.second) { + return const_iterator(m_buckets_data.cbegin() + ibucket, it_find.first, + this); + } else { + return cend(); + } + } + + std::pair equal_range(const CharT* key, + size_type key_size) { + return equal_range(key, key_size, hash_key(key, key_size)); + } + + std::pair equal_range( + const CharT* key, size_type key_size) const { + return equal_range(key, key_size, hash_key(key, key_size)); + } + + std::pair equal_range(const CharT* key, + size_type key_size, + std::size_t hash) { + iterator it = find(key, key_size, hash); + return std::make_pair(it, (it == end()) ? it : std::next(it)); + } + + std::pair equal_range( + const CharT* key, size_type key_size, std::size_t hash) const { + const_iterator it = find(key, key_size, hash); + return std::make_pair(it, (it == cend()) ? it : std::next(it)); + } + + /* + * Bucket interface + */ + size_type bucket_count() const { return m_buckets_data.size(); } + + size_type max_bucket_count() const { + return std::min(GrowthPolicy::max_bucket_count(), + m_buckets_data.max_size()); + } + + /* + * Hash policy + */ + float load_factor() const { + if (bucket_count() == 0) { + return 0; + } + + return float(m_nb_elements) / float(bucket_count()); + } + + float max_load_factor() const { return m_max_load_factor; } + + void max_load_factor(float ml) { + m_max_load_factor = std::max(0.1f, ml); + m_load_threshold = size_type(float(bucket_count()) * m_max_load_factor); + } + + void rehash(size_type count) { + count = std::max(count, + size_type(std::ceil(float(size()) / max_load_factor()))); + rehash_impl(count); + } + + void reserve(size_type count) { + rehash(size_type(std::ceil(float(count) / max_load_factor()))); + } + + /* + * Observers + */ + hasher hash_function() const { return static_cast(*this); } + + // TODO add support for statefull KeyEqual + key_equal key_eq() const { return KeyEqual(); } + + /* + * Other + */ + iterator mutable_iterator(const_iterator it) noexcept { + auto it_bucket = + m_buckets_data.begin() + + std::distance(m_buckets_data.cbegin(), it.m_buckets_iterator); + return iterator(it_bucket, it.m_array_bucket_iterator, this); + } + + template + void serialize(Serializer& serializer) const { + serialize_impl(serializer); + } + + template + void deserialize(Deserializer& deserializer, bool hash_compatible) { + deserialize_impl(deserializer, hash_compatible); + } + + private: + std::size_t hash_key(const CharT* key, size_type key_size) const { + return Hash::operator()(key, key_size); + } + + std::size_t bucket_for_hash(std::size_t hash) const { + return GrowthPolicy::bucket_for_hash(hash); + } + + /** + * If there is a mapped_type, the mapped value in m_values is not erased now. + * It will be erased when the ratio between the size of the map and + * the size of the map + the number of deleted values still stored is low + * enough (see clear_old_erased_values). + */ + iterator erase_from_bucket(iterator pos) noexcept { + auto array_bucket_next_it = + pos.m_buckets_iterator->erase(pos.m_array_bucket_iterator); + m_nb_elements--; + + if (array_bucket_next_it != pos.m_buckets_iterator->cend()) { + return iterator(pos.m_buckets_iterator, array_bucket_next_it, this); + } else { + do { + ++pos.m_buckets_iterator; + } while (pos.m_buckets_iterator != m_buckets_data.end() && + pos.m_buckets_iterator->empty()); + + if (pos.m_buckets_iterator != m_buckets_data.end()) { + return iterator(pos.m_buckets_iterator, + pos.m_buckets_iterator->cbegin(), this); + } else { + return end(); + } + } + } + + template ::value>::type* = nullptr> + bool should_clear_old_erased_values( + float /*threshold*/ = DEFAULT_CLEAR_OLD_ERASED_VALUE_THRESHOLD) const { + return false; + } + + template ::value>::type* = nullptr> + bool should_clear_old_erased_values( + float threshold = DEFAULT_CLEAR_OLD_ERASED_VALUE_THRESHOLD) const { + if (this->m_values.size() == 0) { + return false; + } + + return float(m_nb_elements) / float(this->m_values.size()) < threshold; + } + + template ::value>::type* = nullptr> + void clear_old_erased_values() {} + + template ::value>::type* = nullptr> + void clear_old_erased_values() { + static_assert(std::is_nothrow_move_constructible::value || + std::is_copy_constructible::value, + "mapped_value must be either copy constructible or nothrow " + "move constructible."); + + if (m_nb_elements == this->m_values.size()) { + return; + } + + std::vector new_values; + new_values.reserve(size()); + + for (auto it = begin(); it != end(); ++it) { + new_values.push_back(std::move_if_noexcept(it.value())); + } + + IndexSizeT ivalue = 0; + for (auto it = begin(); it != end(); ++it) { + auto it_array_bucket = + it.m_buckets_iterator->mutable_iterator(it.m_array_bucket_iterator); + it_array_bucket.set_value(ivalue); + ivalue++; + } + + new_values.swap(this->m_values); + tsl_ah_assert(m_nb_elements == this->m_values.size()); + } + + /** + * Return true if a rehash occurred. + */ + bool grow_on_high_load() { + if (size() >= m_load_threshold) { + rehash_impl(GrowthPolicy::next_bucket_count()); + return true; + } + + return false; + } + + template ::value>::type* = nullptr> + std::pair emplace_impl( + std::size_t ibucket, typename array_bucket::const_iterator end_of_bucket, + const CharT* key, size_type key_size, ValueArgs&&... value_args) { + if (this->m_values.size() >= max_size()) { + // Try to clear old erased values lingering in m_values. Throw if it + // doesn't change anything. + clear_old_erased_values(); + if (this->m_values.size() >= max_size()) { + throw std::length_error( + "Can't insert value, too much values in the map."); + } + } + + if (this->m_values.size() == this->m_values.capacity()) { + this->m_values.reserve( + std::size_t(float(this->m_values.size()) * + value_container::VECTOR_GROWTH_RATE)); + } + + this->m_values.emplace_back(std::forward(value_args)...); + + try { + auto it = m_buckets[ibucket].append( + end_of_bucket, key, key_size, IndexSizeT(this->m_values.size() - 1)); + m_nb_elements++; + + return std::make_pair( + iterator(m_buckets_data.begin() + ibucket, it, this), true); + } catch (...) { + // Rollback + this->m_values.pop_back(); + throw; + } + } + + template ::value>::type* = nullptr> + std::pair emplace_impl( + std::size_t ibucket, typename array_bucket::const_iterator end_of_bucket, + const CharT* key, size_type key_size) { + if (m_nb_elements >= max_size()) { + throw std::length_error( + "Can't insert value, too much values in the map."); + } + + auto it = m_buckets[ibucket].append(end_of_bucket, key, key_size); + m_nb_elements++; + + return std::make_pair(iterator(m_buckets_data.begin() + ibucket, it, this), + true); + } + + void rehash_impl(size_type bucket_count) { + GrowthPolicy new_growth_policy(bucket_count); + if (bucket_count == this->bucket_count()) { + return; + } + + if (should_clear_old_erased_values( + REHASH_CLEAR_OLD_ERASED_VALUE_THRESHOLD)) { + clear_old_erased_values(); + } + + std::vector required_size_for_bucket(bucket_count, 0); + std::vector bucket_for_ivalue(size(), 0); + + std::size_t ivalue = 0; + for (auto it = begin(); it != end(); ++it) { + const std::size_t hash = hash_key(it.key(), it.key_size()); + const std::size_t ibucket = new_growth_policy.bucket_for_hash(hash); + + bucket_for_ivalue[ivalue] = ibucket; + required_size_for_bucket[ibucket] += + array_bucket::entry_required_bytes(it.key_size()); + ivalue++; + } + + std::vector new_buckets; + new_buckets.reserve(bucket_count); + for (std::size_t ibucket = 0; ibucket < bucket_count; ibucket++) { + new_buckets.emplace_back(required_size_for_bucket[ibucket]); + } + + ivalue = 0; + for (auto it = begin(); it != end(); ++it) { + const std::size_t ibucket = bucket_for_ivalue[ivalue]; + append_iterator_in_reserved_bucket_no_check(new_buckets[ibucket], it); + + ivalue++; + } + + using std::swap; + swap(static_cast(*this), new_growth_policy); + + m_buckets_data.swap(new_buckets); + m_buckets = !m_buckets_data.empty() ? m_buckets_data.data() + : static_empty_bucket_ptr(); + + // Call max_load_factor to change m_load_threshold + max_load_factor(m_max_load_factor); + } + + template ::value>::type* = nullptr> + void append_iterator_in_reserved_bucket_no_check(array_bucket& bucket, + iterator it) { + bucket.append_in_reserved_bucket_no_check(it.key(), it.key_size()); + } + + template ::value>::type* = nullptr> + void append_iterator_in_reserved_bucket_no_check(array_bucket& bucket, + iterator it) { + bucket.append_in_reserved_bucket_no_check(it.key(), it.key_size(), + it.value_position()); + } + + /** + * On serialization the values of each bucket (if has_mapped_type is true) are + * serialized next to the bucket. The potential old erased values in + * value_container are thus not serialized. + * + * On deserialization, when hash_compatible is true, we reaffect the value + * index (IndexSizeT) of each bucket with set_value as the position of each + * value is no more the same in value_container compared to when they were + * serialized. + * + * It's done this way as we can't call clear_old_erased_values() because we + * want the serialize method to remain const and we don't want to + * serialize/deserialize old erased values. As we may not serialize all the + * values in value_container, the values we keep can change of index. We thus + * have to modify the value indexes in the buckets. + */ + template + void serialize_impl(Serializer& serializer) const { + const slz_size_type version = SERIALIZATION_PROTOCOL_VERSION; + serializer(version); + + const slz_size_type bucket_count = m_buckets_data.size(); + serializer(bucket_count); + + const slz_size_type nb_elements = m_nb_elements; + serializer(nb_elements); + + const float max_load_factor = m_max_load_factor; + serializer(max_load_factor); + + for (const array_bucket& bucket : m_buckets_data) { + bucket.serialize(serializer); + serialize_bucket_values(serializer, bucket); + } + } + + template < + class Serializer, class U = T, + typename std::enable_if::value>::type* = nullptr> + void serialize_bucket_values(Serializer& /*serializer*/, + const array_bucket& /*bucket*/) const {} + + template ::value>::type* = nullptr> + void serialize_bucket_values(Serializer& serializer, + const array_bucket& bucket) const { + for (auto it = bucket.begin(); it != bucket.end(); ++it) { + serializer(this->m_values[it.value()]); + } + } + + template + void deserialize_impl(Deserializer& deserializer, bool hash_compatible) { + tsl_ah_assert(m_buckets_data.empty()); // Current hash table must be empty + + const slz_size_type version = + deserialize_value(deserializer); + // For now we only have one version of the serialization protocol. + // If it doesn't match there is a problem with the file. + if (version != SERIALIZATION_PROTOCOL_VERSION) { + throw std::runtime_error( + "Can't deserialize the array_map/set. The protocol version header is " + "invalid."); + } + + const slz_size_type bucket_count_ds = + deserialize_value(deserializer); + const slz_size_type nb_elements = + deserialize_value(deserializer); + const float max_load_factor = deserialize_value(deserializer); + + m_nb_elements = numeric_cast( + nb_elements, "Deserialized nb_elements is too big."); + + size_type bucket_count = numeric_cast( + bucket_count_ds, "Deserialized bucket_count is too big."); + GrowthPolicy::operator=(GrowthPolicy(bucket_count)); + + this->max_load_factor(max_load_factor); + value_container::reserve(m_nb_elements); + + if (hash_compatible) { + if (bucket_count != bucket_count_ds) { + throw std::runtime_error( + "The GrowthPolicy is not the same even though hash_compatible is " + "true."); + } + + m_buckets_data.reserve(bucket_count); + for (size_type i = 0; i < bucket_count; i++) { + m_buckets_data.push_back(array_bucket::deserialize(deserializer)); + deserialize_bucket_values(deserializer, m_buckets_data.back()); + } + } else { + m_buckets_data.resize(bucket_count); + for (size_type i = 0; i < bucket_count; i++) { + // TODO use buffer to avoid reallocation on each deserialization. + array_bucket bucket = array_bucket::deserialize(deserializer); + deserialize_bucket_values(deserializer, bucket); + + for (auto it_val = bucket.cbegin(); it_val != bucket.cend(); ++it_val) { + const std::size_t ibucket = + bucket_for_hash(hash_key(it_val.key(), it_val.key_size())); + + auto it_find = m_buckets_data[ibucket].find_or_end_of_bucket( + it_val.key(), it_val.key_size()); + if (it_find.second) { + throw std::runtime_error( + "Error on deserialization, the same key is presents multiple " + "times."); + } + + append_array_bucket_iterator_in_bucket(m_buckets_data[ibucket], + it_find.first, it_val); + } + } + } + + m_buckets = m_buckets_data.data(); + + if (load_factor() > this->max_load_factor()) { + throw std::runtime_error( + "Invalid max_load_factor. Check that the serializer and deserializer " + "support " + "floats correctly as they can be converted implicitely to ints."); + } + } + + template < + class Deserializer, class U = T, + typename std::enable_if::value>::type* = nullptr> + void deserialize_bucket_values(Deserializer& /*deserializer*/, + array_bucket& /*bucket*/) {} + + template ::value>::type* = nullptr> + void deserialize_bucket_values(Deserializer& deserializer, + array_bucket& bucket) { + for (auto it = bucket.begin(); it != bucket.end(); ++it) { + this->m_values.emplace_back(deserialize_value(deserializer)); + + tsl_ah_assert(this->m_values.size() - 1 <= + std::numeric_limits::max()); + it.set_value(static_cast(this->m_values.size() - 1)); + } + } + + template ::value>::type* = nullptr> + void append_array_bucket_iterator_in_bucket( + array_bucket& bucket, typename array_bucket::const_iterator end_of_bucket, + typename array_bucket::const_iterator it_val) { + bucket.append(end_of_bucket, it_val.key(), it_val.key_size()); + } + + template ::value>::type* = nullptr> + void append_array_bucket_iterator_in_bucket( + array_bucket& bucket, typename array_bucket::const_iterator end_of_bucket, + typename array_bucket::const_iterator it_val) { + bucket.append(end_of_bucket, it_val.key(), it_val.key_size(), + it_val.value()); + } + + public: + static const size_type DEFAULT_INIT_BUCKET_COUNT = 0; + static constexpr float DEFAULT_MAX_LOAD_FACTOR = 2.0f; + static const size_type MAX_KEY_SIZE = array_bucket::MAX_KEY_SIZE; + + private: + /** + * Protocol version currenlty used for serialization. + */ + static const slz_size_type SERIALIZATION_PROTOCOL_VERSION = 1; + + static constexpr float DEFAULT_CLEAR_OLD_ERASED_VALUE_THRESHOLD = 0.6f; + static constexpr float REHASH_CLEAR_OLD_ERASED_VALUE_THRESHOLD = 0.9f; + + /** + * Return an always valid pointer to a static empty array_bucket. + */ + array_bucket* static_empty_bucket_ptr() { + static array_bucket empty_bucket; + return &empty_bucket; + } + + private: + std::vector m_buckets_data; + + /** + * Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points + * to static_empty_bucket_ptr. This variable is useful to avoid the cost of + * checking if m_buckets_data is empty when trying to find an element. + * + * TODO Remove m_buckets_data and only use a pointer+size instead of a + * pointer+vector to save some space in the array_hash object. + */ + array_bucket* m_buckets; + + IndexSizeT m_nb_elements; + float m_max_load_factor; + size_type m_load_threshold; +}; + +} // end namespace detail_array_hash +} // end namespace tsl + +#endif diff --git a/include/tsl/array-hash/array_map.h b/include/tsl/array-hash/array_map.h new file mode 100644 index 00000000..435fe03a --- /dev/null +++ b/include/tsl/array-hash/array_map.h @@ -0,0 +1,929 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ARRAY_MAP_H +#define TSL_ARRAY_MAP_H + +#include +#include +#include +#include +#include +#include +#include + +#include "array_hash.h" + +namespace tsl { + +/** + * Implementation of a cache-conscious string hash map. + * + * The map stores the strings as `const CharT*`. If `StoreNullTerminator` is + * true, the strings are stored with the a null-terminator (the `key()` method + * of the iterators will return a pointer to this null-terminated string). + * Otherwise the null character is not stored (which allow an economy of 1 byte + * per string). + * + * The value `T` must be either nothrow move-constructible, copy-constructible + * or both. + * + * The size of a key string is limited to `std::numeric_limits::max() + * - 1`. That is 65 535 characters by default, but can be raised with the + * `KeySizeT` template parameter. See `max_key_size()` for an easy access to + * this limit. + * + * The number of elements in the map is limited to + * `std::numeric_limits::max()`. That is 4 294 967 296 elements, but + * can be raised with the `IndexSizeT` template parameter. See `max_size()` for + * an easy access to this limit. + * + * Iterators invalidation: + * - clear, operator=: always invalidate the iterators. + * - insert, emplace, operator[]: always invalidate the iterators. + * - erase: always invalidate the iterators. + * - shrink_to_fit: always invalidate the iterators. + */ +template , + class KeyEqual = tsl::ah::str_equal, + bool StoreNullTerminator = true, class KeySizeT = std::uint16_t, + class IndexSizeT = std::uint32_t, + class GrowthPolicy = tsl::ah::power_of_two_growth_policy<2>> +class array_map { + private: + template + using is_iterator = tsl::detail_array_hash::is_iterator; + + using ht = tsl::detail_array_hash::array_hash; + + public: + using char_type = typename ht::char_type; + using mapped_type = T; + using key_size_type = typename ht::key_size_type; + using index_size_type = typename ht::index_size_type; + using size_type = typename ht::size_type; + using hasher = typename ht::hasher; + using key_equal = typename ht::key_equal; + using iterator = typename ht::iterator; + using const_iterator = typename ht::const_iterator; + + public: + array_map() : array_map(ht::DEFAULT_INIT_BUCKET_COUNT) {} + + explicit array_map(size_type bucket_count, const Hash& hash = Hash()) + : m_ht(bucket_count, hash, ht::DEFAULT_MAX_LOAD_FACTOR) {} + + template ::value>::type* = nullptr> + array_map(InputIt first, InputIt last, + size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT, + const Hash& hash = Hash()) + : array_map(bucket_count, hash) { + insert(first, last); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + array_map( + std::initializer_list, T>> init, + size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT, + const Hash& hash = Hash()) + : array_map(bucket_count, hash) { + insert(init); + } +#else + array_map(std::initializer_list> init, + size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT, + const Hash& hash = Hash()) + : array_map(bucket_count, hash) { + insert(init); + } +#endif + +#ifdef TSL_AH_HAS_STRING_VIEW + array_map& operator=( + std::initializer_list, T>> + ilist) { + clear(); + + reserve(ilist.size()); + insert(ilist); + + return *this; + } +#else + array_map& operator=( + std::initializer_list> ilist) { + clear(); + + reserve(ilist.size()); + insert(ilist); + + return *this; + } +#endif + + /* + * Iterators + */ + iterator begin() noexcept { return m_ht.begin(); } + const_iterator begin() const noexcept { return m_ht.begin(); } + const_iterator cbegin() const noexcept { return m_ht.cbegin(); } + + iterator end() noexcept { return m_ht.end(); } + const_iterator end() const noexcept { return m_ht.end(); } + const_iterator cend() const noexcept { return m_ht.cend(); } + + /* + * Capacity + */ + bool empty() const noexcept { return m_ht.empty(); } + size_type size() const noexcept { return m_ht.size(); } + size_type max_size() const noexcept { return m_ht.max_size(); } + size_type max_key_size() const noexcept { return m_ht.max_key_size(); } + void shrink_to_fit() { m_ht.shrink_to_fit(); } + + /* + * Modifiers + */ + void clear() noexcept { m_ht.clear(); } + +#ifdef TSL_AH_HAS_STRING_VIEW + std::pair insert(const std::basic_string_view& key, + const T& value) { + return m_ht.emplace(key.data(), key.size(), value); + } +#else + std::pair insert(const CharT* key, const T& value) { + return m_ht.emplace(key, std::char_traits::length(key), value); + } + + std::pair insert(const std::basic_string& key, + const T& value) { + return m_ht.emplace(key.data(), key.size(), value); + } +#endif + std::pair insert_ks(const CharT* key, size_type key_size, + const T& value) { + return m_ht.emplace(key, key_size, value); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + std::pair insert(const std::basic_string_view& key, + T&& value) { + return m_ht.emplace(key.data(), key.size(), std::move(value)); + } +#else + std::pair insert(const CharT* key, T&& value) { + return m_ht.emplace(key, std::char_traits::length(key), + std::move(value)); + } + + std::pair insert(const std::basic_string& key, + T&& value) { + return m_ht.emplace(key.data(), key.size(), std::move(value)); + } +#endif + std::pair insert_ks(const CharT* key, size_type key_size, + T&& value) { + return m_ht.emplace(key, key_size, std::move(value)); + } + + template ::value>::type* = nullptr> + void insert(InputIt first, InputIt last) { + if (std::is_base_of< + std::forward_iterator_tag, + typename std::iterator_traits::iterator_category>::value) { + const auto nb_elements_insert = std::distance(first, last); + const std::size_t nb_free_buckets = + std::size_t(float(bucket_count()) * max_load_factor()) - size(); + + if (nb_elements_insert > 0 && + nb_free_buckets < std::size_t(nb_elements_insert)) { + reserve(size() + std::size_t(nb_elements_insert)); + } + } + + for (auto it = first; it != last; ++it) { + insert_pair(*it); + } + } + +#ifdef TSL_AH_HAS_STRING_VIEW + void insert(std::initializer_list, T>> + ilist) { + insert(ilist.begin(), ilist.end()); + } +#else + void insert(std::initializer_list> ilist) { + insert(ilist.begin(), ilist.end()); + } +#endif + +#ifdef TSL_AH_HAS_STRING_VIEW + template + std::pair insert_or_assign( + const std::basic_string_view& key, M&& obj) { + return m_ht.insert_or_assign(key.data(), key.size(), std::forward(obj)); + } +#else + template + std::pair insert_or_assign(const CharT* key, M&& obj) { + return m_ht.insert_or_assign(key, std::char_traits::length(key), + std::forward(obj)); + } + + template + std::pair insert_or_assign( + const std::basic_string& key, M&& obj) { + return m_ht.insert_or_assign(key.data(), key.size(), std::forward(obj)); + } +#endif + template + std::pair insert_or_assign_ks(const CharT* key, + size_type key_size, M&& obj) { + return m_ht.insert_or_assign(key, key_size, std::forward(obj)); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + template + std::pair emplace(const std::basic_string_view& key, + Args&&... args) { + return m_ht.emplace(key.data(), key.size(), std::forward(args)...); + } +#else + template + std::pair emplace(const CharT* key, Args&&... args) { + return m_ht.emplace(key, std::char_traits::length(key), + std::forward(args)...); + } + + template + std::pair emplace(const std::basic_string& key, + Args&&... args) { + return m_ht.emplace(key.data(), key.size(), std::forward(args)...); + } +#endif + template + std::pair emplace_ks(const CharT* key, size_type key_size, + Args&&... args) { + return m_ht.emplace(key, key_size, std::forward(args)...); + } + + /** + * Erase has an amortized O(1) runtime complexity, but even if it removes the + * key immediately, it doesn't do the same for the associated value T. + * + * T will only be removed when the ratio between the size of the map and + * the size of the map + the number of deleted values still stored is low + * enough. + * + * To force the deletion you can call shrink_to_fit. + */ + iterator erase(const_iterator pos) { return m_ht.erase(pos); } + + /** + * @copydoc erase(const_iterator pos) + */ + iterator erase(const_iterator first, const_iterator last) { + return m_ht.erase(first, last); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc erase(const_iterator pos) + */ + size_type erase(const std::basic_string_view& key) { + return m_ht.erase(key.data(), key.size()); + } +#else + /** + * @copydoc erase(const_iterator pos) + */ + size_type erase(const CharT* key) { + return m_ht.erase(key, std::char_traits::length(key)); + } + + /** + * @copydoc erase(const_iterator pos) + */ + size_type erase(const std::basic_string& key) { + return m_ht.erase(key.data(), key.size()); + } +#endif + /** + * @copydoc erase(const_iterator pos) + */ + size_type erase_ks(const CharT* key, size_type key_size) { + return m_ht.erase(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + size_type erase(const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.erase(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + size_type erase(const CharT* key, std::size_t precalculated_hash) { + return m_ht.erase(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + size_type erase(const std::basic_string& key, + std::size_t precalculated_hash) { + return m_ht.erase(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * @copydoc erase(const_iterator pos) + * + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + size_type erase_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) { + return m_ht.erase(key, key_size, precalculated_hash); + } + + void swap(array_map& other) { other.m_ht.swap(m_ht); } + + /* + * Lookup + */ +#ifdef TSL_AH_HAS_STRING_VIEW + T& at(const std::basic_string_view& key) { + return m_ht.at(key.data(), key.size()); + } + + const T& at(const std::basic_string_view& key) const { + return m_ht.at(key.data(), key.size()); + } +#else + T& at(const CharT* key) { + return m_ht.at(key, std::char_traits::length(key)); + } + + const T& at(const CharT* key) const { + return m_ht.at(key, std::char_traits::length(key)); + } + + T& at(const std::basic_string& key) { + return m_ht.at(key.data(), key.size()); + } + + const T& at(const std::basic_string& key) const { + return m_ht.at(key.data(), key.size()); + } +#endif + T& at_ks(const CharT* key, size_type key_size) { + return m_ht.at(key, key_size); + } + + const T& at_ks(const CharT* key, size_type key_size) const { + return m_ht.at(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + T& at(const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.at(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const T& at(const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.at(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + T& at(const CharT* key, std::size_t precalculated_hash) { + return m_ht.at(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const T& at(const CharT* key, std::size_t precalculated_hash) const { + return m_ht.at(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + T& at(const std::basic_string& key, std::size_t precalculated_hash) { + return m_ht.at(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const T& at(const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.at(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + T& at_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) { + return m_ht.at(key, key_size, precalculated_hash); + } + + /** + * @copydoc at_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const T& at_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.at(key, key_size, precalculated_hash); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + T& operator[](const std::basic_string_view& key) { + return m_ht.access_operator(key.data(), key.size()); + } +#else + T& operator[](const CharT* key) { + return m_ht.access_operator(key, std::char_traits::length(key)); + } + T& operator[](const std::basic_string& key) { + return m_ht.access_operator(key.data(), key.size()); + } +#endif + +#ifdef TSL_AH_HAS_STRING_VIEW + size_type count(const std::basic_string_view& key) const { + return m_ht.count(key.data(), key.size()); + } +#else + size_type count(const CharT* key) const { + return m_ht.count(key, std::char_traits::length(key)); + } + + size_type count(const std::basic_string& key) const { + return m_ht.count(key.data(), key.size()); + } +#endif + size_type count_ks(const CharT* key, size_type key_size) const { + return m_ht.count(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc count_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) const + */ + size_type count(const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.count(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc count_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) const + */ + size_type count(const CharT* key, std::size_t precalculated_hash) const { + return m_ht.count(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc count_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) const + */ + size_type count(const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.count(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + size_type count_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.count(key, key_size, precalculated_hash); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + iterator find(const std::basic_string_view& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string_view& key) const { + return m_ht.find(key.data(), key.size()); + } +#else + iterator find(const CharT* key) { + return m_ht.find(key, std::char_traits::length(key)); + } + + const_iterator find(const CharT* key) const { + return m_ht.find(key, std::char_traits::length(key)); + } + + iterator find(const std::basic_string& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string& key) const { + return m_ht.find(key.data(), key.size()); + } +#endif + iterator find_ks(const CharT* key, size_type key_size) { + return m_ht.find(key, key_size); + } + + const_iterator find_ks(const CharT* key, size_type key_size) const { + return m_ht.find(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + iterator find(const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find(const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + iterator find(const CharT* key, std::size_t precalculated_hash) { + return m_ht.find(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find(const CharT* key, std::size_t precalculated_hash) const { + return m_ht.find(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + iterator find(const std::basic_string& key, + std::size_t precalculated_hash) { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find(const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + iterator find_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) { + return m_ht.find(key, key_size, precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.find(key, key_size, precalculated_hash); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + std::pair equal_range( + const std::basic_string_view& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string_view& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#else + std::pair equal_range(const CharT* key) { + return m_ht.equal_range(key, std::char_traits::length(key)); + } + + std::pair equal_range( + const CharT* key) const { + return m_ht.equal_range(key, std::char_traits::length(key)); + } + + std::pair equal_range( + const std::basic_string& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#endif + std::pair equal_range_ks(const CharT* key, + size_type key_size) { + return m_ht.equal_range(key, key_size); + } + + std::pair equal_range_ks( + const CharT* key, size_type key_size) const { + return m_ht.equal_range(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range(const CharT* key, + std::size_t precalculated_hash) { + return m_ht.equal_range(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const CharT* key, std::size_t precalculated_hash) const { + return m_ht.equal_range(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range(const std::basic_string& key, + std::size_t precalculated_hash) { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + std::pair equal_range_ks(const CharT* key, + size_type key_size, + std::size_t precalculated_hash) { + return m_ht.equal_range(key, key_size, precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range_ks( + const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.equal_range(key, key_size, precalculated_hash); + } + + /* + * Bucket interface + */ + size_type bucket_count() const { return m_ht.bucket_count(); } + size_type max_bucket_count() const { return m_ht.max_bucket_count(); } + + /* + * Hash policy + */ + float load_factor() const { return m_ht.load_factor(); } + float max_load_factor() const { return m_ht.max_load_factor(); } + void max_load_factor(float ml) { m_ht.max_load_factor(ml); } + + void rehash(size_type count) { m_ht.rehash(count); } + void reserve(size_type count) { m_ht.reserve(count); } + + /* + * Observers + */ + hasher hash_function() const { return m_ht.hash_function(); } + key_equal key_eq() const { return m_ht.key_eq(); } + + /* + * Other + */ + /** + * Return the `const_iterator it` as an `iterator`. + */ + iterator mutable_iterator(const_iterator it) noexcept { + return m_ht.mutable_iterator(it); + } + + /** + * Serialize the map through the `serializer` parameter. + * + * The `serializer` parameter must be a function object that supports the + * following calls: + * - `template void operator()(const U& value);` where the types + * `std::uint64_t`, `float` and `T` must be supported for U. + * - `void operator()(const CharT* value, std::size_t value_size);` + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, ...) of the types it serializes in the hands of the `Serializer` + * function object if compatibility is required. + */ + template + void serialize(Serializer& serializer) const { + m_ht.serialize(serializer); + } + + /** + * Deserialize a previously serialized map through the `deserializer` + * parameter. + * + * The `deserializer` parameter must be a function object that supports the + * following calls: + * - `template U operator()();` where the types `std::uint64_t`, + * `float` and `T` must be supported for U. + * - `void operator()(CharT* value_out, std::size_t value_size);` + * + * If the deserialized hash map type is hash compatible with the serialized + * map, the deserialization process can be sped up by setting + * `hash_compatible` to true. To be hash compatible, the Hash (take care of + * the 32-bits vs 64 bits), KeyEqual, GrowthPolicy, StoreNullTerminator, + * KeySizeT and IndexSizeT must behave the same than the ones used on the + * serialized map. Otherwise the behaviour is undefined with `hash_compatible` + * sets to true. + * + * The behaviour is undefined if the type `CharT` and `T` of the `array_map` + * are not the same as the types used during serialization. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, size of int, ...) of the types it deserializes in the hands of the + * `Deserializer` function object if compatibility is required. + */ + template + static array_map deserialize(Deserializer& deserializer, + bool hash_compatible = false) { + array_map map(0); + map.m_ht.deserialize(deserializer, hash_compatible); + + return map; + } + + friend bool operator==(const array_map& lhs, const array_map& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) { + const auto it_element_rhs = rhs.find_ks(it.key(), it.key_size()); + if (it_element_rhs == rhs.cend() || + it.value() != it_element_rhs.value()) { + return false; + } + } + + return true; + } + + friend bool operator!=(const array_map& lhs, const array_map& rhs) { + return !operator==(lhs, rhs); + } + + friend void swap(array_map& lhs, array_map& rhs) { lhs.swap(rhs); } + + private: + template + void insert_pair(const std::pair& value) { + insert(value.first, value.second); + } + + template + void insert_pair(std::pair&& value) { + insert(value.first, std::move(value.second)); + } + + public: + static const size_type MAX_KEY_SIZE = ht::MAX_KEY_SIZE; + + private: + ht m_ht; +}; + +/** + * Same as + * `tsl::array_map`. + */ +template , + class KeyEqual = tsl::ah::str_equal, + bool StoreNullTerminator = true, class KeySizeT = std::uint16_t, + class IndexSizeT = std::uint32_t> +using array_pg_map = + array_map; + +} // end namespace tsl + +#endif diff --git a/include/tsl/array-hash/array_set.h b/include/tsl/array-hash/array_set.h new file mode 100644 index 00000000..f6fb0167 --- /dev/null +++ b/include/tsl/array-hash/array_set.h @@ -0,0 +1,716 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_ARRAY_SET_H +#define TSL_ARRAY_SET_H + +#include +#include +#include +#include +#include +#include +#include + +#include "array_hash.h" + +namespace tsl { + +/** + * Implementation of a cache-conscious string hash set. + * + * The set stores the strings as `const CharT*`. If `StoreNullTerminator` is + * true, the strings are stored with the a null-terminator (the `key()` method + * of the iterators will return a pointer to this null-terminated string). + * Otherwise the null character is not stored (which allow an economy of 1 byte + * per string). + * + * The size of a key string is limited to `std::numeric_limits::max() + * - 1`. That is 65 535 characters by default, but can be raised with the + * `KeySizeT` template parameter. See `max_key_size()` for an easy access to + * this limit. + * + * The number of elements in the set is limited to + * `std::numeric_limits::max()`. That is 4 294 967 296 elements, but + * can be raised with the `IndexSizeT` template parameter. See `max_size()` for + * an easy access to this limit. + * + * Iterators invalidation: + * - clear, operator=: always invalidate the iterators. + * - insert, emplace, operator[]: always invalidate the iterators. + * - erase: always invalidate the iterators. + * - shrink_to_fit: always invalidate the iterators. + */ +template , + class KeyEqual = tsl::ah::str_equal, + bool StoreNullTerminator = true, class KeySizeT = std::uint16_t, + class IndexSizeT = std::uint32_t, + class GrowthPolicy = tsl::ah::power_of_two_growth_policy<2>> +class array_set { + private: + template + using is_iterator = tsl::detail_array_hash::is_iterator; + + using ht = tsl::detail_array_hash::array_hash; + + public: + using char_type = typename ht::char_type; + using key_size_type = typename ht::key_size_type; + using index_size_type = typename ht::index_size_type; + using size_type = typename ht::size_type; + using hasher = typename ht::hasher; + using key_equal = typename ht::key_equal; + using iterator = typename ht::iterator; + using const_iterator = typename ht::const_iterator; + + array_set() : array_set(ht::DEFAULT_INIT_BUCKET_COUNT) {} + + explicit array_set(size_type bucket_count, const Hash& hash = Hash()) + : m_ht(bucket_count, hash, ht::DEFAULT_MAX_LOAD_FACTOR) {} + + template ::value>::type* = nullptr> + array_set(InputIt first, InputIt last, + size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT, + const Hash& hash = Hash()) + : array_set(bucket_count, hash) { + insert(first, last); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + array_set(std::initializer_list> init, + size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT, + const Hash& hash = Hash()) + : array_set(bucket_count, hash) { + insert(init); + } +#else + array_set(std::initializer_list init, + size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT, + const Hash& hash = Hash()) + : array_set(bucket_count, hash) { + insert(init); + } +#endif + +#ifdef TSL_AH_HAS_STRING_VIEW + array_set& operator=( + std::initializer_list> ilist) { + clear(); + + reserve(ilist.size()); + insert(ilist); + + return *this; + } +#else + array_set& operator=(std::initializer_list ilist) { + clear(); + + reserve(ilist.size()); + insert(ilist); + + return *this; + } +#endif + + /* + * Iterators + */ + iterator begin() noexcept { return m_ht.begin(); } + const_iterator begin() const noexcept { return m_ht.begin(); } + const_iterator cbegin() const noexcept { return m_ht.cbegin(); } + + iterator end() noexcept { return m_ht.end(); } + const_iterator end() const noexcept { return m_ht.end(); } + const_iterator cend() const noexcept { return m_ht.cend(); } + + /* + * Capacity + */ + bool empty() const noexcept { return m_ht.empty(); } + size_type size() const noexcept { return m_ht.size(); } + size_type max_size() const noexcept { return m_ht.max_size(); } + size_type max_key_size() const noexcept { return m_ht.max_key_size(); } + void shrink_to_fit() { m_ht.shrink_to_fit(); } + + /* + * Modifiers + */ + void clear() noexcept { m_ht.clear(); } + +#ifdef TSL_AH_HAS_STRING_VIEW + std::pair insert(const std::basic_string_view& key) { + return m_ht.emplace(key.data(), key.size()); + } +#else + std::pair insert(const CharT* key) { + return m_ht.emplace(key, std::char_traits::length(key)); + } + + std::pair insert(const std::basic_string& key) { + return m_ht.emplace(key.data(), key.size()); + } +#endif + std::pair insert_ks(const CharT* key, size_type key_size) { + return m_ht.emplace(key, key_size); + } + + template ::value>::type* = nullptr> + void insert(InputIt first, InputIt last) { + if (std::is_base_of< + std::forward_iterator_tag, + typename std::iterator_traits::iterator_category>::value) { + const auto nb_elements_insert = std::distance(first, last); + const std::size_t nb_free_buckets = + std::size_t(float(bucket_count()) * max_load_factor()) - size(); + + if (nb_elements_insert > 0 && + nb_free_buckets < std::size_t(nb_elements_insert)) { + reserve(size() + std::size_t(nb_elements_insert)); + } + } + + for (auto it = first; it != last; ++it) { + insert(*it); + } + } + +#ifdef TSL_AH_HAS_STRING_VIEW + void insert(std::initializer_list> ilist) { + insert(ilist.begin(), ilist.end()); + } +#else + void insert(std::initializer_list ilist) { + insert(ilist.begin(), ilist.end()); + } +#endif + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc emplace_ks(const CharT* key, size_type key_size) + */ + std::pair emplace(const std::basic_string_view& key) { + return m_ht.emplace(key.data(), key.size()); + } +#else + /** + * @copydoc emplace_ks(const CharT* key, size_type key_size) + */ + std::pair emplace(const CharT* key) { + return m_ht.emplace(key, std::char_traits::length(key)); + } + + /** + * @copydoc emplace_ks(const CharT* key, size_type key_size) + */ + std::pair emplace(const std::basic_string& key) { + return m_ht.emplace(key.data(), key.size()); + } +#endif + /** + * No difference compared to the insert method. Mainly here for coherence with + * array_map. + */ + std::pair emplace_ks(const CharT* key, size_type key_size) { + return m_ht.emplace(key, key_size); + } + + iterator erase(const_iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator first, const_iterator last) { + return m_ht.erase(first, last); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + size_type erase(const std::basic_string_view& key) { + return m_ht.erase(key.data(), key.size()); + } +#else + size_type erase(const CharT* key) { + return m_ht.erase(key, std::char_traits::length(key)); + } + + size_type erase(const std::basic_string& key) { + return m_ht.erase(key.data(), key.size()); + } +#endif + size_type erase_ks(const CharT* key, size_type key_size) { + return m_ht.erase(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + size_type erase(const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.erase(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + size_type erase(const CharT* key, std::size_t precalculated_hash) { + return m_ht.erase(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + size_type erase(const std::basic_string& key, + std::size_t precalculated_hash) { + return m_ht.erase(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + size_type erase_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) { + return m_ht.erase(key, key_size, precalculated_hash); + } + + void swap(array_set& other) { other.m_ht.swap(m_ht); } + + /* + * Lookup + */ +#ifdef TSL_AH_HAS_STRING_VIEW + size_type count(const std::basic_string_view& key) const { + return m_ht.count(key.data(), key.size()); + } +#else + size_type count(const CharT* key) const { + return m_ht.count(key, std::char_traits::length(key)); + } + size_type count(const std::basic_string& key) const { + return m_ht.count(key.data(), key.size()); + } +#endif + size_type count_ks(const CharT* key, size_type key_size) const { + return m_ht.count(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc count_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) const + */ + size_type count(const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.count(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc count_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) const + */ + size_type count(const CharT* key, std::size_t precalculated_hash) const { + return m_ht.count(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc count_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) const + */ + size_type count(const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.count(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + size_type count_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.count(key, key_size, precalculated_hash); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + iterator find(const std::basic_string_view& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string_view& key) const { + return m_ht.find(key.data(), key.size()); + } +#else + iterator find(const CharT* key) { + return m_ht.find(key, std::char_traits::length(key)); + } + + const_iterator find(const CharT* key) const { + return m_ht.find(key, std::char_traits::length(key)); + } + + iterator find(const std::basic_string& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string& key) const { + return m_ht.find(key.data(), key.size()); + } +#endif + iterator find_ks(const CharT* key, size_type key_size) { + return m_ht.find(key, key_size); + } + + const_iterator find_ks(const CharT* key, size_type key_size) const { + return m_ht.find(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + iterator find(const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find(const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + iterator find(const CharT* key, std::size_t precalculated_hash) { + return m_ht.find(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find(const CharT* key, std::size_t precalculated_hash) const { + return m_ht.find(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + iterator find(const std::basic_string& key, + std::size_t precalculated_hash) { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find(const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.find(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + iterator find_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) { + return m_ht.find(key, key_size, precalculated_hash); + } + + /** + * @copydoc find_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + const_iterator find_ks(const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.find(key, key_size, precalculated_hash); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + std::pair equal_range( + const std::basic_string_view& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string_view& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#else + std::pair equal_range(const CharT* key) { + return m_ht.equal_range(key, std::char_traits::length(key)); + } + + std::pair equal_range( + const CharT* key) const { + return m_ht.equal_range(key, std::char_traits::length(key)); + } + + std::pair equal_range( + const std::basic_string& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#endif + std::pair equal_range_ks(const CharT* key, + size_type key_size) { + return m_ht.equal_range(key, key_size); + } + + std::pair equal_range_ks( + const CharT* key, size_type key_size) const { + return m_ht.equal_range(key, key_size); + } + +#ifdef TSL_AH_HAS_STRING_VIEW + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const std::basic_string_view& key, + std::size_t precalculated_hash) { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const std::basic_string_view& key, + std::size_t precalculated_hash) const { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } +#else + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range(const CharT* key, + std::size_t precalculated_hash) { + return m_ht.equal_range(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const CharT* key, std::size_t precalculated_hash) const { + return m_ht.equal_range(key, std::char_traits::length(key), + precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range(const std::basic_string& key, + std::size_t precalculated_hash) { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range( + const std::basic_string& key, + std::size_t precalculated_hash) const { + return m_ht.equal_range(key.data(), key.size(), precalculated_hash); + } +#endif + /** + * Use the hash value 'precalculated_hash' instead of hashing the key. The + * hash value should be the same as hash_function()(key). Useful to speed-up + * the lookup to the value if you already have the hash. + */ + std::pair equal_range_ks(const CharT* key, + size_type key_size, + std::size_t precalculated_hash) { + return m_ht.equal_range(key, key_size, precalculated_hash); + } + + /** + * @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t + * precalculated_hash) + */ + std::pair equal_range_ks( + const CharT* key, size_type key_size, + std::size_t precalculated_hash) const { + return m_ht.equal_range(key, key_size, precalculated_hash); + } + + /* + * Bucket interface + */ + size_type bucket_count() const { return m_ht.bucket_count(); } + size_type max_bucket_count() const { return m_ht.max_bucket_count(); } + + /* + * Hash policy + */ + float load_factor() const { return m_ht.load_factor(); } + float max_load_factor() const { return m_ht.max_load_factor(); } + void max_load_factor(float ml) { m_ht.max_load_factor(ml); } + + void rehash(size_type count) { m_ht.rehash(count); } + void reserve(size_type count) { m_ht.reserve(count); } + + /* + * Observers + */ + hasher hash_function() const { return m_ht.hash_function(); } + key_equal key_eq() const { return m_ht.key_eq(); } + + /* + * Other + */ + /** + * Return the `const_iterator it` as an `iterator`. + */ + iterator mutable_iterator(const_iterator it) noexcept { + return m_ht.mutable_iterator(it); + } + + /** + * Serialize the set through the `serializer` parameter. + * + * The `serializer` parameter must be a function object that supports the + * following calls: + * - `template void operator()(const U& value);` where the types + * `std::uint64_t` and `float` must be supported for U. + * - `void operator()(const CharT* value, std::size_t value_size);` + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, ...) of the types it serializes in the hands of the `Serializer` + * function object if compatibility is required. + */ + template + void serialize(Serializer& serializer) const { + m_ht.serialize(serializer); + } + + /** + * Deserialize a previously serialized set through the `deserializer` + * parameter. + * + * The `deserializer` parameter must be a function object that supports the + * following calls: + * - `template U operator()();` where the types `std::uint64_t` + * and `float` must be supported for U. + * - `void operator()(CharT* value_out, std::size_t value_size);` + * + * If the deserialized hash set type is hash compatible with the serialized + * set, the deserialization process can be sped up by setting + * `hash_compatible` to true. To be hash compatible, the Hash (take care of + * the 32-bits vs 64 bits), KeyEqual, GrowthPolicy, StoreNullTerminator, + * KeySizeT and IndexSizeT must behave the same than the ones used on the + * serialized set. Otherwise the behaviour is undefined with `hash_compatible` + * sets to true. + * + * The behaviour is undefined if the type `CharT` of the `array_set` is not + * the same as the type used during serialization. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, size of int, ...) of the types it deserializes in the hands of the + * `Deserializer` function object if compatibility is required. + */ + template + static array_set deserialize(Deserializer& deserializer, + bool hash_compatible = false) { + array_set set(0); + set.m_ht.deserialize(deserializer, hash_compatible); + + return set; + } + + friend bool operator==(const array_set& lhs, const array_set& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) { + const auto it_element_rhs = rhs.find_ks(it.key(), it.key_size()); + if (it_element_rhs == rhs.cend()) { + return false; + } + } + + return true; + } + + friend bool operator!=(const array_set& lhs, const array_set& rhs) { + return !operator==(lhs, rhs); + } + + friend void swap(array_set& lhs, array_set& rhs) { lhs.swap(rhs); } + + public: + static const size_type MAX_KEY_SIZE = ht::MAX_KEY_SIZE; + + private: + ht m_ht; +}; + +/** + * Same as + * `tsl::array_set`. + */ +template , + class KeyEqual = tsl::ah::str_equal, + bool StoreNullTerminator = true, class KeySizeT = std::uint16_t, + class IndexSizeT = std::uint32_t> +using array_pg_set = + array_set; + +} // end namespace tsl + +#endif diff --git a/include/tsl/htrie_hash.h b/include/tsl/htrie_hash.h new file mode 100644 index 00000000..0cf4632a --- /dev/null +++ b/include/tsl/htrie_hash.h @@ -0,0 +1,2076 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_HTRIE_HASH_H +#define TSL_HTRIE_HASH_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "array-hash/array_map.h" +#include "array-hash/array_set.h" + +/* + * __has_include is a bit useless + * (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=79433), check also __cplusplus + * version. + */ +#ifdef __has_include +#if __has_include() && (__cplusplus >= 201703L || _MSVC_LANG >= 201703L) +#define TSL_HT_HAS_STRING_VIEW +#endif +#endif + +#ifdef TSL_HT_HAS_STRING_VIEW +#include +#endif + +#ifdef TSL_DEBUG +#define tsl_ht_assert(expr) assert(expr) +#else +#define tsl_ht_assert(expr) (static_cast(0)) +#endif + +namespace tsl { + +namespace detail_htrie_hash { + +template +struct is_iterator : std::false_type {}; + +template +struct is_iterator::iterator_category, + void>::value>::type> : std::true_type {}; + +template +struct is_related : std::false_type {}; + +template +struct is_related + : std::is_same::type>::type, + typename std::remove_cv< + typename std::remove_reference::type>::type> {}; + +template +static T numeric_cast(U value, + const char* error_message = "numeric_cast() failed.") { + T ret = static_cast(value); + if (static_cast(ret) != value) { + throw std::runtime_error(error_message); + } + + const bool is_same_signedness = + (std::is_unsigned::value && std::is_unsigned::value) || + (std::is_signed::value && std::is_signed::value); + if (!is_same_signedness && (ret < T{}) != (value < U{})) { + throw std::runtime_error(error_message); + } + + return ret; +} + +template +struct value_node { + /* + * Avoid conflict with copy constructor 'value_node(const value_node&)'. If we + * call the copy constructor with a mutable reference + * 'value_node(value_node&)', we don't want the forward constructor to be + * called. + */ + template ::value>::type* = nullptr> + value_node(Args&&... args) : m_value(std::forward(args)...) {} + + T m_value; +}; + +template <> +struct value_node {}; + +/** + * T should be void if there is no value associated to a key (in a set for + * example). + */ +template +class htrie_hash { + private: + template + using has_value = + typename std::integral_constant::value>; + + static_assert(std::is_same::value, + "char is the only supported CharT type for now."); + + static const std::size_t ALPHABET_SIZE = + std::numeric_limits::type>::max() + 1; + + public: + template + class htrie_hash_iterator; + + using char_type = CharT; + using key_size_type = KeySizeT; + using size_type = std::size_t; + using hasher = Hash; + using iterator = htrie_hash_iterator; + using const_iterator = htrie_hash_iterator; + using prefix_iterator = htrie_hash_iterator; + using const_prefix_iterator = htrie_hash_iterator; + + private: + using array_hash_type = typename std::conditional< + has_value::value, + tsl::array_map, false, KeySizeT, + std::uint16_t, tsl::ah::power_of_two_growth_policy<4>>, + tsl::array_set, false, KeySizeT, + std::uint16_t, + tsl::ah::power_of_two_growth_policy<4>>>::type; + + private: + /* + * The tree is mainly composed of two nodes types: trie_node and hash_node + * which both have anode as base class. Each child is either a hash_node or a + * trie_node. + * + * A hash_node is always a leaf node, it doesn't have any child. + * + * Example: + * | ... | a |.. ..................... | f | ... | trie_node_1 + * \ \ + * hash_node_1 |array_hash = {"dd"}| |...| a | ... | trie_node_2 + * / + * |array_hash = {"ble", "bric", "lse"}| hash_node_2 + * + * + * Each trie_node may also have a value node, which contains a value T, if the + * trie_node marks the end of a string value. + * + * A trie node should at least have one child or a value node. There can't be + * a trie node without any child and no value node. + */ + + using value_node = tsl::detail_htrie_hash::value_node; + + class trie_node; + class hash_node; + + // TODO better encapsulate operations modifying the tree. + class anode { + friend class trie_node; + + public: + /* + * TODO Avoid the virtual to economize 8 bytes. We could use a custom + * deleter in the std::unique_ptr we use (as we know if an anode is a + * trie_node or hash_node). + */ + virtual ~anode() = default; + + bool is_trie_node() const noexcept { + return m_node_type == node_type::TRIE_NODE; + } + + bool is_hash_node() const noexcept { + return m_node_type == node_type::HASH_NODE; + } + + trie_node& as_trie_node() noexcept { + tsl_ht_assert(is_trie_node()); + return static_cast(*this); + } + + hash_node& as_hash_node() noexcept { + tsl_ht_assert(is_hash_node()); + return static_cast(*this); + } + + const trie_node& as_trie_node() const noexcept { + tsl_ht_assert(is_trie_node()); + return static_cast(*this); + } + + const hash_node& as_hash_node() const noexcept { + tsl_ht_assert(is_hash_node()); + return static_cast(*this); + } + + /** + * @see m_child_of_char + */ + CharT child_of_char() const noexcept { + tsl_ht_assert(parent() != nullptr); + return m_child_of_char; + } + + /** + * Return nullptr if none. + */ + trie_node* parent() noexcept { return m_parent_node; } + + const trie_node* parent() const noexcept { return m_parent_node; } + + protected: + enum class node_type : unsigned char { HASH_NODE, TRIE_NODE }; + + anode(node_type node_type_) + : m_node_type(node_type_), m_child_of_char(0), m_parent_node(nullptr) {} + + anode(node_type node_type_, CharT child_of_char) + : m_node_type(node_type_), + m_child_of_char(child_of_char), + m_parent_node(nullptr) {} + + protected: + node_type m_node_type; + + /** + * If the node has a parent, then it's a descendant of some char. + * + * Example: + * | ... | a | b | ... | trie_node_1 + * \ + * |...| a | ... | trie_node_2 + * / + * |array_hash| hash_node_1 + * + * trie_node_2 is a child of trie_node_1 through 'b', it will have 'b' as + * m_child_of_char. hash_node_1 is a child of trie_node_2 through 'a', it + * will have 'a' as m_child_of_char. + * + * trie_node_1 has no parent, its m_child_of_char is undefined. + */ + CharT m_child_of_char; + trie_node* m_parent_node; + }; + + // Give the position in trie_node::m_children corresponding to the character c + static std::size_t as_position(CharT c) noexcept { + return static_cast( + static_cast::type>(c)); + } + + class trie_node : public anode { + public: + trie_node() + : anode(anode::node_type::TRIE_NODE), + m_value_node(nullptr), + m_children() {} + + trie_node(const trie_node& other) + : anode(anode::node_type::TRIE_NODE, other.m_child_of_char), + m_value_node(nullptr), + m_children() { + if (other.m_value_node != nullptr) { + m_value_node = make_unique(*other.m_value_node); + } + + // TODO avoid recursion + for (std::size_t ichild = 0; ichild < other.m_children.size(); ichild++) { + if (other.m_children[ichild] != nullptr) { + if (other.m_children[ichild]->is_hash_node()) { + m_children[ichild] = make_unique( + other.m_children[ichild]->as_hash_node()); + } else { + m_children[ichild] = make_unique( + other.m_children[ichild]->as_trie_node()); + } + + m_children[ichild]->m_parent_node = this; + } + } + } + + trie_node(trie_node&& other) = delete; + trie_node& operator=(const trie_node& other) = delete; + trie_node& operator=(trie_node&& other) = delete; + + /** + * Return nullptr if none. + */ + anode* first_child() noexcept { + return const_cast( + static_cast(this)->first_child()); + } + + const anode* first_child() const noexcept { + for (std::size_t ichild = 0; ichild < m_children.size(); ichild++) { + if (m_children[ichild] != nullptr) { + return m_children[ichild].get(); + } + } + + return nullptr; + } + + /** + * Get the next_child that come after current_child. Return nullptr if no + * next child. + */ + anode* next_child(const anode& current_child) noexcept { + return const_cast( + static_cast(this)->next_child(current_child)); + } + + const anode* next_child(const anode& current_child) const noexcept { + tsl_ht_assert(current_child.parent() == this); + + for (std::size_t ichild = as_position(current_child.child_of_char()) + 1; + ichild < m_children.size(); ichild++) { + if (m_children[ichild] != nullptr) { + return m_children[ichild].get(); + } + } + + return nullptr; + } + + /** + * Return the first left-descendant trie node with an m_value_node. If none + * return the most left trie node. + */ + trie_node& most_left_descendant_value_trie_node() noexcept { + return const_cast( + static_cast(this) + ->most_left_descendant_value_trie_node()); + } + + const trie_node& most_left_descendant_value_trie_node() const noexcept { + const trie_node* current_node = this; + while (true) { + if (current_node->m_value_node != nullptr) { + return *current_node; + } + + const anode* first_child = current_node->first_child(); + tsl_ht_assert(first_child != + nullptr); // a trie_node must either have a value_node or + // at least one child. + if (first_child->is_hash_node()) { + return *current_node; + } + + current_node = &first_child->as_trie_node(); + } + } + + size_type nb_children() const noexcept { + return std::count_if( + m_children.cbegin(), m_children.cend(), + [](const std::unique_ptr& n) { return n != nullptr; }); + } + + bool empty() const noexcept { + return std::all_of( + m_children.cbegin(), m_children.cend(), + [](const std::unique_ptr& n) { return n == nullptr; }); + } + + std::unique_ptr& child(CharT for_char) noexcept { + return m_children[as_position(for_char)]; + } + + const std::unique_ptr& child(CharT for_char) const noexcept { + return m_children[as_position(for_char)]; + } + + typename std::array, ALPHABET_SIZE>::iterator + begin() noexcept { + return m_children.begin(); + } + + typename std::array, ALPHABET_SIZE>::iterator + end() noexcept { + return m_children.end(); + } + + void set_child(CharT for_char, std::unique_ptr child) noexcept { + if (child != nullptr) { + child->m_child_of_char = for_char; + child->m_parent_node = this; + } + + m_children[as_position(for_char)] = std::move(child); + } + + std::unique_ptr& val_node() noexcept { return m_value_node; } + + const std::unique_ptr& val_node() const noexcept { + return m_value_node; + } + + private: + // TODO Avoid storing a value_node when has_value::value is false + std::unique_ptr m_value_node; + + /** + * Each character CharT corresponds to one position in the array. To convert + * a character to a position use the as_position method. + * + * TODO Try to reduce the size of m_children with a hash map, linear/binary + * search on array, ... + * TODO Store number of non-null values in m_children. Check if we can store + * this value in the alignment space as we don't want the node to get bigger + * (empty() and nb_children() are rarely used so it is not an important + * variable). + */ + std::array, ALPHABET_SIZE> m_children; + }; + + class hash_node : public anode { + public: + hash_node(const Hash& hash, float max_load_factor) + : hash_node(HASH_NODE_DEFAULT_INIT_BUCKETS_COUNT, hash, + max_load_factor) {} + + hash_node(size_type bucket_count, const Hash& hash, float max_load_factor) + : anode(anode::node_type::HASH_NODE), m_array_hash(bucket_count, hash) { + m_array_hash.max_load_factor(max_load_factor); + } + + hash_node(array_hash_type&& array_hash) noexcept( + std::is_nothrow_move_constructible::value) + : anode(anode::node_type::HASH_NODE), + m_array_hash(std::move(array_hash)) {} + + hash_node(const hash_node& other) = default; + + hash_node(hash_node&& other) = delete; + hash_node& operator=(const hash_node& other) = delete; + hash_node& operator=(hash_node&& other) = delete; + + array_hash_type& array_hash() noexcept { return m_array_hash; } + + const array_hash_type& array_hash() const noexcept { return m_array_hash; } + + private: + array_hash_type m_array_hash; + }; + + public: + template + class htrie_hash_iterator { + friend class htrie_hash; + + private: + using anode_type = + typename std::conditional::type; + using trie_node_type = + typename std::conditional::type; + using hash_node_type = + typename std::conditional::type; + + using array_hash_iterator_type = + typename std::conditional::type; + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = + typename std::conditional::value, T, void>::type; + using difference_type = std::ptrdiff_t; + using reference = typename std::conditional< + has_value::value, + typename std::conditional< + IsConst, typename std::add_lvalue_reference::type, + typename std::add_lvalue_reference::type>::type, + void>::type; + using pointer = typename std::conditional< + has_value::value, + typename std::conditional::type, void>::type; + + private: + /** + * Start reading from start_hash_node->array_hash().begin(). + */ + htrie_hash_iterator(hash_node_type& start_hash_node) noexcept + : htrie_hash_iterator(start_hash_node, + start_hash_node.array_hash().begin()) {} + + /** + * Start reading from iterator begin in start_hash_node->array_hash(). + */ + htrie_hash_iterator(hash_node_type& start_hash_node, + array_hash_iterator_type begin) noexcept + : m_current_trie_node(start_hash_node.parent()), + m_current_hash_node(&start_hash_node), + m_array_hash_iterator(begin), + m_array_hash_end_iterator(start_hash_node.array_hash().end()), + m_read_trie_node_value(false) { + tsl_ht_assert(!m_current_hash_node->array_hash().empty()); + } + + /** + * Start reading from the value in start_trie_node. + * start_trie_node->val_node() should be non-null. + */ + htrie_hash_iterator(trie_node_type& start_trie_node) noexcept + : m_current_trie_node(&start_trie_node), + m_current_hash_node(nullptr), + m_read_trie_node_value(true) { + tsl_ht_assert(m_current_trie_node->val_node() != nullptr); + } + + template ::type* = nullptr> + htrie_hash_iterator(trie_node_type* tnode, hash_node_type* hnode, + array_hash_iterator_type begin, + array_hash_iterator_type end, + bool read_trie_node_value) noexcept + : m_current_trie_node(tnode), + m_current_hash_node(hnode), + m_array_hash_iterator(begin), + m_array_hash_end_iterator(end), + m_read_trie_node_value(read_trie_node_value) {} + + template ::type* = nullptr> + htrie_hash_iterator(trie_node_type* tnode, hash_node_type* hnode, + array_hash_iterator_type begin, + array_hash_iterator_type end, bool read_trie_node_value, + std::basic_string prefix_filter) noexcept + : m_current_trie_node(tnode), + m_current_hash_node(hnode), + m_array_hash_iterator(begin), + m_array_hash_end_iterator(end), + m_read_trie_node_value(read_trie_node_value), + m_prefix_filter(std::move(prefix_filter)) {} + + public: + htrie_hash_iterator() noexcept {} + + // Copy constructor from iterator to const_iterator. + template ::type* = + nullptr> + htrie_hash_iterator( + const htrie_hash_iterator& other) noexcept + : m_current_trie_node(other.m_current_trie_node), + m_current_hash_node(other.m_current_hash_node), + m_array_hash_iterator(other.m_array_hash_iterator), + m_array_hash_end_iterator(other.m_array_hash_end_iterator), + m_read_trie_node_value(other.m_read_trie_node_value) {} + + // Copy constructor from iterator to const_iterator. + template < + bool TIsConst = IsConst, bool TIsPrefixIterator = IsPrefixIterator, + typename std::enable_if::type* = nullptr> + htrie_hash_iterator( + const htrie_hash_iterator& other) noexcept + : m_current_trie_node(other.m_current_trie_node), + m_current_hash_node(other.m_current_hash_node), + m_array_hash_iterator(other.m_array_hash_iterator), + m_array_hash_end_iterator(other.m_array_hash_end_iterator), + m_read_trie_node_value(other.m_read_trie_node_value), + m_prefix_filter(other.m_prefix_filter) {} + + htrie_hash_iterator(const htrie_hash_iterator& other) = default; + htrie_hash_iterator(htrie_hash_iterator&& other) = default; + htrie_hash_iterator& operator=(const htrie_hash_iterator& other) = default; + htrie_hash_iterator& operator=(htrie_hash_iterator&& other) = default; + + void key(std::basic_string& key_buffer_out) const { + key_buffer_out.clear(); + + trie_node_type* tnode = m_current_trie_node; + while (tnode != nullptr && tnode->parent() != nullptr) { + key_buffer_out.push_back(tnode->child_of_char()); + tnode = tnode->parent(); + } + + std::reverse(key_buffer_out.begin(), key_buffer_out.end()); + + if (!m_read_trie_node_value) { + tsl_ht_assert(m_current_hash_node != nullptr); + if (m_current_hash_node->parent() != nullptr) { + key_buffer_out.push_back(m_current_hash_node->child_of_char()); + } + + key_buffer_out.append(m_array_hash_iterator.key(), + m_array_hash_iterator.key_size()); + } + } + + std::basic_string key() const { + std::basic_string key_buffer; + key(key_buffer); + + return key_buffer; + } + + template ::value>::type* = nullptr> + reference value() const { + if (this->m_read_trie_node_value) { + tsl_ht_assert(this->m_current_trie_node != nullptr); + tsl_ht_assert(this->m_current_trie_node->val_node() != nullptr); + + return this->m_current_trie_node->val_node()->m_value; + } else { + return this->m_array_hash_iterator.value(); + } + } + + template ::value>::type* = nullptr> + reference operator*() const { + return value(); + } + + template ::value>::type* = nullptr> + pointer operator->() const { + return std::addressof(value()); + } + + htrie_hash_iterator& operator++() { + if (m_read_trie_node_value) { + tsl_ht_assert(m_current_trie_node != nullptr); + + m_read_trie_node_value = false; + + anode_type* child = m_current_trie_node->first_child(); + if (child != nullptr) { + set_most_left_descendant_as_next_node(*child); + } else if (m_current_trie_node->parent() != nullptr) { + trie_node_type* current_node_child = m_current_trie_node; + m_current_trie_node = m_current_trie_node->parent(); + + set_next_node_ascending(*current_node_child); + } else { + set_as_end_iterator(); + } + } else { + ++m_array_hash_iterator; + if (m_array_hash_iterator != m_array_hash_end_iterator) { + filter_prefix(); + } + // End of the road, set the iterator as an end node. + else if (m_current_trie_node == nullptr) { + set_as_end_iterator(); + } else { + tsl_ht_assert(m_current_hash_node != nullptr); + set_next_node_ascending(*m_current_hash_node); + } + } + + return *this; + } + + htrie_hash_iterator operator++(int) { + htrie_hash_iterator tmp(*this); + ++*this; + + return tmp; + } + + friend bool operator==(const htrie_hash_iterator& lhs, + const htrie_hash_iterator& rhs) { + if (lhs.m_current_trie_node != rhs.m_current_trie_node || + lhs.m_read_trie_node_value != rhs.m_read_trie_node_value) { + return false; + } else if (lhs.m_read_trie_node_value) { + return true; + } else { + if (lhs.m_current_hash_node != rhs.m_current_hash_node) { + return false; + } else if (lhs.m_current_hash_node == nullptr) { + return true; + } else { + return lhs.m_array_hash_iterator == rhs.m_array_hash_iterator && + lhs.m_array_hash_end_iterator == rhs.m_array_hash_end_iterator; + } + } + } + + friend bool operator!=(const htrie_hash_iterator& lhs, + const htrie_hash_iterator& rhs) { + return !(lhs == rhs); + } + + private: + void hash_node_prefix(std::basic_string& key_buffer_out) const { + tsl_ht_assert(!m_read_trie_node_value); + key_buffer_out.clear(); + + trie_node_type* tnode = m_current_trie_node; + while (tnode != nullptr && tnode->parent() != nullptr) { + key_buffer_out.push_back(tnode->child_of_char()); + tnode = tnode->parent(); + } + + std::reverse(key_buffer_out.begin(), key_buffer_out.end()); + + tsl_ht_assert(m_current_hash_node != nullptr); + if (m_current_hash_node->parent() != nullptr) { + key_buffer_out.push_back(m_current_hash_node->child_of_char()); + } + } + + template ::type* = nullptr> + void filter_prefix() {} + + template ::type* = nullptr> + void filter_prefix() { + tsl_ht_assert(m_array_hash_iterator != m_array_hash_end_iterator); + tsl_ht_assert(!m_read_trie_node_value && m_current_hash_node != nullptr); + + if (m_prefix_filter.empty()) { + return; + } + + while ((m_prefix_filter.size() > m_array_hash_iterator.key_size() || + m_prefix_filter.compare(0, m_prefix_filter.size(), + m_array_hash_iterator.key(), + m_prefix_filter.size()) != 0)) { + ++m_array_hash_iterator; + if (m_array_hash_iterator == m_array_hash_end_iterator) { + if (m_current_trie_node == nullptr) { + set_as_end_iterator(); + } else { + tsl_ht_assert(m_current_hash_node != nullptr); + set_next_node_ascending(*m_current_hash_node); + } + + return; + } + } + } + + /** + * Go back up in the tree to get the current_trie_node_child sibling. + * If none, try to go back up more in the tree to check the siblings of the + * ancestors. + */ + void set_next_node_ascending(anode_type& current_trie_node_child) { + tsl_ht_assert(m_current_trie_node != nullptr); + tsl_ht_assert(current_trie_node_child.parent() == m_current_trie_node); + + anode_type* next_node = + m_current_trie_node->next_child(current_trie_node_child); + while (next_node == nullptr && m_current_trie_node->parent() != nullptr) { + anode_type* current_child = m_current_trie_node; + m_current_trie_node = m_current_trie_node->parent(); + next_node = m_current_trie_node->next_child(*current_child); + } + + // End of the road, set the iterator as an end node. + if (next_node == nullptr) { + set_as_end_iterator(); + } else { + set_most_left_descendant_as_next_node(*next_node); + } + } + + void set_most_left_descendant_as_next_node(anode_type& search_start) { + if (search_start.is_hash_node()) { + set_current_hash_node(search_start.as_hash_node()); + } else { + m_current_trie_node = + &search_start.as_trie_node().most_left_descendant_value_trie_node(); + if (m_current_trie_node->val_node() != nullptr) { + m_read_trie_node_value = true; + } else { + anode_type* first_child = m_current_trie_node->first_child(); + // a trie_node must either have a value_node or at least one child. + tsl_ht_assert(first_child != nullptr); + + set_current_hash_node(first_child->as_hash_node()); + } + } + } + + void set_current_hash_node(hash_node_type& hnode) { + tsl_ht_assert(!hnode.array_hash().empty()); + + m_current_hash_node = &hnode; + m_array_hash_iterator = m_current_hash_node->array_hash().begin(); + m_array_hash_end_iterator = m_current_hash_node->array_hash().end(); + } + + void set_as_end_iterator() { + m_current_trie_node = nullptr; + m_current_hash_node = nullptr; + m_read_trie_node_value = false; + } + + void skip_hash_node() { + tsl_ht_assert(!m_read_trie_node_value && m_current_hash_node != nullptr); + if (m_current_trie_node == nullptr) { + set_as_end_iterator(); + } else { + tsl_ht_assert(m_current_hash_node != nullptr); + set_next_node_ascending(*m_current_hash_node); + } + } + + private: + trie_node_type* m_current_trie_node; + hash_node_type* m_current_hash_node; + + array_hash_iterator_type m_array_hash_iterator; + array_hash_iterator_type m_array_hash_end_iterator; + + bool m_read_trie_node_value; + // TODO can't have void if !IsPrefixIterator, use inheritance + typename std::conditional, + bool>::type m_prefix_filter; + }; + + public: + htrie_hash(const Hash& hash, float max_load_factor, size_type burst_threshold) + : m_root(nullptr), + m_nb_elements(0), + m_hash(hash), + m_max_load_factor(max_load_factor) { + this->burst_threshold(burst_threshold); + } + + htrie_hash(const htrie_hash& other) + : m_root(nullptr), + m_nb_elements(other.m_nb_elements), + m_hash(other.m_hash), + m_max_load_factor(other.m_max_load_factor), + m_burst_threshold(other.m_burst_threshold) { + if (other.m_root != nullptr) { + if (other.m_root->is_hash_node()) { + m_root = make_unique(other.m_root->as_hash_node()); + } else { + m_root = make_unique(other.m_root->as_trie_node()); + } + } + } + + htrie_hash(htrie_hash&& other) noexcept( + std::is_nothrow_move_constructible::value) + : m_root(std::move(other.m_root)), + m_nb_elements(other.m_nb_elements), + m_hash(std::move(other.m_hash)), + m_max_load_factor(other.m_max_load_factor), + m_burst_threshold(other.m_burst_threshold) { + other.clear(); + } + + htrie_hash& operator=(const htrie_hash& other) { + if (&other != this) { + std::unique_ptr new_root = nullptr; + if (other.m_root != nullptr) { + if (other.m_root->is_hash_node()) { + new_root = make_unique(other.m_root->as_hash_node()); + } else { + new_root = make_unique(other.m_root->as_trie_node()); + } + } + + m_hash = other.m_hash; + m_root = std::move(new_root); + m_nb_elements = other.m_nb_elements; + m_max_load_factor = other.m_max_load_factor; + m_burst_threshold = other.m_burst_threshold; + } + + return *this; + } + + htrie_hash& operator=(htrie_hash&& other) { + other.swap(*this); + other.clear(); + + return *this; + } + + /* + * Iterators + */ + iterator begin() noexcept { return mutable_iterator(cbegin()); } + + const_iterator begin() const noexcept { return cbegin(); } + + const_iterator cbegin() const noexcept { + if (empty()) { + return cend(); + } + + return cbegin(*m_root); + } + + iterator end() noexcept { + iterator it; + it.set_as_end_iterator(); + + return it; + } + + const_iterator end() const noexcept { return cend(); } + + const_iterator cend() const noexcept { + const_iterator it; + it.set_as_end_iterator(); + + return it; + } + + /* + * Capacity + */ + bool empty() const noexcept { return m_nb_elements == 0; } + + size_type size() const noexcept { return m_nb_elements; } + + size_type max_size() const noexcept { + return std::numeric_limits::max(); + } + + size_type max_key_size() const noexcept { + return array_hash_type::MAX_KEY_SIZE; + } + + void shrink_to_fit() { + auto first = begin(); + auto last = end(); + + while (first != last) { + if (first.m_read_trie_node_value) { + ++first; + } else { + /* + * shrink_to_fit on array_hash will invalidate the iterators of + * array_hash. Save pointer to array_hash, skip the array_hash_node and + * then call shrink_to_fit on the saved pointer. + */ + hash_node* hnode = first.m_current_hash_node; + first.skip_hash_node(); + + tsl_ht_assert(hnode != nullptr); + hnode->array_hash().shrink_to_fit(); + } + } + } + + /* + * Modifiers + */ + void clear() noexcept { + m_root.reset(nullptr); + m_nb_elements = 0; + } + + template + std::pair insert(const CharT* key, size_type key_size, + ValueArgs&&... value_args) { + if (key_size > max_key_size()) { + throw std::length_error("Key is too long."); + } + + if (m_root == nullptr) { + m_root = make_unique(m_hash, m_max_load_factor); + } + + return insert_impl(*m_root, key, key_size, + std::forward(value_args)...); + } + + iterator erase(const_iterator pos) { return erase(mutable_iterator(pos)); } + + iterator erase(const_iterator first, const_iterator last) { + // TODO Optimize, could avoid the call to std::distance + const std::size_t nb_to_erase = std::size_t(std::distance(first, last)); + auto to_delete = mutable_iterator(first); + for (std::size_t i = 0; i < nb_to_erase; i++) { + to_delete = erase(to_delete); + } + + return to_delete; + } + + size_type erase(const CharT* key, size_type key_size) { + auto it = find(key, key_size); + if (it != end()) { + erase(it); + return 1; + } else { + return 0; + } + } + + size_type erase_prefix(const CharT* prefix, size_type prefix_size) { + if (m_root == nullptr) { + return 0; + } + + anode* current_node = m_root.get(); + for (size_type iprefix = 0; iprefix < prefix_size; iprefix++) { + if (current_node->is_trie_node()) { + trie_node* tnode = ¤t_node->as_trie_node(); + + if (tnode->child(prefix[iprefix]) == nullptr) { + return 0; + } else { + current_node = tnode->child(prefix[iprefix]).get(); + } + } else { + hash_node& hnode = current_node->as_hash_node(); + return erase_prefix_hash_node(hnode, prefix + iprefix, + prefix_size - iprefix); + } + } + + if (current_node->is_trie_node()) { + trie_node* parent = current_node->parent(); + + if (parent != nullptr) { + const size_type nb_erased = + size_descendants(current_node->as_trie_node()); + + parent->set_child(current_node->child_of_char(), nullptr); + m_nb_elements -= nb_erased; + + if (parent->empty()) { + clear_empty_nodes(*parent); + } + + return nb_erased; + } else { + const size_type nb_erased = m_nb_elements; + m_root.reset(nullptr); + m_nb_elements = 0; + + return nb_erased; + } + } else { + const size_type nb_erased = + current_node->as_hash_node().array_hash().size(); + + current_node->as_hash_node().array_hash().clear(); + m_nb_elements -= nb_erased; + + clear_empty_nodes(current_node->as_hash_node()); + + return nb_erased; + } + } + + void swap(htrie_hash& other) { + using std::swap; + + swap(m_hash, other.m_hash); + swap(m_root, other.m_root); + swap(m_nb_elements, other.m_nb_elements); + swap(m_max_load_factor, other.m_max_load_factor); + swap(m_burst_threshold, other.m_burst_threshold); + } + + /* + * Lookup + */ + template ::value>::type* = nullptr> + U& at(const CharT* key, size_type key_size) { + return const_cast( + static_cast(this)->at(key, key_size)); + } + + template ::value>::type* = nullptr> + const U& at(const CharT* key, size_type key_size) const { + auto it_find = find(key, key_size); + if (it_find != cend()) { + return it_find.value(); + } else { + throw std::out_of_range("Couldn't find key."); + } + } + + // TODO optimize + template ::value>::type* = nullptr> + U& access_operator(const CharT* key, size_type key_size) { + auto it_find = find(key, key_size); + if (it_find != cend()) { + return it_find.value(); + } else { + return insert(key, key_size, U{}).first.value(); + } + } + + size_type count(const CharT* key, size_type key_size) const { + if (find(key, key_size) != cend()) { + return 1; + } else { + return 0; + } + } + + iterator find(const CharT* key, size_type key_size) { + if (m_root == nullptr) { + return end(); + } + + return find_impl(*m_root, key, key_size); + } + + const_iterator find(const CharT* key, size_type key_size) const { + if (m_root == nullptr) { + return cend(); + } + + return find_impl(*m_root, key, key_size); + } + + std::pair equal_range(const CharT* key, + size_type key_size) { + iterator it = find(key, key_size); + return std::make_pair(it, (it == end()) ? it : std::next(it)); + } + + std::pair equal_range( + const CharT* key, size_type key_size) const { + const_iterator it = find(key, key_size); + return std::make_pair(it, (it == cend()) ? it : std::next(it)); + } + + std::pair equal_prefix_range( + const CharT* prefix, size_type prefix_size) { + if (m_root == nullptr) { + return std::make_pair(prefix_end(), prefix_end()); + } + + return equal_prefix_range_impl(*m_root, prefix, prefix_size); + } + + std::pair equal_prefix_range( + const CharT* prefix, size_type prefix_size) const { + if (m_root == nullptr) { + return std::make_pair(prefix_cend(), prefix_cend()); + } + + return equal_prefix_range_impl(*m_root, prefix, prefix_size); + } + + iterator longest_prefix(const CharT* key, size_type key_size) { + if (m_root == nullptr) { + return end(); + } + + return longest_prefix_impl(*m_root, key, key_size); + } + + const_iterator longest_prefix(const CharT* key, size_type key_size) const { + if (m_root == nullptr) { + return cend(); + } + + return longest_prefix_impl(*m_root, key, key_size); + } + + /* + * Hash policy + */ + float max_load_factor() const { return m_max_load_factor; } + + void max_load_factor(float ml) { m_max_load_factor = ml; } + + /* + * Burst policy + */ + size_type burst_threshold() const { return m_burst_threshold; } + + void burst_threshold(size_type threshold) { + const size_type min_burst_threshold = MIN_BURST_THRESHOLD; + m_burst_threshold = std::max(min_burst_threshold, threshold); + } + + /* + * Observers + */ + hasher hash_function() const { return m_hash; } + + /* + * Other + */ + template + void serialize(Serializer& serializer) const { + serialize_impl(serializer); + } + + template + void deserialize(Deserializer& deserializer, bool hash_compatible) { + deserialize_impl(deserializer, hash_compatible); + } + + private: + /** + * Get the begin iterator by searching for the most left descendant node + * starting at search_start_node. + */ + template + Iterator cbegin(const anode& search_start_node) const noexcept { + if (search_start_node.is_hash_node()) { + return Iterator(search_start_node.as_hash_node()); + } + + const trie_node& tnode = + search_start_node.as_trie_node().most_left_descendant_value_trie_node(); + if (tnode.val_node() != nullptr) { + return Iterator(tnode); + } else { + const anode* first_child = tnode.first_child(); + tsl_ht_assert(first_child != nullptr); + + return Iterator(first_child->as_hash_node()); + } + } + + /** + * Get an iterator to the node that come just after the last descendant of + * search_start_node. + */ + template + Iterator cend(const anode& search_start_node) const noexcept { + if (search_start_node.parent() == nullptr) { + Iterator it; + it.set_as_end_iterator(); + + return it; + } + + const trie_node* current_trie_node = search_start_node.parent(); + const anode* next_node = current_trie_node->next_child(search_start_node); + + while (next_node == nullptr && current_trie_node->parent() != nullptr) { + const anode* current_child = current_trie_node; + current_trie_node = current_trie_node->parent(); + next_node = current_trie_node->next_child(*current_child); + } + + if (next_node == nullptr) { + Iterator it; + it.set_as_end_iterator(); + + return it; + } else { + return cbegin(*next_node); + } + } + + prefix_iterator prefix_end() noexcept { + prefix_iterator it; + it.set_as_end_iterator(); + + return it; + } + + const_prefix_iterator prefix_cend() const noexcept { + const_prefix_iterator it; + it.set_as_end_iterator(); + + return it; + } + + size_type size_descendants(const anode& start_node) const { + auto first = cbegin(start_node); + auto last = cend(start_node); + + size_type nb_elements = 0; + while (first != last) { + if (first.m_read_trie_node_value) { + nb_elements++; + ++first; + } else { + nb_elements += first.m_current_hash_node->array_hash().size(); + first.skip_hash_node(); + } + } + + return nb_elements; + } + + template + std::pair insert_impl(anode& search_start_node, + const CharT* key, size_type key_size, + ValueArgs&&... value_args) { + anode* current_node = &search_start_node; + + for (size_type ikey = 0; ikey < key_size; ikey++) { + if (current_node->is_trie_node()) { + trie_node& tnode = current_node->as_trie_node(); + + if (tnode.child(key[ikey]) != nullptr) { + current_node = tnode.child(key[ikey]).get(); + } else { + auto hnode = make_unique(m_hash, m_max_load_factor); + auto insert_it = hnode->array_hash().emplace_ks( + key + ikey + 1, key_size - ikey - 1, + std::forward(value_args)...); + + tnode.set_child(key[ikey], std::move(hnode)); + m_nb_elements++; + + return std::make_pair( + iterator(tnode.child(key[ikey])->as_hash_node(), insert_it.first), + true); + } + } else { + return insert_in_hash_node(current_node->as_hash_node(), key + ikey, + key_size - ikey, + std::forward(value_args)...); + } + } + + if (current_node->is_trie_node()) { + trie_node& tnode = current_node->as_trie_node(); + if (tnode.val_node() != nullptr) { + return std::make_pair(iterator(tnode), false); + } else { + tnode.val_node() = + make_unique(std::forward(value_args)...); + m_nb_elements++; + + return std::make_pair(iterator(tnode), true); + } + } else { + return insert_in_hash_node(current_node->as_hash_node(), "", 0, + std::forward(value_args)...); + } + } + + template + std::pair insert_in_hash_node(hash_node& hnode, + const CharT* key, + size_type key_size, + ValueArgs&&... value_args) { + if (need_burst(hnode)) { + std::unique_ptr new_node = burst(hnode); + if (hnode.parent() == nullptr) { + tsl_ht_assert(m_root.get() == &hnode); + + m_root = std::move(new_node); + return insert_impl(*m_root, key, key_size, + std::forward(value_args)...); + } else { + trie_node* parent = hnode.parent(); + const CharT child_of_char = hnode.child_of_char(); + + parent->set_child(child_of_char, std::move(new_node)); + + return insert_impl(*parent->child(child_of_char), key, key_size, + std::forward(value_args)...); + } + } else { + auto it_insert = hnode.array_hash().emplace_ks( + key, key_size, std::forward(value_args)...); + if (it_insert.second) { + m_nb_elements++; + } + + return std::make_pair(iterator(hnode, it_insert.first), it_insert.second); + } + } + + iterator erase(iterator pos) { + iterator next_pos = std::next(pos); + + if (pos.m_read_trie_node_value) { + tsl_ht_assert(pos.m_current_trie_node != nullptr && + pos.m_current_trie_node->val_node() != nullptr); + + pos.m_current_trie_node->val_node().reset(nullptr); + m_nb_elements--; + + if (pos.m_current_trie_node->empty()) { + clear_empty_nodes(*pos.m_current_trie_node); + } + + return next_pos; + } else { + tsl_ht_assert(pos.m_current_hash_node != nullptr); + auto next_array_hash_it = pos.m_current_hash_node->array_hash().erase( + pos.m_array_hash_iterator); + m_nb_elements--; + + if (next_array_hash_it != pos.m_current_hash_node->array_hash().end()) { + // The erase on array_hash invalidated the next_pos iterator, return the + // right one. + return iterator(*pos.m_current_hash_node, next_array_hash_it); + } else { + if (pos.m_current_hash_node->array_hash().empty()) { + clear_empty_nodes(*pos.m_current_hash_node); + } + + return next_pos; + } + } + } + + /** + * Clear all the empty nodes from the tree starting from empty_node (empty for + * a hash_node means that the array hash is empty, for a trie_node it means + * the node doesn't have any child or value_node associated to it). + */ + void clear_empty_nodes(anode& empty_node) noexcept { + tsl_ht_assert(!empty_node.is_trie_node() || + (empty_node.as_trie_node().empty() && + empty_node.as_trie_node().val_node() == nullptr)); + tsl_ht_assert(!empty_node.is_hash_node() || + empty_node.as_hash_node().array_hash().empty()); + + trie_node* parent = empty_node.parent(); + if (parent == nullptr) { + tsl_ht_assert(m_root.get() == &empty_node); + tsl_ht_assert(m_nb_elements == 0); + m_root.reset(nullptr); + } else if (parent->val_node() != nullptr || parent->nb_children() > 1) { + parent->child(empty_node.child_of_char()).reset(nullptr); + } else if (parent->parent() == nullptr) { + tsl_ht_assert(m_root.get() == empty_node.parent()); + tsl_ht_assert(m_nb_elements == 0); + m_root.reset(nullptr); + } else { + /** + * Parent is empty if we remove its empty_node child. + * Put empty_node as new child of the grand parent instead of parent (move + * hnode up, and delete the parent). And recurse. + * + * We can't just set grand_parent->child(parent->child_of_char()) to + * nullptr as the grand_parent may also become empty. We don't want empty + * trie_node with no value_node in the tree. + */ + trie_node* grand_parent = parent->parent(); + grand_parent->set_child( + parent->child_of_char(), + std::move(parent->child(empty_node.child_of_char()))); + + clear_empty_nodes(empty_node); + } + } + + iterator find_impl(const anode& search_start_node, const CharT* key, + size_type key_size) { + return mutable_iterator(static_cast(this)->find_impl( + search_start_node, key, key_size)); + } + + const_iterator find_impl(const anode& search_start_node, const CharT* key, + size_type key_size) const { + const anode* current_node = &search_start_node; + + for (size_type ikey = 0; ikey < key_size; ikey++) { + if (current_node->is_trie_node()) { + const trie_node* tnode = ¤t_node->as_trie_node(); + + if (tnode->child(key[ikey]) == nullptr) { + return cend(); + } else { + current_node = tnode->child(key[ikey]).get(); + } + } else { + return find_in_hash_node(current_node->as_hash_node(), key + ikey, + key_size - ikey); + } + } + + if (current_node->is_trie_node()) { + const trie_node& tnode = current_node->as_trie_node(); + return (tnode.val_node() != nullptr) ? const_iterator(tnode) : cend(); + } else { + return find_in_hash_node(current_node->as_hash_node(), "", 0); + } + } + + const_iterator find_in_hash_node(const hash_node& hnode, const CharT* key, + size_type key_size) const { + auto it = hnode.array_hash().find_ks(key, key_size); + if (it != hnode.array_hash().end()) { + return const_iterator(hnode, it); + } else { + return cend(); + } + } + + iterator longest_prefix_impl(const anode& search_start_node, + const CharT* value, size_type value_size) { + return mutable_iterator( + static_cast(this)->longest_prefix_impl( + search_start_node, value, value_size)); + } + + const_iterator longest_prefix_impl(const anode& search_start_node, + const CharT* value, + size_type value_size) const { + const anode* current_node = &search_start_node; + const_iterator longest_found_prefix = cend(); + + for (size_type ivalue = 0; ivalue < value_size; ivalue++) { + if (current_node->is_trie_node()) { + const trie_node& tnode = current_node->as_trie_node(); + + if (tnode.val_node() != nullptr) { + longest_found_prefix = const_iterator(tnode); + } + + if (tnode.child(value[ivalue]) == nullptr) { + return longest_found_prefix; + } else { + current_node = tnode.child(value[ivalue]).get(); + } + } else { + const hash_node& hnode = current_node->as_hash_node(); + + /** + * Test the presence in the hash node of each substring from the + * remaining [ivalue, value_size) string starting from the longest. + * Also test the empty string. + */ + for (std::size_t i = ivalue; i <= value_size; i++) { + auto it = + hnode.array_hash().find_ks(value + ivalue, (value_size - i)); + if (it != hnode.array_hash().end()) { + return const_iterator(hnode, it); + } + } + + return longest_found_prefix; + } + } + + if (current_node->is_trie_node()) { + const trie_node& tnode = current_node->as_trie_node(); + + if (tnode.val_node() != nullptr) { + longest_found_prefix = const_iterator(tnode); + } + } else { + const hash_node& hnode = current_node->as_hash_node(); + + auto it = hnode.array_hash().find_ks("", 0); + if (it != hnode.array_hash().end()) { + longest_found_prefix = const_iterator(hnode, it); + } + } + + return longest_found_prefix; + } + + std::pair equal_prefix_range_impl( + anode& search_start_node, const CharT* prefix, size_type prefix_size) { + auto range = static_cast(this)->equal_prefix_range_impl( + search_start_node, prefix, prefix_size); + return std::make_pair(mutable_iterator(range.first), + mutable_iterator(range.second)); + } + + std::pair + equal_prefix_range_impl(const anode& search_start_node, const CharT* prefix, + size_type prefix_size) const { + const anode* current_node = &search_start_node; + + for (size_type iprefix = 0; iprefix < prefix_size; iprefix++) { + if (current_node->is_trie_node()) { + const trie_node* tnode = ¤t_node->as_trie_node(); + + if (tnode->child(prefix[iprefix]) == nullptr) { + return std::make_pair(prefix_cend(), prefix_cend()); + } else { + current_node = tnode->child(prefix[iprefix]).get(); + } + } else { + const hash_node& hnode = current_node->as_hash_node(); + const_prefix_iterator begin( + hnode.parent(), &hnode, hnode.array_hash().begin(), + hnode.array_hash().end(), false, + std::basic_string(prefix + iprefix, prefix_size - iprefix)); + begin.filter_prefix(); + + const_prefix_iterator end = cend(*current_node); + + return std::make_pair(begin, end); + } + } + + const_prefix_iterator begin = cbegin(*current_node); + const_prefix_iterator end = cend(*current_node); + + return std::make_pair(begin, end); + } + + size_type erase_prefix_hash_node(hash_node& hnode, const CharT* prefix, + size_type prefix_size) { + size_type nb_erased = 0; + + auto it = hnode.array_hash().begin(); + while (it != hnode.array_hash().end()) { + if (it.key_size() >= prefix_size && + std::memcmp(prefix, it.key(), prefix_size * sizeof(CharT)) == 0) { + it = hnode.array_hash().erase(it); + ++nb_erased; + --m_nb_elements; + } else { + ++it; + } + } + + return nb_erased; + } + + /* + * Burst + */ + bool need_burst(hash_node& node) const { + return node.array_hash().size() >= m_burst_threshold; + } + + /** + * Burst the node and use the copy constructor instead of move constructor for + * the values. Also use this method for trivial value types like int, int*, + * ... as it requires less book-keeping (thus faster) than the burst using + * move constructors. + */ + template ::value && std::is_copy_constructible::value && + (!std::is_nothrow_move_constructible::value || + !std::is_nothrow_move_assignable::value || + std::is_arithmetic::value || + std::is_pointer::value)>::type* = nullptr> + std::unique_ptr burst(hash_node& node) { + const std::array first_char_count = + get_first_char_count(node.array_hash().cbegin(), + node.array_hash().cend()); + + auto new_node = make_unique(); + for (auto it = node.array_hash().cbegin(); it != node.array_hash().cend(); + ++it) { + if (it.key_size() == 0) { + new_node->val_node() = make_unique(it.value()); + } else { + hash_node& hnode = + get_hash_node_for_char(first_char_count, *new_node, it.key()[0]); + hnode.array_hash().insert_ks(it.key() + 1, it.key_size() - 1, + it.value()); + } + } + + tsl_ht_assert(new_node->val_node() != nullptr || !new_node->empty()); + return new_node; + } + + /** + * Burst the node and use the move constructor and move assign operator if + * they don't throw. + */ + template ::value && + std::is_nothrow_move_constructible::value && + std::is_nothrow_move_assignable::value && + !std::is_arithmetic::value && + !std::is_pointer::value>::type* = nullptr> + std::unique_ptr burst(hash_node& node) { + /** + * We burst the node->array_hash() into multiple arrays hash. While doing + * so, we move each value in the node->array_hash() into the new arrays + * hash. After each move, we save a pointer to where the value has been + * moved. In case of exception, we rollback these values into the original + * node->array_hash(). + */ + std::vector moved_values_rollback; + moved_values_rollback.reserve(node.array_hash().size()); + + try { + const std::array first_char_count = + get_first_char_count(node.array_hash().cbegin(), + node.array_hash().cend()); + + auto new_node = make_unique(); + for (auto it = node.array_hash().begin(); it != node.array_hash().end(); + ++it) { + if (it.key_size() == 0) { + new_node->val_node() = make_unique(std::move(it.value())); + moved_values_rollback.push_back( + std::addressof(new_node->val_node()->m_value)); + } else { + hash_node& hnode = + get_hash_node_for_char(first_char_count, *new_node, it.key()[0]); + auto it_insert = hnode.array_hash().insert_ks( + it.key() + 1, it.key_size() - 1, std::move(it.value())); + moved_values_rollback.push_back( + std::addressof(it_insert.first.value())); + } + } + + tsl_ht_assert(new_node->val_node() != nullptr || !new_node->empty()); + return new_node; + } catch (...) { + // Rollback the values + auto it = node.array_hash().begin(); + for (std::size_t ivalue = 0; ivalue < moved_values_rollback.size(); + ivalue++) { + it.value() = std::move(*moved_values_rollback[ivalue]); + + ++it; + } + + throw; + } + } + + template ::value>::type* = nullptr> + std::unique_ptr burst(hash_node& node) { + const std::array first_char_count = + get_first_char_count(node.array_hash().begin(), + node.array_hash().end()); + + auto new_node = make_unique(); + for (auto it = node.array_hash().cbegin(); it != node.array_hash().cend(); + ++it) { + if (it.key_size() == 0) { + new_node->val_node() = make_unique(); + } else { + hash_node& hnode = + get_hash_node_for_char(first_char_count, *new_node, it.key()[0]); + hnode.array_hash().insert_ks(it.key() + 1, it.key_size() - 1); + } + } + + tsl_ht_assert(new_node->val_node() != nullptr || !new_node->empty()); + return new_node; + } + + std::array get_first_char_count( + typename array_hash_type::const_iterator begin, + typename array_hash_type::const_iterator end) const { + std::array count{{}}; + for (auto it = begin; it != end; ++it) { + if (it.key_size() == 0) { + continue; + } + + count[as_position(it.key()[0])]++; + } + + return count; + } + + hash_node& get_hash_node_for_char( + const std::array& first_char_count, + trie_node& tnode, CharT for_char) { + if (tnode.child(for_char) == nullptr) { + const size_type nb_buckets = + size_type(std::ceil(float(first_char_count[as_position(for_char)] + + HASH_NODE_DEFAULT_INIT_BUCKETS_COUNT / 2) / + m_max_load_factor)); + + tnode.set_child(for_char, make_unique(nb_buckets, m_hash, + m_max_load_factor)); + } + + return tnode.child(for_char)->as_hash_node(); + } + + iterator mutable_iterator(const_iterator it) noexcept { + // end iterator or reading from a trie node value + if (it.m_current_hash_node == nullptr || it.m_read_trie_node_value) { + typename array_hash_type::iterator default_it; + + return iterator(const_cast(it.m_current_trie_node), nullptr, + default_it, default_it, it.m_read_trie_node_value); + } else { + hash_node* hnode = const_cast(it.m_current_hash_node); + return iterator( + const_cast(it.m_current_trie_node), hnode, + hnode->array_hash().mutable_iterator(it.m_array_hash_iterator), + hnode->array_hash().mutable_iterator(it.m_array_hash_end_iterator), + it.m_read_trie_node_value); + } + } + + prefix_iterator mutable_iterator(const_prefix_iterator it) noexcept { + // end iterator or reading from a trie node value + if (it.m_current_hash_node == nullptr || it.m_read_trie_node_value) { + typename array_hash_type::iterator default_it; + + return prefix_iterator(const_cast(it.m_current_trie_node), + nullptr, default_it, default_it, + it.m_read_trie_node_value, ""); + } else { + hash_node* hnode = const_cast(it.m_current_hash_node); + return prefix_iterator( + const_cast(it.m_current_trie_node), hnode, + hnode->array_hash().mutable_iterator(it.m_array_hash_iterator), + hnode->array_hash().mutable_iterator(it.m_array_hash_end_iterator), + it.m_read_trie_node_value, it.m_prefix_filter); + } + } + + template + void serialize_impl(Serializer& serializer) const { + const slz_size_type version = SERIALIZATION_PROTOCOL_VERSION; + serializer(version); + + const slz_size_type nb_elements = m_nb_elements; + serializer(nb_elements); + + const float max_load_factor = m_max_load_factor; + serializer(max_load_factor); + + const slz_size_type burst_threshold = m_burst_threshold; + serializer(burst_threshold); + + std::basic_string str_buffer; + + auto it = begin(); + auto last = end(); + + while (it != last) { + // Serialize trie node value + if (it.m_read_trie_node_value) { + const CharT node_type = + static_cast::type>( + slz_node_type::TRIE_NODE); + serializer(&node_type, 1); + + it.key(str_buffer); + + const slz_size_type str_size = str_buffer.size(); + serializer(str_size); + serializer(str_buffer.data(), str_buffer.size()); + serialize_value(serializer, it); + + ++it; + } + // Serialize hash node values + else { + const CharT node_type = + static_cast::type>( + slz_node_type::HASH_NODE); + serializer(&node_type, 1); + + it.hash_node_prefix(str_buffer); + + const slz_size_type str_size = str_buffer.size(); + serializer(str_size); + serializer(str_buffer.data(), str_buffer.size()); + + const hash_node* hnode = it.m_current_hash_node; + tsl_ht_assert(hnode != nullptr); + hnode->array_hash().serialize(serializer); + + it.skip_hash_node(); + } + } + } + + template ::value>::type* = nullptr> + void serialize_value(Serializer& /*serializer*/, + const_iterator /*it*/) const {} + + template ::value>::type* = nullptr> + void serialize_value(Serializer& serializer, const_iterator it) const { + serializer(it.value()); + } + + template + void deserialize_impl(Deserializer& deserializer, bool hash_compatible) { + tsl_ht_assert(m_nb_elements == 0 && + m_root == nullptr); // Current trie must be empty + + const slz_size_type version = + deserialize_value(deserializer); + // For now we only have one version of the serialization protocol. + // If it doesn't match there is a problem with the file. + if (version != SERIALIZATION_PROTOCOL_VERSION) { + throw std::runtime_error( + "Can't deserialize the htrie_map/set. The protocol version header is " + "invalid."); + } + + const slz_size_type nb_elements = + deserialize_value(deserializer); + const float max_load_factor = deserialize_value(deserializer); + const slz_size_type burst_threshold = + deserialize_value(deserializer); + + this->burst_threshold(numeric_cast( + burst_threshold, "Deserialized burst_threshold is too big.")); + this->max_load_factor(max_load_factor); + + std::vector str_buffer; + while (m_nb_elements < nb_elements) { + CharT node_type_marker; + deserializer(&node_type_marker, 1); + + static_assert( + std::is_same< + CharT, typename std::underlying_type::type>::value, + ""); + const slz_node_type node_type = + static_cast(node_type_marker); + if (node_type == slz_node_type::TRIE_NODE) { + const std::size_t str_size = numeric_cast( + deserialize_value(deserializer), + "Deserialized str_size is too big."); + + str_buffer.resize(str_size); + deserializer(str_buffer.data(), str_size); + + trie_node* current_node = + insert_prefix_trie_nodes(str_buffer.data(), str_size); + deserialize_value_node(deserializer, current_node); + m_nb_elements++; + } else if (node_type == slz_node_type::HASH_NODE) { + const std::size_t str_size = numeric_cast( + deserialize_value(deserializer), + "Deserialized str_size is too big."); + + if (str_size == 0) { + tsl_ht_assert(m_nb_elements == 0 && !m_root); + + m_root = make_unique( + array_hash_type::deserialize(deserializer, hash_compatible)); + m_nb_elements += m_root->as_hash_node().array_hash().size(); + + tsl_ht_assert(m_nb_elements == nb_elements); + } else { + str_buffer.resize(str_size); + deserializer(str_buffer.data(), str_size); + + auto hnode = make_unique( + array_hash_type::deserialize(deserializer, hash_compatible)); + m_nb_elements += hnode->array_hash().size(); + + trie_node* current_node = + insert_prefix_trie_nodes(str_buffer.data(), str_size - 1); + current_node->set_child(str_buffer[str_size - 1], std::move(hnode)); + } + } else { + throw std::runtime_error("Unknown deserialized node type."); + } + } + + tsl_ht_assert(m_nb_elements == nb_elements); + } + + trie_node* insert_prefix_trie_nodes(const CharT* prefix, + std::size_t prefix_size) { + if (m_root == nullptr) { + m_root = make_unique(); + } + + trie_node* current_node = &m_root->as_trie_node(); + for (std::size_t iprefix = 0; iprefix < prefix_size; iprefix++) { + if (current_node->child(prefix[iprefix]) == nullptr) { + current_node->set_child(prefix[iprefix], make_unique()); + } + + current_node = ¤t_node->child(prefix[iprefix])->as_trie_node(); + } + + return current_node; + } + + template ::value>::type* = nullptr> + void deserialize_value_node(Deserializer& /*deserializer*/, + trie_node* current_node) { + tsl_ht_assert(!current_node->val_node()); + current_node->val_node() = make_unique(); + } + + template ::value>::type* = nullptr> + void deserialize_value_node(Deserializer& deserializer, + trie_node* current_node) { + tsl_ht_assert(!current_node->val_node()); + current_node->val_node() = + make_unique(deserialize_value(deserializer)); + } + + template + static U deserialize_value(Deserializer& deserializer) { + // MSVC < 2017 is not conformant, circumvent the problem by removing the + // template keyword +#if defined(_MSC_VER) && _MSC_VER < 1910 + return deserializer.Deserializer::operator()(); +#else + return deserializer.Deserializer::template operator()(); +#endif + } + + // Same as std::make_unique for non-array types which is only available in + // C++14 (we need to support C++11). + template + static std::unique_ptr make_unique(Args&&... args) { + return std::unique_ptr(new U(std::forward(args)...)); + } + + public: + static constexpr float HASH_NODE_DEFAULT_MAX_LOAD_FACTOR = 8.0f; + static const size_type DEFAULT_BURST_THRESHOLD = 16384; + + private: + /** + * Fixed size type used to represent size_type values on serialization. Need + * to be big enough to represent a std::size_t on 32 and 64 bits platforms, + * and must be the same size on both platforms. + */ + using slz_size_type = std::uint64_t; + enum class slz_node_type : CharT { TRIE_NODE = 0, HASH_NODE = 1 }; + + /** + * Protocol version currenlty used for serialization. + */ + static const slz_size_type SERIALIZATION_PROTOCOL_VERSION = 1; + + static const size_type HASH_NODE_DEFAULT_INIT_BUCKETS_COUNT = 32; + static const size_type MIN_BURST_THRESHOLD = 4; + + std::unique_ptr m_root; + size_type m_nb_elements; + Hash m_hash; + float m_max_load_factor; + size_type m_burst_threshold; +}; + +} // end namespace detail_htrie_hash +} // end namespace tsl + +#endif diff --git a/include/tsl/htrie_map.h b/include/tsl/htrie_map.h new file mode 100644 index 00000000..c601ac8d --- /dev/null +++ b/include/tsl/htrie_map.h @@ -0,0 +1,668 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_HTRIE_MAP_H +#define TSL_HTRIE_MAP_H + +#include +#include +#include +#include +#include + +#include "htrie_hash.h" + +namespace tsl { + +/** + * Implementation of a hat-trie map. + * + * The value T must be either nothrow move-constructible/assignable, + * copy-constructible or both. + * + * The size of a key string is limited to std::numeric_limits::max() + * - 1. That is 65 535 characters by default, but can be raised with the + * KeySizeT template parameter. See max_key_size() for an easy access to this + * limit. + * + * Iterators invalidation: + * - clear, operator=: always invalidate the iterators. + * - insert, emplace, operator[]: always invalidate the iterators. + * - erase: always invalidate the iterators. + */ +template , + class KeySizeT = std::uint16_t> +class htrie_map { + private: + template + using is_iterator = tsl::detail_array_hash::is_iterator; + + using ht = tsl::detail_htrie_hash::htrie_hash; + + public: + using char_type = typename ht::char_type; + using mapped_type = T; + using key_size_type = typename ht::key_size_type; + using size_type = typename ht::size_type; + using hasher = typename ht::hasher; + using iterator = typename ht::iterator; + using const_iterator = typename ht::const_iterator; + using prefix_iterator = typename ht::prefix_iterator; + using const_prefix_iterator = typename ht::const_prefix_iterator; + + public: + explicit htrie_map(const Hash& hash = Hash()) + : m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR, + ht::DEFAULT_BURST_THRESHOLD) {} + + explicit htrie_map(size_type burst_threshold, const Hash& hash = Hash()) + : m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR, burst_threshold) {} + + template ::value>::type* = nullptr> + htrie_map(InputIt first, InputIt last, const Hash& hash = Hash()) + : htrie_map(hash) { + insert(first, last); + } + +#ifdef TSL_HT_HAS_STRING_VIEW + htrie_map( + std::initializer_list, T>> init, + const Hash& hash = Hash()) + : htrie_map(hash) { + insert(init); + } +#else + htrie_map(std::initializer_list> init, + const Hash& hash = Hash()) + : htrie_map(hash) { + insert(init); + } +#endif + +#ifdef TSL_HT_HAS_STRING_VIEW + htrie_map& operator=( + std::initializer_list, T>> + ilist) { + clear(); + insert(ilist); + + return *this; + } +#else + htrie_map& operator=( + std::initializer_list> ilist) { + clear(); + insert(ilist); + + return *this; + } +#endif + + /* + * Iterators + */ + iterator begin() noexcept { return m_ht.begin(); } + const_iterator begin() const noexcept { return m_ht.begin(); } + const_iterator cbegin() const noexcept { return m_ht.cbegin(); } + + iterator end() noexcept { return m_ht.end(); } + const_iterator end() const noexcept { return m_ht.end(); } + const_iterator cend() const noexcept { return m_ht.cend(); } + + /* + * Capacity + */ + bool empty() const noexcept { return m_ht.empty(); } + size_type size() const noexcept { return m_ht.size(); } + size_type max_size() const noexcept { return m_ht.max_size(); } + size_type max_key_size() const noexcept { return m_ht.max_key_size(); } + + /** + * Call shrink_to_fit() on each hash node of the hat-trie to reduce its size. + */ + void shrink_to_fit() { m_ht.shrink_to_fit(); } + + /* + * Modifiers + */ + void clear() noexcept { m_ht.clear(); } + + std::pair insert_ks(const CharT* key, size_type key_size, + const T& value) { + return m_ht.insert(key, key_size, value); + } +#ifdef TSL_HT_HAS_STRING_VIEW + std::pair insert(const std::basic_string_view& key, + const T& value) { + return m_ht.insert(key.data(), key.size(), value); + } +#else + std::pair insert(const CharT* key, const T& value) { + return m_ht.insert(key, std::strlen(key), value); + } + + std::pair insert(const std::basic_string& key, + const T& value) { + return m_ht.insert(key.data(), key.size(), value); + } +#endif + + std::pair insert_ks(const CharT* key, size_type key_size, + T&& value) { + return m_ht.insert(key, key_size, std::move(value)); + } +#ifdef TSL_HT_HAS_STRING_VIEW + std::pair insert(const std::basic_string_view& key, + T&& value) { + return m_ht.insert(key.data(), key.size(), std::move(value)); + } +#else + std::pair insert(const CharT* key, T&& value) { + return m_ht.insert(key, std::strlen(key), std::move(value)); + } + + std::pair insert(const std::basic_string& key, + T&& value) { + return m_ht.insert(key.data(), key.size(), std::move(value)); + } +#endif + + template ::value>::type* = nullptr> + void insert(InputIt first, InputIt last) { + for (auto it = first; it != last; ++it) { + insert_pair(*it); + } + } + +#ifdef TSL_HT_HAS_STRING_VIEW + void insert(std::initializer_list, T>> + ilist) { + insert(ilist.begin(), ilist.end()); + } +#else + void insert(std::initializer_list> ilist) { + insert(ilist.begin(), ilist.end()); + } +#endif + + template + std::pair emplace_ks(const CharT* key, size_type key_size, + Args&&... args) { + return m_ht.insert(key, key_size, std::forward(args)...); + } +#ifdef TSL_HT_HAS_STRING_VIEW + template + std::pair emplace(const std::basic_string_view& key, + Args&&... args) { + return m_ht.insert(key.data(), key.size(), std::forward(args)...); + } +#else + template + std::pair emplace(const CharT* key, Args&&... args) { + return m_ht.insert(key, std::strlen(key), std::forward(args)...); + } + + template + std::pair emplace(const std::basic_string& key, + Args&&... args) { + return m_ht.insert(key.data(), key.size(), std::forward(args)...); + } +#endif + + iterator erase(const_iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator first, const_iterator last) { + return m_ht.erase(first, last); + } + + size_type erase_ks(const CharT* key, size_type key_size) { + return m_ht.erase(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + size_type erase(const std::basic_string_view& key) { + return m_ht.erase(key.data(), key.size()); + } +#else + size_type erase(const CharT* key) { + return m_ht.erase(key, std::strlen(key)); + } + + size_type erase(const std::basic_string& key) { + return m_ht.erase(key.data(), key.size()); + } +#endif + + /** + * Erase all the elements which have 'prefix' as prefix. Return the number of + * erase elements. + */ + size_type erase_prefix_ks(const CharT* prefix, size_type prefix_size) { + return m_ht.erase_prefix(prefix, prefix_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + /** + * @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size) + */ + size_type erase_prefix(const std::basic_string_view& prefix) { + return m_ht.erase_prefix(prefix.data(), prefix.size()); + } +#else + /** + * @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size) + */ + size_type erase_prefix(const CharT* prefix) { + return m_ht.erase_prefix(prefix, std::strlen(prefix)); + } + + /** + * @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size) + */ + size_type erase_prefix(const std::basic_string& prefix) { + return m_ht.erase_prefix(prefix.data(), prefix.size()); + } +#endif + + void swap(htrie_map& other) { other.m_ht.swap(m_ht); } + + /* + * Lookup + */ + T& at_ks(const CharT* key, size_type key_size) { + return m_ht.at(key, key_size); + } + const T& at_ks(const CharT* key, size_type key_size) const { + return m_ht.at(key, key_size); + } + +#ifdef TSL_HT_HAS_STRING_VIEW + T& at(const std::basic_string_view& key) { + return m_ht.at(key.data(), key.size()); + } + const T& at(const std::basic_string_view& key) const { + return m_ht.at(key.data(), key.size()); + } +#else + T& at(const CharT* key) { return m_ht.at(key, std::strlen(key)); } + const T& at(const CharT* key) const { return m_ht.at(key, std::strlen(key)); } + + T& at(const std::basic_string& key) { + return m_ht.at(key.data(), key.size()); + } + const T& at(const std::basic_string& key) const { + return m_ht.at(key.data(), key.size()); + } +#endif + +#ifdef TSL_HT_HAS_STRING_VIEW + T& operator[](const std::basic_string_view& key) { + return m_ht.access_operator(key.data(), key.size()); + } +#else + T& operator[](const CharT* key) { + return m_ht.access_operator(key, std::strlen(key)); + } + T& operator[](const std::basic_string& key) { + return m_ht.access_operator(key.data(), key.size()); + } +#endif + + size_type count_ks(const CharT* key, size_type key_size) const { + return m_ht.count(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + size_type count(const std::basic_string_view& key) const { + return m_ht.count(key.data(), key.size()); + } +#else + size_type count(const CharT* key) const { + return m_ht.count(key, std::strlen(key)); + } + size_type count(const std::basic_string& key) const { + return m_ht.count(key.data(), key.size()); + } +#endif + + iterator find_ks(const CharT* key, size_type key_size) { + return m_ht.find(key, key_size); + } + + const_iterator find_ks(const CharT* key, size_type key_size) const { + return m_ht.find(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + iterator find(const std::basic_string_view& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string_view& key) const { + return m_ht.find(key.data(), key.size()); + } +#else + iterator find(const CharT* key) { return m_ht.find(key, std::strlen(key)); } + + const_iterator find(const CharT* key) const { + return m_ht.find(key, std::strlen(key)); + } + + iterator find(const std::basic_string& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string& key) const { + return m_ht.find(key.data(), key.size()); + } +#endif + + std::pair equal_range_ks(const CharT* key, + size_type key_size) { + return m_ht.equal_range(key, key_size); + } + + std::pair equal_range_ks( + const CharT* key, size_type key_size) const { + return m_ht.equal_range(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + std::pair equal_range( + const std::basic_string_view& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string_view& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#else + std::pair equal_range(const CharT* key) { + return m_ht.equal_range(key, std::strlen(key)); + } + + std::pair equal_range( + const CharT* key) const { + return m_ht.equal_range(key, std::strlen(key)); + } + + std::pair equal_range( + const std::basic_string& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#endif + + /** + * Return a range containing all the elements which have 'prefix' as prefix. + * The range is defined by a pair of iterator, the first being the begin + * iterator and the second being the end iterator. + */ + std::pair equal_prefix_range_ks( + const CharT* prefix, size_type prefix_size) { + return m_ht.equal_prefix_range(prefix, prefix_size); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range_ks( + const CharT* prefix, size_type prefix_size) const { + return m_ht.equal_prefix_range(prefix, prefix_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string_view& prefix) { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string_view& prefix) const { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } +#else + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const CharT* prefix) { + return m_ht.equal_prefix_range(prefix, std::strlen(prefix)); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const CharT* prefix) const { + return m_ht.equal_prefix_range(prefix, std::strlen(prefix)); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string& prefix) { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string& prefix) const { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } +#endif + + /** + * Return the element in the trie which is the longest prefix of `key`. If no + * element in the trie is a prefix of `key`, the end iterator is returned. + * + * Example: + * + * tsl::htrie_map map = {{"/foo", 1}, {"/foo/bar", 1}}; + * + * map.longest_prefix("/foo"); // returns {"/foo", 1} + * map.longest_prefix("/foo/baz"); // returns {"/foo", 1} + * map.longest_prefix("/foo/bar/baz"); // returns {"/foo/bar", 1} + * map.longest_prefix("/foo/bar/"); // returns {"/foo/bar", 1} + * map.longest_prefix("/bar"); // returns end() + * map.longest_prefix(""); // returns end() + */ + iterator longest_prefix_ks(const CharT* key, size_type key_size) { + return m_ht.longest_prefix(key, key_size); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix_ks(const CharT* key, size_type key_size) const { + return m_ht.longest_prefix(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + iterator longest_prefix(const std::basic_string_view& key) { + return m_ht.longest_prefix(key.data(), key.size()); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix( + const std::basic_string_view& key) const { + return m_ht.longest_prefix(key.data(), key.size()); + } +#else + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + iterator longest_prefix(const CharT* key) { + return m_ht.longest_prefix(key, std::strlen(key)); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix(const CharT* key) const { + return m_ht.longest_prefix(key, std::strlen(key)); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + iterator longest_prefix(const std::basic_string& key) { + return m_ht.longest_prefix(key.data(), key.size()); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix(const std::basic_string& key) const { + return m_ht.longest_prefix(key.data(), key.size()); + } +#endif + + /* + * Hash policy + */ + float max_load_factor() const { return m_ht.max_load_factor(); } + void max_load_factor(float ml) { m_ht.max_load_factor(ml); } + + /* + * Burst policy + */ + size_type burst_threshold() const { return m_ht.burst_threshold(); } + void burst_threshold(size_type threshold) { m_ht.burst_threshold(threshold); } + + /* + * Observers + */ + hasher hash_function() const { return m_ht.hash_function(); } + + /* + * Other + */ + + /** + * Serialize the map through the `serializer` parameter. + * + * The `serializer` parameter must be a function object that supports the + * following calls: + * - `void operator()(const U& value);` where the types `std::uint64_t`, + * `float` and `T` must be supported for U. + * - `void operator()(const CharT* value, std::size_t value_size);` + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, ...) of the types it serializes in the hands of the `Serializer` + * function object if compatibility is required. + */ + template + void serialize(Serializer& serializer) const { + m_ht.serialize(serializer); + } + + /** + * Deserialize a previously serialized map through the `deserializer` + * parameter. + * + * The `deserializer` parameter must be a function object that supports the + * following calls: + * - `template U operator()();` where the types `std::uint64_t`, + * `float` and `T` must be supported for U. + * - `void operator()(CharT* value_out, std::size_t value_size);` + * + * If the deserialized hash map part of the hat-trie is hash compatible with + * the serialized map, the deserialization process can be sped up by setting + * `hash_compatible` to true. To be hash compatible, the Hash (take care of + * the 32-bits vs 64 bits), and KeySizeT must behave the same than the ones + * used in the serialized map. Otherwise the behaviour is undefined with + * `hash_compatible` sets to true. + * + * The behaviour is undefined if the type `CharT` and `T` of the `htrie_map` + * are not the same as the types used during serialization. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, size of int, ...) of the types it deserializes in the hands of the + * `Deserializer` function object if compatibility is required. + */ + template + static htrie_map deserialize(Deserializer& deserializer, + bool hash_compatible = false) { + htrie_map map; + map.m_ht.deserialize(deserializer, hash_compatible); + + return map; + } + + friend bool operator==(const htrie_map& lhs, const htrie_map& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + std::string key_buffer; + for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) { + it.key(key_buffer); + + const auto it_element_rhs = rhs.find(key_buffer); + if (it_element_rhs == rhs.cend() || + it.value() != it_element_rhs.value()) { + return false; + } + } + + return true; + } + + friend bool operator!=(const htrie_map& lhs, const htrie_map& rhs) { + return !operator==(lhs, rhs); + } + + friend void swap(htrie_map& lhs, htrie_map& rhs) { lhs.swap(rhs); } + + private: + template + void insert_pair(const std::pair& value) { + insert(value.first, value.second); + } + + template + void insert_pair(std::pair&& value) { + insert(value.first, std::move(value.second)); + } + + private: + ht m_ht; +}; + +} // end namespace tsl + +#endif diff --git a/include/tsl/htrie_set.h b/include/tsl/htrie_set.h new file mode 100644 index 00000000..ea71f0e9 --- /dev/null +++ b/include/tsl/htrie_set.h @@ -0,0 +1,578 @@ +/** + * MIT License + * + * Copyright (c) 2017 Thibaut Goetghebuer-Planchon + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef TSL_HTRIE_SET_H +#define TSL_HTRIE_SET_H + +#include +#include +#include +#include +#include + +#include "htrie_hash.h" + +namespace tsl { + +/** + * Implementation of a hat-trie set. + * + * The size of a key string is limited to std::numeric_limits::max() + * - 1. That is 65 535 characters by default, but can be raised with the + * KeySizeT template parameter. See max_key_size() for an easy access to this + * limit. + * + * Iterators invalidation: + * - clear, operator=: always invalidate the iterators. + * - insert: always invalidate the iterators. + * - erase: always invalidate the iterators. + */ +template , + class KeySizeT = std::uint16_t> +class htrie_set { + private: + template + using is_iterator = tsl::detail_array_hash::is_iterator; + + using ht = tsl::detail_htrie_hash::htrie_hash; + + public: + using char_type = typename ht::char_type; + using key_size_type = typename ht::key_size_type; + using size_type = typename ht::size_type; + using hasher = typename ht::hasher; + using iterator = typename ht::iterator; + using const_iterator = typename ht::const_iterator; + using prefix_iterator = typename ht::prefix_iterator; + using const_prefix_iterator = typename ht::const_prefix_iterator; + + public: + explicit htrie_set(const Hash& hash = Hash()) + : m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR, + ht::DEFAULT_BURST_THRESHOLD) {} + + explicit htrie_set(size_type burst_threshold, const Hash& hash = Hash()) + : m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR, burst_threshold) {} + + template ::value>::type* = nullptr> + htrie_set(InputIt first, InputIt last, const Hash& hash = Hash()) + : htrie_set(hash) { + insert(first, last); + } + +#ifdef TSL_HT_HAS_STRING_VIEW + htrie_set(std::initializer_list> init, + const Hash& hash = Hash()) + : htrie_set(hash) { + insert(init); + } +#else + htrie_set(std::initializer_list init, const Hash& hash = Hash()) + : htrie_set(hash) { + insert(init); + } +#endif + +#ifdef TSL_HT_HAS_STRING_VIEW + htrie_set& operator=( + std::initializer_list> ilist) { + clear(); + insert(ilist); + + return *this; + } +#else + htrie_set& operator=(std::initializer_list ilist) { + clear(); + insert(ilist); + + return *this; + } +#endif + + /* + * Iterators + */ + iterator begin() noexcept { return m_ht.begin(); } + const_iterator begin() const noexcept { return m_ht.begin(); } + const_iterator cbegin() const noexcept { return m_ht.cbegin(); } + + iterator end() noexcept { return m_ht.end(); } + const_iterator end() const noexcept { return m_ht.end(); } + const_iterator cend() const noexcept { return m_ht.cend(); } + + /* + * Capacity + */ + bool empty() const noexcept { return m_ht.empty(); } + size_type size() const noexcept { return m_ht.size(); } + size_type max_size() const noexcept { return m_ht.max_size(); } + size_type max_key_size() const noexcept { return m_ht.max_key_size(); } + + /** + * Call shrink_to_fit() on each hash node of the hat-trie to reduce its size. + */ + void shrink_to_fit() { m_ht.shrink_to_fit(); } + + /* + * Modifiers + */ + void clear() noexcept { m_ht.clear(); } + + std::pair insert_ks(const CharT* key, size_type key_size) { + return m_ht.insert(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + std::pair insert(const std::basic_string_view& key) { + return m_ht.insert(key.data(), key.size()); + } +#else + std::pair insert(const CharT* key) { + return m_ht.insert(key, std::strlen(key)); + } + + std::pair insert(const std::basic_string& key) { + return m_ht.insert(key.data(), key.size()); + } +#endif + + template ::value>::type* = nullptr> + void insert(InputIt first, InputIt last) { + for (auto it = first; it != last; ++it) { + insert(*it); + } + } + +#ifdef TSL_HT_HAS_STRING_VIEW + void insert(std::initializer_list> ilist) { + insert(ilist.begin(), ilist.end()); + } +#else + void insert(std::initializer_list ilist) { + insert(ilist.begin(), ilist.end()); + } +#endif + + std::pair emplace_ks(const CharT* key, size_type key_size) { + return m_ht.insert(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + std::pair emplace(const std::basic_string_view& key) { + return m_ht.insert(key.data(), key.size()); + } +#else + std::pair emplace(const CharT* key) { + return m_ht.insert(key, std::strlen(key)); + } + + std::pair emplace(const std::basic_string& key) { + return m_ht.insert(key.data(), key.size()); + } +#endif + + iterator erase(const_iterator pos) { return m_ht.erase(pos); } + iterator erase(const_iterator first, const_iterator last) { + return m_ht.erase(first, last); + } + + size_type erase_ks(const CharT* key, size_type key_size) { + return m_ht.erase(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + size_type erase(const std::basic_string_view& key) { + return m_ht.erase(key.data(), key.size()); + } +#else + size_type erase(const CharT* key) { + return m_ht.erase(key, std::strlen(key)); + } + + size_type erase(const std::basic_string& key) { + return m_ht.erase(key.data(), key.size()); + } +#endif + + /** + * Erase all the elements which have 'prefix' as prefix. Return the number of + * erase elements. + */ + size_type erase_prefix_ks(const CharT* prefix, size_type prefix_size) { + return m_ht.erase_prefix(prefix, prefix_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + /** + * @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size) + */ + size_type erase_prefix(const std::basic_string_view& prefix) { + return m_ht.erase_prefix(prefix.data(), prefix.size()); + } +#else + /** + * @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size) + */ + size_type erase_prefix(const CharT* prefix) { + return m_ht.erase_prefix(prefix, std::strlen(prefix)); + } + + /** + * @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size) + */ + size_type erase_prefix(const std::basic_string& prefix) { + return m_ht.erase_prefix(prefix.data(), prefix.size()); + } +#endif + + void swap(htrie_set& other) { other.m_ht.swap(m_ht); } + + /* + * Lookup + */ + size_type count_ks(const CharT* key, size_type key_size) const { + return m_ht.count(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + size_type count(const std::basic_string_view& key) const { + return m_ht.count(key.data(), key.size()); + } +#else + size_type count(const CharT* key) const { + return m_ht.count(key, std::strlen(key)); + } + size_type count(const std::basic_string& key) const { + return m_ht.count(key.data(), key.size()); + } +#endif + + iterator find_ks(const CharT* key, size_type key_size) { + return m_ht.find(key, key_size); + } + + const_iterator find_ks(const CharT* key, size_type key_size) const { + return m_ht.find(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + iterator find(const std::basic_string_view& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string_view& key) const { + return m_ht.find(key.data(), key.size()); + } +#else + iterator find(const CharT* key) { return m_ht.find(key, std::strlen(key)); } + + const_iterator find(const CharT* key) const { + return m_ht.find(key, std::strlen(key)); + } + + iterator find(const std::basic_string& key) { + return m_ht.find(key.data(), key.size()); + } + + const_iterator find(const std::basic_string& key) const { + return m_ht.find(key.data(), key.size()); + } +#endif + + std::pair equal_range_ks(const CharT* key, + size_type key_size) { + return m_ht.equal_range(key, key_size); + } + + std::pair equal_range_ks( + const CharT* key, size_type key_size) const { + return m_ht.equal_range(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + std::pair equal_range( + const std::basic_string_view& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string_view& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#else + std::pair equal_range(const CharT* key) { + return m_ht.equal_range(key, std::strlen(key)); + } + + std::pair equal_range( + const CharT* key) const { + return m_ht.equal_range(key, std::strlen(key)); + } + + std::pair equal_range( + const std::basic_string& key) { + return m_ht.equal_range(key.data(), key.size()); + } + + std::pair equal_range( + const std::basic_string& key) const { + return m_ht.equal_range(key.data(), key.size()); + } +#endif + + /** + * Return a range containing all the elements which have 'prefix' as prefix. + * The range is defined by a pair of iterator, the first being the begin + * iterator and the second being the end iterator. + */ + std::pair equal_prefix_range_ks( + const CharT* prefix, size_type prefix_size) { + return m_ht.equal_prefix_range(prefix, prefix_size); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range_ks( + const CharT* prefix, size_type prefix_size) const { + return m_ht.equal_prefix_range(prefix, prefix_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string_view& prefix) { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string_view& prefix) const { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } +#else + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const CharT* prefix) { + return m_ht.equal_prefix_range(prefix, std::strlen(prefix)); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const CharT* prefix) const { + return m_ht.equal_prefix_range(prefix, std::strlen(prefix)); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string& prefix) { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } + + /** + * @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size) + */ + std::pair equal_prefix_range( + const std::basic_string& prefix) const { + return m_ht.equal_prefix_range(prefix.data(), prefix.size()); + } +#endif + + /** + * Return the element in the trie which is the longest prefix of `key`. If no + * element in the trie is a prefix of `key`, the end iterator is returned. + * + * Example: + * + * tsl::htrie_set set = {"/foo", "/foo/bar"}; + * + * set.longest_prefix("/foo"); // returns "/foo" + * set.longest_prefix("/foo/baz"); // returns "/foo" + * set.longest_prefix("/foo/bar/baz"); // returns "/foo/bar" + * set.longest_prefix("/foo/bar/"); // returns "/foo/bar" + * set.longest_prefix("/bar"); // returns end() + * set.longest_prefix(""); // returns end() + */ + iterator longest_prefix_ks(const CharT* key, size_type key_size) { + return m_ht.longest_prefix(key, key_size); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix_ks(const CharT* key, size_type key_size) const { + return m_ht.longest_prefix(key, key_size); + } +#ifdef TSL_HT_HAS_STRING_VIEW + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + iterator longest_prefix(const std::basic_string_view& key) { + return m_ht.longest_prefix(key.data(), key.size()); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix( + const std::basic_string_view& key) const { + return m_ht.longest_prefix(key.data(), key.size()); + } +#else + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + iterator longest_prefix(const CharT* key) { + return m_ht.longest_prefix(key, std::strlen(key)); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix(const CharT* key) const { + return m_ht.longest_prefix(key, std::strlen(key)); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + iterator longest_prefix(const std::basic_string& key) { + return m_ht.longest_prefix(key.data(), key.size()); + } + + /** + * @copydoc longest_prefix_ks(const CharT* key, size_type key_size) + */ + const_iterator longest_prefix(const std::basic_string& key) const { + return m_ht.longest_prefix(key.data(), key.size()); + } +#endif + + /* + * Hash policy + */ + float max_load_factor() const { return m_ht.max_load_factor(); } + void max_load_factor(float ml) { m_ht.max_load_factor(ml); } + + /* + * Burst policy + */ + size_type burst_threshold() const { return m_ht.burst_threshold(); } + void burst_threshold(size_type threshold) { m_ht.burst_threshold(threshold); } + + /* + * Observers + */ + hasher hash_function() const { return m_ht.hash_function(); } + + /* + * Other + */ + + /** + * Serialize the set through the `serializer` parameter. + * + * The `serializer` parameter must be a function object that supports the + * following calls: + * - `void operator()(const U& value);` where the types `std::uint64_t` and + * `float` must be supported for U. + * - `void operator()(const CharT* value, std::size_t value_size);` + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, ...) of the types it serializes in the hands of the `Serializer` + * function object if compatibility is required. + */ + template + void serialize(Serializer& serializer) const { + m_ht.serialize(serializer); + } + + /** + * Deserialize a previously serialized set through the `deserializer` + * parameter. + * + * The `deserializer` parameter must be a function object that supports the + * following calls: + * - `template U operator()();` where the types `std::uint64_t` + * and `float` must be supported for U. + * - `void operator()(CharT* value_out, std::size_t value_size);` + * + * If the deserialized hash set part of the hat-trie is hash compatible with + * the serialized set, the deserialization process can be sped up by setting + * `hash_compatible` to true. To be hash compatible, the Hash (take care of + * the 32-bits vs 64 bits), and KeySizeT must behave the same than the ones + * used in the serialized set. Otherwise the behaviour is undefined with + * `hash_compatible` sets to true. + * + * The behaviour is undefined if the type `CharT` of the `htrie_set` is not + * the same as the type used during serialization. + * + * The implementation leaves binary compatibility (endianness, IEEE 754 for + * floats, size of int, ...) of the types it deserializes in the hands of the + * `Deserializer` function object if compatibility is required. + */ + template + static htrie_set deserialize(Deserializer& deserializer, + bool hash_compatible = false) { + htrie_set set; + set.m_ht.deserialize(deserializer, hash_compatible); + + return set; + } + + friend bool operator==(const htrie_set& lhs, const htrie_set& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + std::string key_buffer; + for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) { + it.key(key_buffer); + + const auto it_element_rhs = rhs.find(key_buffer); + if (it_element_rhs == rhs.cend()) { + return false; + } + } + + return true; + } + + friend bool operator!=(const htrie_set& lhs, const htrie_set& rhs) { + return !operator==(lhs, rhs); + } + + friend void swap(htrie_set& lhs, htrie_set& rhs) { lhs.swap(rhs); } + + private: + ht m_ht; +}; + +} // end namespace tsl + +#endif diff --git a/src/collection.cpp b/src/collection.cpp index de964f9e..aa8a4cc2 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -689,7 +689,10 @@ Option Collection::search(const std::string & raw_query, const s const size_t min_len_1typo, const size_t min_len_2typo, bool split_join_tokens, - const size_t max_candidates) const { + const size_t max_candidates, + const std::vector& infixes, + const size_t max_extra_prefix, + const size_t max_extra_suffix) const { std::shared_lock lock(mutex); @@ -721,6 +724,13 @@ Option Collection::search(const std::string & raw_query, const s } } + if(!search_fields.empty() && search_fields.size() != infixes.size()) { + if(infixes.size() != 1) { + return Option(400, "Number of infix values in `infix` does not match " + "number of `query_by` fields."); + } + } + // process weights for search fields std::vector weighted_search_fields; size_t max_weight = 20; @@ -1005,7 +1015,8 @@ Option Collection::search(const std::string & raw_query, const s group_by_fields, group_limit, default_sorting_field, prioritize_exact_match, exhaustive_search, 4, filter_overrides, search_stop_millis, - min_len_1typo, min_len_2typo, max_candidates); + min_len_1typo, min_len_2typo, max_candidates, infixes, + max_extra_prefix, max_extra_suffix); index->run_search(search_params); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index c2322b8c..0bc6695b 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -5,6 +5,7 @@ #include "collection_manager.h" #include "batched_indexer.h" #include "logger.h" +#include "magic_enum.hpp" constexpr const size_t CollectionManager::DEFAULT_NUM_MEMORY_SHARDS; @@ -567,6 +568,10 @@ Option CollectionManager::do_search(std::map& re const char *ENABLE_OVERRIDES = "enable_overrides"; const char *MAX_CANDIDATES = "max_candidates"; + const char *INFIX = "infix"; + const char *MAX_EXTRA_PREFIX = "max_extra_prefix"; + const char *MAX_EXTRA_SUFFIX = "max_extra_suffix"; + // strings under this length will be fully highlighted, instead of showing a snippet of relevant portion const char *SNIPPET_THRESHOLD = "snippet_threshold"; @@ -708,6 +713,14 @@ Option CollectionManager::do_search(std::map& re req_params[SPLIT_JOIN_TOKENS] = "true"; } + if(req_params.count(MAX_EXTRA_PREFIX) == 0) { + req_params[MAX_EXTRA_PREFIX] = std::to_string(INT16_MAX); + } + + if(req_params.count(MAX_EXTRA_SUFFIX) == 0) { + req_params[MAX_EXTRA_SUFFIX] = std::to_string(INT16_MAX); + } + std::vector query_by_weights_str; std::vector query_by_weights; @@ -781,6 +794,14 @@ Option CollectionManager::do_search(std::map& re return Option(400,"Parameter `" + std::string(SEARCH_CUTOFF_MS) + "` must be an unsigned integer."); } + if(!StringUtils::is_uint32_t(req_params[MAX_EXTRA_PREFIX])) { + return Option(400,"Parameter `" + std::string(MAX_EXTRA_PREFIX) + "` must be an unsigned integer."); + } + + if(!StringUtils::is_uint32_t(req_params[MAX_EXTRA_SUFFIX])) { + return Option(400,"Parameter `" + std::string(MAX_EXTRA_SUFFIX) + "` must be an unsigned integer."); + } + bool prioritize_exact_match = (req_params[PRIORITIZE_EXACT_MATCH] == "true"); bool pre_segmented_query = (req_params[PRE_SEGMENTED_QUERY] == "true"); bool exhaustive_search = (req_params[EXHAUSTIVE_SEARCH] == "true"); @@ -833,6 +854,21 @@ Option CollectionManager::do_search(std::map& re req_params[MAX_CANDIDATES] = exhaustive_search ? "10000" : "4"; } + std::vector infixes; + if(req_params.count(INFIX) != 0) { + std::vector infix_strs; + StringUtils::split(req_params[INFIX], infix_strs, ","); + + for(auto& infix_str: infix_strs) { + auto infix_op = magic_enum::enum_cast(infix_str); + if(infix_op.has_value()) { + infixes.push_back(infix_op.value()); + } + } + } else { + infixes.push_back(off); + } + bool enable_overrides = (req_params[ENABLE_OVERRIDES] == "true"); CollectionManager & collectionManager = CollectionManager::get_instance(); @@ -898,7 +934,10 @@ Option CollectionManager::do_search(std::map& re static_cast(std::stol(req_params[MIN_LEN_1TYPO])), static_cast(std::stol(req_params[MIN_LEN_2TYPO])), split_join_tokens, - static_cast(std::stol(req_params[MAX_CANDIDATES])) + static_cast(std::stol(req_params[MAX_CANDIDATES])), + infixes, + static_cast(std::stol(req_params[MAX_EXTRA_PREFIX])), + static_cast(std::stol(req_params[MAX_EXTRA_SUFFIX])) ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/index.cpp b/src/index.cpp index 825e7ae6..1843d700 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -72,6 +72,16 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* art_tree_init(ft); search_index.emplace(fname_field.second.faceted_name(), ft); } + + if(fname_field.second.infix) { + array_mapped_infix_t infix_sets(ARRAY_INFIX_DIM); + + for(auto& infix_set: infix_sets) { + infix_set = new tsl::htrie_set(); + } + + infix_index.emplace(fname_field.second.name, infix_sets); + } } for(const auto & pair: sort_schema) { @@ -139,6 +149,15 @@ Index::~Index() { sort_index.clear(); + for(auto& kv: infix_index) { + for(auto& infix_set: kv.second) { + delete infix_set; + infix_set = nullptr; + } + } + + infix_index.clear(); + for(auto& name_tree: str_sort_index) { delete name_tree.second; name_tree.second = nullptr; @@ -660,6 +679,12 @@ void Index::index_field_in_memory(const field& afield, std::vector for(auto &token_offsets: field_index_it->second.offsets) { token_to_doc_offsets[token_offsets.first].emplace_back(seq_id, record.points, token_offsets.second); + + if(afield.infix) { + auto strhash = StringUtils::hash_wy(token_offsets.first.c_str(), token_offsets.first.size()); + const auto& infix_sets = infix_index.at(afield.name); + infix_sets[strhash % 4]->insert(token_offsets.first); + } } } @@ -1638,7 +1663,10 @@ void Index::run_search(search_args* search_params) { search_params->search_cutoff_ms, search_params->min_len_1typo, search_params->min_len_2typo, - search_params->max_candidates); + search_params->max_candidates, + search_params->infixes, + search_params->max_extra_prefix, + search_params->max_extra_suffix); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -2013,6 +2041,59 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string& return false; } +void Index::search_infix(const std::string& query, const std::string& field_name, + std::vector& ids, const size_t max_extra_prefix, const size_t max_extra_suffix) const { + + auto infix_maps_it = infix_index.find(field_name); + + if(infix_maps_it == infix_index.end()) { + return ; + } + + auto infix_sets = infix_maps_it->second; + std::vector leaves; + + size_t num_processed = 0; + std::mutex m_process; + std::condition_variable cv_process; + + auto search_tree = search_index.at(field_name); + + for(auto infix_set: infix_sets) { + thread_pool->enqueue([infix_set, &leaves, search_tree, &query, max_extra_prefix, max_extra_suffix, + &num_processed, &m_process, &cv_process]() { + std::vector this_leaves; + std::string key_buffer; + + for(auto it = infix_set->begin(); it != infix_set->end(); it++) { + it.key(key_buffer); + auto start_index = key_buffer.find(query); + if(start_index != std::string::npos && start_index <= max_extra_prefix && + (key_buffer.size() - (start_index + query.size())) <= max_extra_suffix) { + art_leaf* l = (art_leaf *) art_search(search_tree, + (const unsigned char *) key_buffer.c_str(), + key_buffer.size()+1); + if(l != nullptr) { + this_leaves.push_back(l); + } + } + } + + std::unique_lock lock(m_process); + leaves.insert(leaves.end(), this_leaves.begin(), this_leaves.end()); + num_processed++; + cv_process.notify_one(); + }); + } + + std::unique_lock lock_process(m_process); + cv_process.wait(lock_process, [&](){ return num_processed == infix_sets.size(); }); + + for(auto leaf: leaves) { + posting_t::merge({leaf->values}, ids); + } +} + void Index::search(std::vector& field_query_tokens, const std::vector& search_fields, std::vector& filters, @@ -2040,7 +2121,10 @@ void Index::search(std::vector& field_query_tokens, const size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, - const size_t max_candidates) const { + const size_t max_candidates, + const std::vector& infixes, + const size_t max_extra_prefix, + const size_t max_extra_suffix) const { search_begin = std::chrono::high_resolution_clock::now(); search_stop_ms = search_cutoff_ms; @@ -2247,6 +2331,7 @@ void Index::search(std::vector& field_query_tokens, int field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0]; bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0]; + infix_t field_infix = (i < infixes.size()) ? infixes[i] : infixes[0]; // proceed to query search only when no filters are provided or when filtering produces results if(filters.empty() || actual_filter_ids_length > 0) { @@ -2282,6 +2367,34 @@ void Index::search(std::vector& field_query_tokens, query_hashes, token_order, field_prefix, drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, min_len_1typo, min_len_2typo, max_candidates); + + if(field_infix == always || (field_infix == fallback && field_num_results == 0)) { + std::vector infix_ids; + search_infix(query_tokens[0].value, field_name, infix_ids, max_extra_prefix, max_extra_suffix); + if(!infix_ids.empty()) { + int sort_order[3]; // 1 or -1 based on DESC or ASC respectively + std::array*, 3> field_values; + std::vector geopoint_indices; + populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values); + uint32_t token_bits = 255; + + for(auto seq_id: infix_ids) { + score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 2, + actual_topster, {}, groups_processed, seq_id, sort_order, field_values, + geopoint_indices, group_limit, group_by_fields, token_bits, + false, false, {}); + } + + std::sort(infix_ids.begin(), infix_ids.end()); + infix_ids.erase(std::unique( infix_ids.begin(), infix_ids.end() ), infix_ids.end()); + + uint32_t* new_all_result_ids = nullptr; + all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &infix_ids[0], + infix_ids.size(), &new_all_result_ids); + delete[] all_result_ids; + all_result_ids = new_all_result_ids; + } + } } else if(actual_filter_ids_length != 0) { // indicates exact match query curate_filtered_ids(filters, curated_ids, exclude_token_ids, @@ -3708,6 +3821,15 @@ void Index::refresh_schemas(const std::vector& new_fields) { search_index.emplace(new_field.faceted_name(), ft); } } + + if(new_field.infix) { + array_mapped_infix_t infix_sets(ARRAY_INFIX_DIM); + for(auto& infix_set: infix_sets) { + infix_set = new tsl::htrie_set(); + } + + infix_index.emplace(new_field.name, infix_sets); + } } } diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 7ec3fe08..a1fbdc98 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -581,6 +581,17 @@ void posting_list_t::merge(const std::vector& posting_lists, st sum_sizes += posting_list->num_ids(); } + if(its.size() == 1) { + result_ids.reserve(posting_lists[0]->ids_length); + auto it = posting_lists[0]->new_iterator(); + while(it.valid()) { + result_ids.push_back(it.id()); + it.next(); + } + + return ; + } + result_ids.reserve(sum_sizes); size_t num_lists = its.size(); diff --git a/test/collection_infix_search_test.cpp b/test/collection_infix_search_test.cpp new file mode 100644 index 00000000..14e34110 --- /dev/null +++ b/test/collection_infix_search_test.cpp @@ -0,0 +1,222 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class CollectionInfixSearchTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_infix"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(CollectionInfixSearchTest, InfixBasics) { + std::vector fields = {field("title", field_types::STRING, false, false, true, "", -1, 1), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "GH100037IN8900X"; + doc["points"] = 100; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("100037", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + // verify off behavior + + results = coll1->search("100037", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {off}).get(); + + ASSERT_EQ(0, results["found"].get()); + ASSERT_EQ(0, results["hits"].size()); + + // when fallback is used, only the prefix result is returned + + doc["id"] = "1"; + doc["title"] = "100037SG7120X"; + doc["points"] = 100; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + results = coll1->search("100037", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {fallback}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + // always behavior: both prefix and infix matches are returned but ranked below prefix match + + results = coll1->search("100037", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + ASSERT_TRUE(results["hits"][0]["text_match"].get() > results["hits"][1]["text_match"].get()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionInfixSearchTest, RespectPrefixAndSuffixLimits) { + std::vector fields = {field("title", field_types::STRING, false, false, true, "", -1, 1), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "GH100037IN8900X"; + doc["points"] = 100; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["id"] = "1"; + doc["title"] = "X100037SG89007120X"; + doc["points"] = 100; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + // check extra prefixes + + auto results = coll1->search("100037", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}, 1).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("100037", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}, 2).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + // check extra suffixes + results = coll1->search("8900", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}, INT16_MAX, 2).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("8900", + {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always}, INT16_MAX, 5).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionInfixSearchTest, InfixSpecificField) { + std::vector fields = {field("title", field_types::STRING, false, false, true, "", -1, 1), + field("description", field_types::STRING, false, false, true, "", -1, 1), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "GH100037IN8900X"; + doc["description"] = "foobar"; + doc["points"] = 100; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["id"] = "1"; + doc["title"] = "foobar"; + doc["description"] = "GH100037IN8900X"; + doc["points"] = 100; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("100037", + {"title", "description"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {always, off}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("100037", + {"title", "description"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {off, always}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + collectionManager.drop_collection("coll1"); +}