Infix basics.

This commit is contained in:
Kishore Nallan 2022-01-27 14:54:50 +05:30
parent 7ce33cc94e
commit ba101d0b40
16 changed files with 7534 additions and 12 deletions

View File

@ -89,6 +89,7 @@ FILE(GLOB SRC_FILES src/*.cpp ${DEP_ROOT_DIR}/${KAKASI_NAME}/data/*.cpp)
FILE(GLOB TEST_FILES test/*.cpp)
include_directories(include)
include_directories(include/tsl)
include_directories(/usr/local/include)
include_directories(${OPENSSL_INCLUDE_DIR})
include_directories(${CURL_INCLUDE_DIR})

View File

@ -395,7 +395,10 @@ public:
size_t min_len_1typo = 4,
size_t min_len_2typo = 7,
bool split_join_tokens = true,
size_t max_candidates = 4) const;
size_t max_candidates = 4,
const std::vector<infix_t>& infixes = {off},
const size_t max_extra_prefix = INT16_MAX,
const size_t max_extra_suffix = INT16_MAX) const;
Option<bool> get_filter_ids(const std::string & simple_filter_query,
std::vector<std::pair<size_t, uint32_t*>>& index_ids);

View File

@ -38,6 +38,7 @@ namespace fields {
static const std::string optional = "optional";
static const std::string index = "index";
static const std::string sort = "sort";
static const std::string infix = "infix";
static const std::string locale = "locale";
}
@ -49,9 +50,10 @@ struct field {
bool index;
std::string locale;
bool sort;
bool infix;
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1) :
bool index = true, std::string locale = "", int sort = -1, int infix = -1) :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale) {
if(sort != -1) {
@ -59,6 +61,8 @@ struct field {
} else {
this->sort = is_num_sort_field();
}
this->infix = (infix != -1) ? bool(infix) : false;
}
bool is_auto() const {
@ -361,6 +365,11 @@ struct field {
field_json[fields::name].get<std::string>() + std::string("` should be a boolean."));
}
if(field_json.count(fields::infix) != 0 && !field_json.at(fields::infix).is_boolean()) {
return Option<bool>(400, std::string("The `infix` property of the field `") +
field_json[fields::name].get<std::string>() + std::string("` should be a boolean."));
}
if(field_json.count(fields::locale) != 0){
if(!field_json.at(fields::locale).is_string()) {
return Option<bool>(400, std::string("The `locale` property of the field `") +
@ -395,6 +404,10 @@ struct field {
field_json[fields::sort] = false;
}
if(field_json.count(fields::infix) == 0) {
field_json[fields::infix] = false;
}
if(field_json[fields::optional] == false) {
return Option<bool>(400, "Field `.*` must be an optional field.");
}
@ -409,7 +422,7 @@ struct field {
field fallback_field(field_json["name"], field_json["type"], field_json["facet"],
field_json["optional"], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort]);
field_json[fields::sort], field_json[fields::infix]);
if(fallback_field.has_valid_type()) {
fallback_field_type = fallback_field.type;
@ -444,6 +457,10 @@ struct field {
}
}
if(field_json.count(fields::infix) == 0) {
field_json[fields::infix] = false;
}
if(field_json.count(fields::optional) == 0) {
// dynamic fields are always optional
bool is_dynamic = field::is_dynamic(field_json[fields::name], field_json[fields::type]);
@ -453,7 +470,7 @@ struct field {
fields.emplace_back(
field(field_json[fields::name], field_json[fields::type], field_json[fields::facet],
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort])
field_json[fields::sort], field_json[fields::infix])
);
}

View File

@ -22,11 +22,15 @@
#include "posting_list.h"
#include "threadpool.h"
#include "adi_tree.h"
#include "tsl/htrie_set.h"
static constexpr size_t ARRAY_FACET_DIM = 4;
using facet_map_t = spp::sparse_hash_map<uint32_t, facet_hash_values_t>;
using array_mapped_facet_t = std::array<facet_map_t*, ARRAY_FACET_DIM>;
static constexpr size_t ARRAY_INFIX_DIM = 4;
using array_mapped_infix_t = std::vector<tsl::htrie_set<char>*>;
struct token_t {
size_t position;
std::string value;
@ -258,6 +262,12 @@ struct override_t {
}
};
enum infix_t {
always,
fallback,
off
};
struct search_args {
std::vector<query_tokens_t> field_query_tokens;
std::vector<search_field_t> search_fields;
@ -287,6 +297,9 @@ struct search_args {
size_t min_len_1typo;
size_t min_len_2typo;
size_t max_candidates;
std::vector<infix_t> infixes;
const size_t max_extra_prefix;
const size_t max_extra_suffix;
spp::sparse_hash_set<uint64_t> groups_processed;
std::vector<std::vector<art_leaf*>> searched_queries;
@ -312,7 +325,10 @@ struct search_args {
size_t search_cutoff_ms,
size_t min_len_1typo,
size_t min_len_2typo,
size_t max_candidates):
size_t max_candidates,
const std::vector<infix_t>& infixes,
const size_t max_extra_prefix,
const size_t max_extra_suffix):
field_query_tokens(field_query_tokens),
search_fields(search_fields), filters(filters), facets(facets),
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
@ -323,7 +339,8 @@ struct search_args {
prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0),
exhaustive_search(exhaustive_search), concurrency(concurrency),
filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms),
min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates) {
min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates),
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix) {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
topster = new Topster(topster_size, group_limit);
@ -431,6 +448,9 @@ private:
// str_sort_field => adi_tree_t
spp::sparse_hash_map<std::string, adi_tree_t*> str_sort_index;
// infix field => value
spp::sparse_hash_map<std::string, array_mapped_infix_t> infix_index;
// geo_array_field => (seq_id => values) used for exact filtering of geo array records
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, int64_t*>*> geo_array_index;
@ -697,7 +717,10 @@ public:
size_t search_cutoff_ms,
size_t min_len_1typo,
size_t min_len_2typo,
size_t max_candidates) const;
size_t max_candidates,
const std::vector<infix_t>& infixes,
const size_t max_extra_prefix,
const size_t max_extra_suffix) const;
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document, const bool is_update);
@ -758,6 +781,9 @@ public:
uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids,
uint32_t filter_ids_length, const size_t concurrency) const;
void search_infix(const std::string& query, const std::string& field_name, std::vector<uint32_t>& ids,
size_t max_extra_prefix, size_t max_extra_suffix) const;
void curate_filtered_ids(const std::vector<filter>& filters, const std::set<uint32_t>& curated_ids,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, uint32_t*& filter_ids,
uint32_t& filter_ids_length, const std::vector<uint32_t>& curated_ids_sorted) const;

View File

@ -0,0 +1,296 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ARRAY_GROWTH_POLICY_H
#define TSL_ARRAY_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
namespace tsl {
namespace ah {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template <std::size_t GrowthFactor>
class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for
* the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* creation and bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out =
round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
} else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std::size_t next_bucket_count() const {
if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
throw std::length_error("The hash table exceeds its maximum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return (std::numeric_limits<std::size_t>::max() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void clear() noexcept { m_mask = 0; }
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if (is_power_of_two(value)) {
return value;
}
if (value == 0) {
return 1;
}
--value;
for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
"GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
template <class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
if (min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
} else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if (m_mod == max_bucket_count()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
const double next_bucket_count =
std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if (!std::isnormal(next_bucket_count)) {
throw std::length_error("The hash table exceeds its maximum size.");
}
if (next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
} else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
void clear() noexcept { m_mod = 1; }
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(std::numeric_limits<std::size_t>::max() /
REHASH_SIZE_MULTIPLICATION_FACTOR));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
"Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
static constexpr const std::array<std::size_t, 40> PRIMES = {
{1ul, 5ul, 17ul, 29ul, 37ul,
53ul, 67ul, 79ul, 97ul, 131ul,
193ul, 257ul, 389ul, 521ul, 769ul,
1031ul, 1543ul, 2053ul, 3079ul, 6151ul,
12289ul, 24593ul, 49157ul, 98317ul, 196613ul,
393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul,
12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}};
template <unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) {
return hash % PRIMES[IPrime];
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
static constexpr const std::array<std::size_t (*)(std::size_t), 40> MOD_PRIME =
{{&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>,
&mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>,
&mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>,
&mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
&mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}};
} // namespace detail
/**
* Grow the hash table by using prime numbers as bucket count. Slower than
* tsl::ah::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize
* the operation by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environment.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
auto it_prime = std::lower_bound(
detail::PRIMES.begin(), detail::PRIMES.end(), min_bucket_count_in_out);
if (it_prime == detail::PRIMES.end()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
m_iprime = static_cast<unsigned int>(
std::distance(detail::PRIMES.begin(), it_prime));
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
} else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return detail::MOD_PRIME[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if (m_iprime + 1 >= detail::PRIMES.size()) {
throw std::length_error("The hash table exceeds its maximum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const { return detail::PRIMES.back(); }
void clear() noexcept { m_iprime = 0; }
private:
unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >=
detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
} // namespace ah
} // namespace tsl
#endif

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,929 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ARRAY_MAP_H
#define TSL_ARRAY_MAP_H
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <string>
#include <type_traits>
#include <utility>
#include "array_hash.h"
namespace tsl {
/**
* Implementation of a cache-conscious string hash map.
*
* The map stores the strings as `const CharT*`. If `StoreNullTerminator` is
* true, the strings are stored with the a null-terminator (the `key()` method
* of the iterators will return a pointer to this null-terminated string).
* Otherwise the null character is not stored (which allow an economy of 1 byte
* per string).
*
* The value `T` must be either nothrow move-constructible, copy-constructible
* or both.
*
* The size of a key string is limited to `std::numeric_limits<KeySizeT>::max()
* - 1`. That is 65 535 characters by default, but can be raised with the
* `KeySizeT` template parameter. See `max_key_size()` for an easy access to
* this limit.
*
* The number of elements in the map is limited to
* `std::numeric_limits<IndexSizeT>::max()`. That is 4 294 967 296 elements, but
* can be raised with the `IndexSizeT` template parameter. See `max_size()` for
* an easy access to this limit.
*
* Iterators invalidation:
* - clear, operator=: always invalidate the iterators.
* - insert, emplace, operator[]: always invalidate the iterators.
* - erase: always invalidate the iterators.
* - shrink_to_fit: always invalidate the iterators.
*/
template <class CharT, class T, class Hash = tsl::ah::str_hash<CharT>,
class KeyEqual = tsl::ah::str_equal<CharT>,
bool StoreNullTerminator = true, class KeySizeT = std::uint16_t,
class IndexSizeT = std::uint32_t,
class GrowthPolicy = tsl::ah::power_of_two_growth_policy<2>>
class array_map {
private:
template <typename U>
using is_iterator = tsl::detail_array_hash::is_iterator<U>;
using ht = tsl::detail_array_hash::array_hash<CharT, T, Hash, KeyEqual,
StoreNullTerminator, KeySizeT,
IndexSizeT, GrowthPolicy>;
public:
using char_type = typename ht::char_type;
using mapped_type = T;
using key_size_type = typename ht::key_size_type;
using index_size_type = typename ht::index_size_type;
using size_type = typename ht::size_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
public:
array_map() : array_map(ht::DEFAULT_INIT_BUCKET_COUNT) {}
explicit array_map(size_type bucket_count, const Hash& hash = Hash())
: m_ht(bucket_count, hash, ht::DEFAULT_MAX_LOAD_FACTOR) {}
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
array_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash& hash = Hash())
: array_map(bucket_count, hash) {
insert(first, last);
}
#ifdef TSL_AH_HAS_STRING_VIEW
array_map(
std::initializer_list<std::pair<std::basic_string_view<CharT>, T>> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash& hash = Hash())
: array_map(bucket_count, hash) {
insert(init);
}
#else
array_map(std::initializer_list<std::pair<const CharT*, T>> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash& hash = Hash())
: array_map(bucket_count, hash) {
insert(init);
}
#endif
#ifdef TSL_AH_HAS_STRING_VIEW
array_map& operator=(
std::initializer_list<std::pair<std::basic_string_view<CharT>, T>>
ilist) {
clear();
reserve(ilist.size());
insert(ilist);
return *this;
}
#else
array_map& operator=(
std::initializer_list<std::pair<const CharT*, T>> ilist) {
clear();
reserve(ilist.size());
insert(ilist);
return *this;
}
#endif
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
size_type max_key_size() const noexcept { return m_ht.max_key_size(); }
void shrink_to_fit() { m_ht.shrink_to_fit(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
#ifdef TSL_AH_HAS_STRING_VIEW
std::pair<iterator, bool> insert(const std::basic_string_view<CharT>& key,
const T& value) {
return m_ht.emplace(key.data(), key.size(), value);
}
#else
std::pair<iterator, bool> insert(const CharT* key, const T& value) {
return m_ht.emplace(key, std::char_traits<CharT>::length(key), value);
}
std::pair<iterator, bool> insert(const std::basic_string<CharT>& key,
const T& value) {
return m_ht.emplace(key.data(), key.size(), value);
}
#endif
std::pair<iterator, bool> insert_ks(const CharT* key, size_type key_size,
const T& value) {
return m_ht.emplace(key, key_size, value);
}
#ifdef TSL_AH_HAS_STRING_VIEW
std::pair<iterator, bool> insert(const std::basic_string_view<CharT>& key,
T&& value) {
return m_ht.emplace(key.data(), key.size(), std::move(value));
}
#else
std::pair<iterator, bool> insert(const CharT* key, T&& value) {
return m_ht.emplace(key, std::char_traits<CharT>::length(key),
std::move(value));
}
std::pair<iterator, bool> insert(const std::basic_string<CharT>& key,
T&& value) {
return m_ht.emplace(key.data(), key.size(), std::move(value));
}
#endif
std::pair<iterator, bool> insert_ks(const CharT* key, size_type key_size,
T&& value) {
return m_ht.emplace(key, key_size, std::move(value));
}
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
void insert(InputIt first, InputIt last) {
if (std::is_base_of<
std::forward_iterator_tag,
typename std::iterator_traits<InputIt>::iterator_category>::value) {
const auto nb_elements_insert = std::distance(first, last);
const std::size_t nb_free_buckets =
std::size_t(float(bucket_count()) * max_load_factor()) - size();
if (nb_elements_insert > 0 &&
nb_free_buckets < std::size_t(nb_elements_insert)) {
reserve(size() + std::size_t(nb_elements_insert));
}
}
for (auto it = first; it != last; ++it) {
insert_pair(*it);
}
}
#ifdef TSL_AH_HAS_STRING_VIEW
void insert(std::initializer_list<std::pair<std::basic_string_view<CharT>, T>>
ilist) {
insert(ilist.begin(), ilist.end());
}
#else
void insert(std::initializer_list<std::pair<const CharT*, T>> ilist) {
insert(ilist.begin(), ilist.end());
}
#endif
#ifdef TSL_AH_HAS_STRING_VIEW
template <class M>
std::pair<iterator, bool> insert_or_assign(
const std::basic_string_view<CharT>& key, M&& obj) {
return m_ht.insert_or_assign(key.data(), key.size(), std::forward<M>(obj));
}
#else
template <class M>
std::pair<iterator, bool> insert_or_assign(const CharT* key, M&& obj) {
return m_ht.insert_or_assign(key, std::char_traits<CharT>::length(key),
std::forward<M>(obj));
}
template <class M>
std::pair<iterator, bool> insert_or_assign(
const std::basic_string<CharT>& key, M&& obj) {
return m_ht.insert_or_assign(key.data(), key.size(), std::forward<M>(obj));
}
#endif
template <class M>
std::pair<iterator, bool> insert_or_assign_ks(const CharT* key,
size_type key_size, M&& obj) {
return m_ht.insert_or_assign(key, key_size, std::forward<M>(obj));
}
#ifdef TSL_AH_HAS_STRING_VIEW
template <class... Args>
std::pair<iterator, bool> emplace(const std::basic_string_view<CharT>& key,
Args&&... args) {
return m_ht.emplace(key.data(), key.size(), std::forward<Args>(args)...);
}
#else
template <class... Args>
std::pair<iterator, bool> emplace(const CharT* key, Args&&... args) {
return m_ht.emplace(key, std::char_traits<CharT>::length(key),
std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> emplace(const std::basic_string<CharT>& key,
Args&&... args) {
return m_ht.emplace(key.data(), key.size(), std::forward<Args>(args)...);
}
#endif
template <class... Args>
std::pair<iterator, bool> emplace_ks(const CharT* key, size_type key_size,
Args&&... args) {
return m_ht.emplace(key, key_size, std::forward<Args>(args)...);
}
/**
* Erase has an amortized O(1) runtime complexity, but even if it removes the
* key immediately, it doesn't do the same for the associated value T.
*
* T will only be removed when the ratio between the size of the map and
* the size of the map + the number of deleted values still stored is low
* enough.
*
* To force the deletion you can call shrink_to_fit.
*/
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
/**
* @copydoc erase(const_iterator pos)
*/
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc erase(const_iterator pos)
*/
size_type erase(const std::basic_string_view<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#else
/**
* @copydoc erase(const_iterator pos)
*/
size_type erase(const CharT* key) {
return m_ht.erase(key, std::char_traits<CharT>::length(key));
}
/**
* @copydoc erase(const_iterator pos)
*/
size_type erase(const std::basic_string<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#endif
/**
* @copydoc erase(const_iterator pos)
*/
size_type erase_ks(const CharT* key, size_type key_size) {
return m_ht.erase(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
size_type erase(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.erase(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
size_type erase(const CharT* key, std::size_t precalculated_hash) {
return m_ht.erase(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
size_type erase(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.erase(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* @copydoc erase(const_iterator pos)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
size_type erase_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) {
return m_ht.erase(key, key_size, precalculated_hash);
}
void swap(array_map& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
#ifdef TSL_AH_HAS_STRING_VIEW
T& at(const std::basic_string_view<CharT>& key) {
return m_ht.at(key.data(), key.size());
}
const T& at(const std::basic_string_view<CharT>& key) const {
return m_ht.at(key.data(), key.size());
}
#else
T& at(const CharT* key) {
return m_ht.at(key, std::char_traits<CharT>::length(key));
}
const T& at(const CharT* key) const {
return m_ht.at(key, std::char_traits<CharT>::length(key));
}
T& at(const std::basic_string<CharT>& key) {
return m_ht.at(key.data(), key.size());
}
const T& at(const std::basic_string<CharT>& key) const {
return m_ht.at(key.data(), key.size());
}
#endif
T& at_ks(const CharT* key, size_type key_size) {
return m_ht.at(key, key_size);
}
const T& at_ks(const CharT* key, size_type key_size) const {
return m_ht.at(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
T& at(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.at(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const T& at(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.at(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
T& at(const CharT* key, std::size_t precalculated_hash) {
return m_ht.at(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const T& at(const CharT* key, std::size_t precalculated_hash) const {
return m_ht.at(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
T& at(const std::basic_string<CharT>& key, std::size_t precalculated_hash) {
return m_ht.at(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const T& at(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.at(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
T& at_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) {
return m_ht.at(key, key_size, precalculated_hash);
}
/**
* @copydoc at_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const T& at_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.at(key, key_size, precalculated_hash);
}
#ifdef TSL_AH_HAS_STRING_VIEW
T& operator[](const std::basic_string_view<CharT>& key) {
return m_ht.access_operator(key.data(), key.size());
}
#else
T& operator[](const CharT* key) {
return m_ht.access_operator(key, std::char_traits<CharT>::length(key));
}
T& operator[](const std::basic_string<CharT>& key) {
return m_ht.access_operator(key.data(), key.size());
}
#endif
#ifdef TSL_AH_HAS_STRING_VIEW
size_type count(const std::basic_string_view<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#else
size_type count(const CharT* key) const {
return m_ht.count(key, std::char_traits<CharT>::length(key));
}
size_type count(const std::basic_string<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#endif
size_type count_ks(const CharT* key, size_type key_size) const {
return m_ht.count(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc count_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash) const
*/
size_type count(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.count(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc count_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash) const
*/
size_type count(const CharT* key, std::size_t precalculated_hash) const {
return m_ht.count(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc count_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash) const
*/
size_type count(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.count(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
size_type count_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.count(key, key_size, precalculated_hash);
}
#ifdef TSL_AH_HAS_STRING_VIEW
iterator find(const std::basic_string_view<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string_view<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#else
iterator find(const CharT* key) {
return m_ht.find(key, std::char_traits<CharT>::length(key));
}
const_iterator find(const CharT* key) const {
return m_ht.find(key, std::char_traits<CharT>::length(key));
}
iterator find(const std::basic_string<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#endif
iterator find_ks(const CharT* key, size_type key_size) {
return m_ht.find(key, key_size);
}
const_iterator find_ks(const CharT* key, size_type key_size) const {
return m_ht.find(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
iterator find(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
iterator find(const CharT* key, std::size_t precalculated_hash) {
return m_ht.find(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find(const CharT* key, std::size_t precalculated_hash) const {
return m_ht.find(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
iterator find(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
iterator find_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) {
return m_ht.find(key, key_size, precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.find(key, key_size, precalculated_hash);
}
#ifdef TSL_AH_HAS_STRING_VIEW
std::pair<iterator, iterator> equal_range(
const std::basic_string_view<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string_view<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#else
std::pair<iterator, iterator> equal_range(const CharT* key) {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key));
}
std::pair<const_iterator, const_iterator> equal_range(
const CharT* key) const {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key));
}
std::pair<iterator, iterator> equal_range(
const std::basic_string<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#endif
std::pair<iterator, iterator> equal_range_ks(const CharT* key,
size_type key_size) {
return m_ht.equal_range(key, key_size);
}
std::pair<const_iterator, const_iterator> equal_range_ks(
const CharT* key, size_type key_size) const {
return m_ht.equal_range(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<iterator, iterator> equal_range(
const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<iterator, iterator> equal_range(const CharT* key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const CharT* key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<iterator, iterator> equal_range(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
std::pair<iterator, iterator> equal_range_ks(const CharT* key,
size_type key_size,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, key_size, precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range_ks(
const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.equal_range(key, key_size, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Return the `const_iterator it` as an `iterator`.
*/
iterator mutable_iterator(const_iterator it) noexcept {
return m_ht.mutable_iterator(it);
}
/**
* Serialize the map through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> void operator()(const U& value);` where the types
* `std::uint64_t`, `float` and `T` must be supported for U.
* - `void operator()(const CharT* value, std::size_t value_size);`
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer& serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized map through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> U operator()();` where the types `std::uint64_t`,
* `float` and `T` must be supported for U.
* - `void operator()(CharT* value_out, std::size_t value_size);`
*
* If the deserialized hash map type is hash compatible with the serialized
* map, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash (take care of
* the 32-bits vs 64 bits), KeyEqual, GrowthPolicy, StoreNullTerminator,
* KeySizeT and IndexSizeT must behave the same than the ones used on the
* serialized map. Otherwise the behaviour is undefined with `hash_compatible`
* sets to true.
*
* The behaviour is undefined if the type `CharT` and `T` of the `array_map`
* are not the same as the types used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static array_map deserialize(Deserializer& deserializer,
bool hash_compatible = false) {
array_map map(0);
map.m_ht.deserialize(deserializer, hash_compatible);
return map;
}
friend bool operator==(const array_map& lhs, const array_map& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) {
const auto it_element_rhs = rhs.find_ks(it.key(), it.key_size());
if (it_element_rhs == rhs.cend() ||
it.value() != it_element_rhs.value()) {
return false;
}
}
return true;
}
friend bool operator!=(const array_map& lhs, const array_map& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(array_map& lhs, array_map& rhs) { lhs.swap(rhs); }
private:
template <class U, class V>
void insert_pair(const std::pair<U, V>& value) {
insert(value.first, value.second);
}
template <class U, class V>
void insert_pair(std::pair<U, V>&& value) {
insert(value.first, std::move(value.second));
}
public:
static const size_type MAX_KEY_SIZE = ht::MAX_KEY_SIZE;
private:
ht m_ht;
};
/**
* Same as
* `tsl::array_map<CharT, T, Hash, KeyEqual, StoreNullTerminator, KeySizeT,
* IndexSizeT, tsl::ah::prime_growth_policy>`.
*/
template <class CharT, class T, class Hash = tsl::ah::str_hash<CharT>,
class KeyEqual = tsl::ah::str_equal<CharT>,
bool StoreNullTerminator = true, class KeySizeT = std::uint16_t,
class IndexSizeT = std::uint32_t>
using array_pg_map =
array_map<CharT, T, Hash, KeyEqual, StoreNullTerminator, KeySizeT,
IndexSizeT, tsl::ah::prime_growth_policy>;
} // end namespace tsl
#endif

View File

@ -0,0 +1,716 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ARRAY_SET_H
#define TSL_ARRAY_SET_H
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <string>
#include <type_traits>
#include <utility>
#include "array_hash.h"
namespace tsl {
/**
* Implementation of a cache-conscious string hash set.
*
* The set stores the strings as `const CharT*`. If `StoreNullTerminator` is
* true, the strings are stored with the a null-terminator (the `key()` method
* of the iterators will return a pointer to this null-terminated string).
* Otherwise the null character is not stored (which allow an economy of 1 byte
* per string).
*
* The size of a key string is limited to `std::numeric_limits<KeySizeT>::max()
* - 1`. That is 65 535 characters by default, but can be raised with the
* `KeySizeT` template parameter. See `max_key_size()` for an easy access to
* this limit.
*
* The number of elements in the set is limited to
* `std::numeric_limits<IndexSizeT>::max()`. That is 4 294 967 296 elements, but
* can be raised with the `IndexSizeT` template parameter. See `max_size()` for
* an easy access to this limit.
*
* Iterators invalidation:
* - clear, operator=: always invalidate the iterators.
* - insert, emplace, operator[]: always invalidate the iterators.
* - erase: always invalidate the iterators.
* - shrink_to_fit: always invalidate the iterators.
*/
template <class CharT, class Hash = tsl::ah::str_hash<CharT>,
class KeyEqual = tsl::ah::str_equal<CharT>,
bool StoreNullTerminator = true, class KeySizeT = std::uint16_t,
class IndexSizeT = std::uint32_t,
class GrowthPolicy = tsl::ah::power_of_two_growth_policy<2>>
class array_set {
private:
template <typename U>
using is_iterator = tsl::detail_array_hash::is_iterator<U>;
using ht = tsl::detail_array_hash::array_hash<CharT, void, Hash, KeyEqual,
StoreNullTerminator, KeySizeT,
IndexSizeT, GrowthPolicy>;
public:
using char_type = typename ht::char_type;
using key_size_type = typename ht::key_size_type;
using index_size_type = typename ht::index_size_type;
using size_type = typename ht::size_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
array_set() : array_set(ht::DEFAULT_INIT_BUCKET_COUNT) {}
explicit array_set(size_type bucket_count, const Hash& hash = Hash())
: m_ht(bucket_count, hash, ht::DEFAULT_MAX_LOAD_FACTOR) {}
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
array_set(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash& hash = Hash())
: array_set(bucket_count, hash) {
insert(first, last);
}
#ifdef TSL_AH_HAS_STRING_VIEW
array_set(std::initializer_list<std::basic_string_view<CharT>> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash& hash = Hash())
: array_set(bucket_count, hash) {
insert(init);
}
#else
array_set(std::initializer_list<const CharT*> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKET_COUNT,
const Hash& hash = Hash())
: array_set(bucket_count, hash) {
insert(init);
}
#endif
#ifdef TSL_AH_HAS_STRING_VIEW
array_set& operator=(
std::initializer_list<std::basic_string_view<CharT>> ilist) {
clear();
reserve(ilist.size());
insert(ilist);
return *this;
}
#else
array_set& operator=(std::initializer_list<const CharT*> ilist) {
clear();
reserve(ilist.size());
insert(ilist);
return *this;
}
#endif
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
size_type max_key_size() const noexcept { return m_ht.max_key_size(); }
void shrink_to_fit() { m_ht.shrink_to_fit(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
#ifdef TSL_AH_HAS_STRING_VIEW
std::pair<iterator, bool> insert(const std::basic_string_view<CharT>& key) {
return m_ht.emplace(key.data(), key.size());
}
#else
std::pair<iterator, bool> insert(const CharT* key) {
return m_ht.emplace(key, std::char_traits<CharT>::length(key));
}
std::pair<iterator, bool> insert(const std::basic_string<CharT>& key) {
return m_ht.emplace(key.data(), key.size());
}
#endif
std::pair<iterator, bool> insert_ks(const CharT* key, size_type key_size) {
return m_ht.emplace(key, key_size);
}
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
void insert(InputIt first, InputIt last) {
if (std::is_base_of<
std::forward_iterator_tag,
typename std::iterator_traits<InputIt>::iterator_category>::value) {
const auto nb_elements_insert = std::distance(first, last);
const std::size_t nb_free_buckets =
std::size_t(float(bucket_count()) * max_load_factor()) - size();
if (nb_elements_insert > 0 &&
nb_free_buckets < std::size_t(nb_elements_insert)) {
reserve(size() + std::size_t(nb_elements_insert));
}
}
for (auto it = first; it != last; ++it) {
insert(*it);
}
}
#ifdef TSL_AH_HAS_STRING_VIEW
void insert(std::initializer_list<std::basic_string_view<CharT>> ilist) {
insert(ilist.begin(), ilist.end());
}
#else
void insert(std::initializer_list<const CharT*> ilist) {
insert(ilist.begin(), ilist.end());
}
#endif
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc emplace_ks(const CharT* key, size_type key_size)
*/
std::pair<iterator, bool> emplace(const std::basic_string_view<CharT>& key) {
return m_ht.emplace(key.data(), key.size());
}
#else
/**
* @copydoc emplace_ks(const CharT* key, size_type key_size)
*/
std::pair<iterator, bool> emplace(const CharT* key) {
return m_ht.emplace(key, std::char_traits<CharT>::length(key));
}
/**
* @copydoc emplace_ks(const CharT* key, size_type key_size)
*/
std::pair<iterator, bool> emplace(const std::basic_string<CharT>& key) {
return m_ht.emplace(key.data(), key.size());
}
#endif
/**
* No difference compared to the insert method. Mainly here for coherence with
* array_map.
*/
std::pair<iterator, bool> emplace_ks(const CharT* key, size_type key_size) {
return m_ht.emplace(key, key_size);
}
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
#ifdef TSL_AH_HAS_STRING_VIEW
size_type erase(const std::basic_string_view<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#else
size_type erase(const CharT* key) {
return m_ht.erase(key, std::char_traits<CharT>::length(key));
}
size_type erase(const std::basic_string<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#endif
size_type erase_ks(const CharT* key, size_type key_size) {
return m_ht.erase(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
size_type erase(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.erase(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
size_type erase(const CharT* key, std::size_t precalculated_hash) {
return m_ht.erase(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc erase_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
size_type erase(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.erase(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
size_type erase_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) {
return m_ht.erase(key, key_size, precalculated_hash);
}
void swap(array_set& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
#ifdef TSL_AH_HAS_STRING_VIEW
size_type count(const std::basic_string_view<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#else
size_type count(const CharT* key) const {
return m_ht.count(key, std::char_traits<CharT>::length(key));
}
size_type count(const std::basic_string<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#endif
size_type count_ks(const CharT* key, size_type key_size) const {
return m_ht.count(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc count_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash) const
*/
size_type count(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.count(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc count_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash) const
*/
size_type count(const CharT* key, std::size_t precalculated_hash) const {
return m_ht.count(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc count_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash) const
*/
size_type count(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.count(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
size_type count_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.count(key, key_size, precalculated_hash);
}
#ifdef TSL_AH_HAS_STRING_VIEW
iterator find(const std::basic_string_view<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string_view<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#else
iterator find(const CharT* key) {
return m_ht.find(key, std::char_traits<CharT>::length(key));
}
const_iterator find(const CharT* key) const {
return m_ht.find(key, std::char_traits<CharT>::length(key));
}
iterator find(const std::basic_string<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#endif
iterator find_ks(const CharT* key, size_type key_size) {
return m_ht.find(key, key_size);
}
const_iterator find_ks(const CharT* key, size_type key_size) const {
return m_ht.find(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
iterator find(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find(const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
iterator find(const CharT* key, std::size_t precalculated_hash) {
return m_ht.find(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find(const CharT* key, std::size_t precalculated_hash) const {
return m_ht.find(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
iterator find(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.find(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
iterator find_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) {
return m_ht.find(key, key_size, precalculated_hash);
}
/**
* @copydoc find_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
const_iterator find_ks(const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.find(key, key_size, precalculated_hash);
}
#ifdef TSL_AH_HAS_STRING_VIEW
std::pair<iterator, iterator> equal_range(
const std::basic_string_view<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string_view<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#else
std::pair<iterator, iterator> equal_range(const CharT* key) {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key));
}
std::pair<const_iterator, const_iterator> equal_range(
const CharT* key) const {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key));
}
std::pair<iterator, iterator> equal_range(
const std::basic_string<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#endif
std::pair<iterator, iterator> equal_range_ks(const CharT* key,
size_type key_size) {
return m_ht.equal_range(key, key_size);
}
std::pair<const_iterator, const_iterator> equal_range_ks(
const CharT* key, size_type key_size) const {
return m_ht.equal_range(key, key_size);
}
#ifdef TSL_AH_HAS_STRING_VIEW
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<iterator, iterator> equal_range(
const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string_view<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
#else
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<iterator, iterator> equal_range(const CharT* key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const CharT* key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, std::char_traits<CharT>::length(key),
precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<iterator, iterator> equal_range(const std::basic_string<CharT>& key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string<CharT>& key,
std::size_t precalculated_hash) const {
return m_ht.equal_range(key.data(), key.size(), precalculated_hash);
}
#endif
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Useful to speed-up
* the lookup to the value if you already have the hash.
*/
std::pair<iterator, iterator> equal_range_ks(const CharT* key,
size_type key_size,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, key_size, precalculated_hash);
}
/**
* @copydoc equal_range_ks(const CharT* key, size_type key_size, std::size_t
* precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range_ks(
const CharT* key, size_type key_size,
std::size_t precalculated_hash) const {
return m_ht.equal_range(key, key_size, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Return the `const_iterator it` as an `iterator`.
*/
iterator mutable_iterator(const_iterator it) noexcept {
return m_ht.mutable_iterator(it);
}
/**
* Serialize the set through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> void operator()(const U& value);` where the types
* `std::uint64_t` and `float` must be supported for U.
* - `void operator()(const CharT* value, std::size_t value_size);`
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer& serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized set through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> U operator()();` where the types `std::uint64_t`
* and `float` must be supported for U.
* - `void operator()(CharT* value_out, std::size_t value_size);`
*
* If the deserialized hash set type is hash compatible with the serialized
* set, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash (take care of
* the 32-bits vs 64 bits), KeyEqual, GrowthPolicy, StoreNullTerminator,
* KeySizeT and IndexSizeT must behave the same than the ones used on the
* serialized set. Otherwise the behaviour is undefined with `hash_compatible`
* sets to true.
*
* The behaviour is undefined if the type `CharT` of the `array_set` is not
* the same as the type used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static array_set deserialize(Deserializer& deserializer,
bool hash_compatible = false) {
array_set set(0);
set.m_ht.deserialize(deserializer, hash_compatible);
return set;
}
friend bool operator==(const array_set& lhs, const array_set& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) {
const auto it_element_rhs = rhs.find_ks(it.key(), it.key_size());
if (it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
friend bool operator!=(const array_set& lhs, const array_set& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(array_set& lhs, array_set& rhs) { lhs.swap(rhs); }
public:
static const size_type MAX_KEY_SIZE = ht::MAX_KEY_SIZE;
private:
ht m_ht;
};
/**
* Same as
* `tsl::array_set<CharT, Hash, KeyEqual, StoreNullTerminator, KeySizeT,
* IndexSizeT, tsl::ah::prime_growth_policy>`.
*/
template <class CharT, class Hash = tsl::ah::str_hash<CharT>,
class KeyEqual = tsl::ah::str_equal<CharT>,
bool StoreNullTerminator = true, class KeySizeT = std::uint16_t,
class IndexSizeT = std::uint32_t>
using array_pg_set =
array_set<CharT, Hash, KeyEqual, StoreNullTerminator, KeySizeT, IndexSizeT,
tsl::ah::prime_growth_policy>;
} // end namespace tsl
#endif

2076
include/tsl/htrie_hash.h Normal file

File diff suppressed because it is too large Load Diff

668
include/tsl/htrie_map.h Normal file
View File

@ -0,0 +1,668 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HTRIE_MAP_H
#define TSL_HTRIE_MAP_H
#include <cstddef>
#include <cstring>
#include <initializer_list>
#include <string>
#include <utility>
#include "htrie_hash.h"
namespace tsl {
/**
* Implementation of a hat-trie map.
*
* The value T must be either nothrow move-constructible/assignable,
* copy-constructible or both.
*
* The size of a key string is limited to std::numeric_limits<KeySizeT>::max()
* - 1. That is 65 535 characters by default, but can be raised with the
* KeySizeT template parameter. See max_key_size() for an easy access to this
* limit.
*
* Iterators invalidation:
* - clear, operator=: always invalidate the iterators.
* - insert, emplace, operator[]: always invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template <class CharT, class T, class Hash = tsl::ah::str_hash<CharT>,
class KeySizeT = std::uint16_t>
class htrie_map {
private:
template <typename U>
using is_iterator = tsl::detail_array_hash::is_iterator<U>;
using ht = tsl::detail_htrie_hash::htrie_hash<CharT, T, Hash, KeySizeT>;
public:
using char_type = typename ht::char_type;
using mapped_type = T;
using key_size_type = typename ht::key_size_type;
using size_type = typename ht::size_type;
using hasher = typename ht::hasher;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
using prefix_iterator = typename ht::prefix_iterator;
using const_prefix_iterator = typename ht::const_prefix_iterator;
public:
explicit htrie_map(const Hash& hash = Hash())
: m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR,
ht::DEFAULT_BURST_THRESHOLD) {}
explicit htrie_map(size_type burst_threshold, const Hash& hash = Hash())
: m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR, burst_threshold) {}
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
htrie_map(InputIt first, InputIt last, const Hash& hash = Hash())
: htrie_map(hash) {
insert(first, last);
}
#ifdef TSL_HT_HAS_STRING_VIEW
htrie_map(
std::initializer_list<std::pair<std::basic_string_view<CharT>, T>> init,
const Hash& hash = Hash())
: htrie_map(hash) {
insert(init);
}
#else
htrie_map(std::initializer_list<std::pair<const CharT*, T>> init,
const Hash& hash = Hash())
: htrie_map(hash) {
insert(init);
}
#endif
#ifdef TSL_HT_HAS_STRING_VIEW
htrie_map& operator=(
std::initializer_list<std::pair<std::basic_string_view<CharT>, T>>
ilist) {
clear();
insert(ilist);
return *this;
}
#else
htrie_map& operator=(
std::initializer_list<std::pair<const CharT*, T>> ilist) {
clear();
insert(ilist);
return *this;
}
#endif
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
size_type max_key_size() const noexcept { return m_ht.max_key_size(); }
/**
* Call shrink_to_fit() on each hash node of the hat-trie to reduce its size.
*/
void shrink_to_fit() { m_ht.shrink_to_fit(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert_ks(const CharT* key, size_type key_size,
const T& value) {
return m_ht.insert(key, key_size, value);
}
#ifdef TSL_HT_HAS_STRING_VIEW
std::pair<iterator, bool> insert(const std::basic_string_view<CharT>& key,
const T& value) {
return m_ht.insert(key.data(), key.size(), value);
}
#else
std::pair<iterator, bool> insert(const CharT* key, const T& value) {
return m_ht.insert(key, std::strlen(key), value);
}
std::pair<iterator, bool> insert(const std::basic_string<CharT>& key,
const T& value) {
return m_ht.insert(key.data(), key.size(), value);
}
#endif
std::pair<iterator, bool> insert_ks(const CharT* key, size_type key_size,
T&& value) {
return m_ht.insert(key, key_size, std::move(value));
}
#ifdef TSL_HT_HAS_STRING_VIEW
std::pair<iterator, bool> insert(const std::basic_string_view<CharT>& key,
T&& value) {
return m_ht.insert(key.data(), key.size(), std::move(value));
}
#else
std::pair<iterator, bool> insert(const CharT* key, T&& value) {
return m_ht.insert(key, std::strlen(key), std::move(value));
}
std::pair<iterator, bool> insert(const std::basic_string<CharT>& key,
T&& value) {
return m_ht.insert(key.data(), key.size(), std::move(value));
}
#endif
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
void insert(InputIt first, InputIt last) {
for (auto it = first; it != last; ++it) {
insert_pair(*it);
}
}
#ifdef TSL_HT_HAS_STRING_VIEW
void insert(std::initializer_list<std::pair<std::basic_string_view<CharT>, T>>
ilist) {
insert(ilist.begin(), ilist.end());
}
#else
void insert(std::initializer_list<std::pair<const CharT*, T>> ilist) {
insert(ilist.begin(), ilist.end());
}
#endif
template <class... Args>
std::pair<iterator, bool> emplace_ks(const CharT* key, size_type key_size,
Args&&... args) {
return m_ht.insert(key, key_size, std::forward<Args>(args)...);
}
#ifdef TSL_HT_HAS_STRING_VIEW
template <class... Args>
std::pair<iterator, bool> emplace(const std::basic_string_view<CharT>& key,
Args&&... args) {
return m_ht.insert(key.data(), key.size(), std::forward<Args>(args)...);
}
#else
template <class... Args>
std::pair<iterator, bool> emplace(const CharT* key, Args&&... args) {
return m_ht.insert(key, std::strlen(key), std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> emplace(const std::basic_string<CharT>& key,
Args&&... args) {
return m_ht.insert(key.data(), key.size(), std::forward<Args>(args)...);
}
#endif
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase_ks(const CharT* key, size_type key_size) {
return m_ht.erase(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
size_type erase(const std::basic_string_view<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#else
size_type erase(const CharT* key) {
return m_ht.erase(key, std::strlen(key));
}
size_type erase(const std::basic_string<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#endif
/**
* Erase all the elements which have 'prefix' as prefix. Return the number of
* erase elements.
*/
size_type erase_prefix_ks(const CharT* prefix, size_type prefix_size) {
return m_ht.erase_prefix(prefix, prefix_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
/**
* @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size)
*/
size_type erase_prefix(const std::basic_string_view<CharT>& prefix) {
return m_ht.erase_prefix(prefix.data(), prefix.size());
}
#else
/**
* @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size)
*/
size_type erase_prefix(const CharT* prefix) {
return m_ht.erase_prefix(prefix, std::strlen(prefix));
}
/**
* @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size)
*/
size_type erase_prefix(const std::basic_string<CharT>& prefix) {
return m_ht.erase_prefix(prefix.data(), prefix.size());
}
#endif
void swap(htrie_map& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T& at_ks(const CharT* key, size_type key_size) {
return m_ht.at(key, key_size);
}
const T& at_ks(const CharT* key, size_type key_size) const {
return m_ht.at(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
T& at(const std::basic_string_view<CharT>& key) {
return m_ht.at(key.data(), key.size());
}
const T& at(const std::basic_string_view<CharT>& key) const {
return m_ht.at(key.data(), key.size());
}
#else
T& at(const CharT* key) { return m_ht.at(key, std::strlen(key)); }
const T& at(const CharT* key) const { return m_ht.at(key, std::strlen(key)); }
T& at(const std::basic_string<CharT>& key) {
return m_ht.at(key.data(), key.size());
}
const T& at(const std::basic_string<CharT>& key) const {
return m_ht.at(key.data(), key.size());
}
#endif
#ifdef TSL_HT_HAS_STRING_VIEW
T& operator[](const std::basic_string_view<CharT>& key) {
return m_ht.access_operator(key.data(), key.size());
}
#else
T& operator[](const CharT* key) {
return m_ht.access_operator(key, std::strlen(key));
}
T& operator[](const std::basic_string<CharT>& key) {
return m_ht.access_operator(key.data(), key.size());
}
#endif
size_type count_ks(const CharT* key, size_type key_size) const {
return m_ht.count(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
size_type count(const std::basic_string_view<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#else
size_type count(const CharT* key) const {
return m_ht.count(key, std::strlen(key));
}
size_type count(const std::basic_string<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#endif
iterator find_ks(const CharT* key, size_type key_size) {
return m_ht.find(key, key_size);
}
const_iterator find_ks(const CharT* key, size_type key_size) const {
return m_ht.find(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
iterator find(const std::basic_string_view<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string_view<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#else
iterator find(const CharT* key) { return m_ht.find(key, std::strlen(key)); }
const_iterator find(const CharT* key) const {
return m_ht.find(key, std::strlen(key));
}
iterator find(const std::basic_string<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#endif
std::pair<iterator, iterator> equal_range_ks(const CharT* key,
size_type key_size) {
return m_ht.equal_range(key, key_size);
}
std::pair<const_iterator, const_iterator> equal_range_ks(
const CharT* key, size_type key_size) const {
return m_ht.equal_range(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
std::pair<iterator, iterator> equal_range(
const std::basic_string_view<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string_view<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#else
std::pair<iterator, iterator> equal_range(const CharT* key) {
return m_ht.equal_range(key, std::strlen(key));
}
std::pair<const_iterator, const_iterator> equal_range(
const CharT* key) const {
return m_ht.equal_range(key, std::strlen(key));
}
std::pair<iterator, iterator> equal_range(
const std::basic_string<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#endif
/**
* Return a range containing all the elements which have 'prefix' as prefix.
* The range is defined by a pair of iterator, the first being the begin
* iterator and the second being the end iterator.
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range_ks(
const CharT* prefix, size_type prefix_size) {
return m_ht.equal_prefix_range(prefix, prefix_size);
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range_ks(
const CharT* prefix, size_type prefix_size) const {
return m_ht.equal_prefix_range(prefix, prefix_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range(
const std::basic_string_view<CharT>& prefix) {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range(
const std::basic_string_view<CharT>& prefix) const {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
#else
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range(
const CharT* prefix) {
return m_ht.equal_prefix_range(prefix, std::strlen(prefix));
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range(
const CharT* prefix) const {
return m_ht.equal_prefix_range(prefix, std::strlen(prefix));
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range(
const std::basic_string<CharT>& prefix) {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range(
const std::basic_string<CharT>& prefix) const {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
#endif
/**
* Return the element in the trie which is the longest prefix of `key`. If no
* element in the trie is a prefix of `key`, the end iterator is returned.
*
* Example:
*
* tsl::htrie_map<char, int> map = {{"/foo", 1}, {"/foo/bar", 1}};
*
* map.longest_prefix("/foo"); // returns {"/foo", 1}
* map.longest_prefix("/foo/baz"); // returns {"/foo", 1}
* map.longest_prefix("/foo/bar/baz"); // returns {"/foo/bar", 1}
* map.longest_prefix("/foo/bar/"); // returns {"/foo/bar", 1}
* map.longest_prefix("/bar"); // returns end()
* map.longest_prefix(""); // returns end()
*/
iterator longest_prefix_ks(const CharT* key, size_type key_size) {
return m_ht.longest_prefix(key, key_size);
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix_ks(const CharT* key, size_type key_size) const {
return m_ht.longest_prefix(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
iterator longest_prefix(const std::basic_string_view<CharT>& key) {
return m_ht.longest_prefix(key.data(), key.size());
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix(
const std::basic_string_view<CharT>& key) const {
return m_ht.longest_prefix(key.data(), key.size());
}
#else
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
iterator longest_prefix(const CharT* key) {
return m_ht.longest_prefix(key, std::strlen(key));
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix(const CharT* key) const {
return m_ht.longest_prefix(key, std::strlen(key));
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
iterator longest_prefix(const std::basic_string<CharT>& key) {
return m_ht.longest_prefix(key.data(), key.size());
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix(const std::basic_string<CharT>& key) const {
return m_ht.longest_prefix(key.data(), key.size());
}
#endif
/*
* Hash policy
*/
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
/*
* Burst policy
*/
size_type burst_threshold() const { return m_ht.burst_threshold(); }
void burst_threshold(size_type threshold) { m_ht.burst_threshold(threshold); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
/*
* Other
*/
/**
* Serialize the map through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following calls:
* - `void operator()(const U& value);` where the types `std::uint64_t`,
* `float` and `T` must be supported for U.
* - `void operator()(const CharT* value, std::size_t value_size);`
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer& serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized map through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> U operator()();` where the types `std::uint64_t`,
* `float` and `T` must be supported for U.
* - `void operator()(CharT* value_out, std::size_t value_size);`
*
* If the deserialized hash map part of the hat-trie is hash compatible with
* the serialized map, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash (take care of
* the 32-bits vs 64 bits), and KeySizeT must behave the same than the ones
* used in the serialized map. Otherwise the behaviour is undefined with
* `hash_compatible` sets to true.
*
* The behaviour is undefined if the type `CharT` and `T` of the `htrie_map`
* are not the same as the types used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static htrie_map deserialize(Deserializer& deserializer,
bool hash_compatible = false) {
htrie_map map;
map.m_ht.deserialize(deserializer, hash_compatible);
return map;
}
friend bool operator==(const htrie_map& lhs, const htrie_map& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
std::string key_buffer;
for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) {
it.key(key_buffer);
const auto it_element_rhs = rhs.find(key_buffer);
if (it_element_rhs == rhs.cend() ||
it.value() != it_element_rhs.value()) {
return false;
}
}
return true;
}
friend bool operator!=(const htrie_map& lhs, const htrie_map& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(htrie_map& lhs, htrie_map& rhs) { lhs.swap(rhs); }
private:
template <class U, class V>
void insert_pair(const std::pair<U, V>& value) {
insert(value.first, value.second);
}
template <class U, class V>
void insert_pair(std::pair<U, V>&& value) {
insert(value.first, std::move(value.second));
}
private:
ht m_ht;
};
} // end namespace tsl
#endif

578
include/tsl/htrie_set.h Normal file
View File

@ -0,0 +1,578 @@
/**
* MIT License
*
* Copyright (c) 2017 Thibaut Goetghebuer-Planchon <tessil@gmx.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HTRIE_SET_H
#define TSL_HTRIE_SET_H
#include <cstddef>
#include <cstring>
#include <initializer_list>
#include <string>
#include <utility>
#include "htrie_hash.h"
namespace tsl {
/**
* Implementation of a hat-trie set.
*
* The size of a key string is limited to std::numeric_limits<KeySizeT>::max()
* - 1. That is 65 535 characters by default, but can be raised with the
* KeySizeT template parameter. See max_key_size() for an easy access to this
* limit.
*
* Iterators invalidation:
* - clear, operator=: always invalidate the iterators.
* - insert: always invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template <class CharT, class Hash = tsl::ah::str_hash<CharT>,
class KeySizeT = std::uint16_t>
class htrie_set {
private:
template <typename U>
using is_iterator = tsl::detail_array_hash::is_iterator<U>;
using ht = tsl::detail_htrie_hash::htrie_hash<CharT, void, Hash, KeySizeT>;
public:
using char_type = typename ht::char_type;
using key_size_type = typename ht::key_size_type;
using size_type = typename ht::size_type;
using hasher = typename ht::hasher;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
using prefix_iterator = typename ht::prefix_iterator;
using const_prefix_iterator = typename ht::const_prefix_iterator;
public:
explicit htrie_set(const Hash& hash = Hash())
: m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR,
ht::DEFAULT_BURST_THRESHOLD) {}
explicit htrie_set(size_type burst_threshold, const Hash& hash = Hash())
: m_ht(hash, ht::HASH_NODE_DEFAULT_MAX_LOAD_FACTOR, burst_threshold) {}
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
htrie_set(InputIt first, InputIt last, const Hash& hash = Hash())
: htrie_set(hash) {
insert(first, last);
}
#ifdef TSL_HT_HAS_STRING_VIEW
htrie_set(std::initializer_list<std::basic_string_view<CharT>> init,
const Hash& hash = Hash())
: htrie_set(hash) {
insert(init);
}
#else
htrie_set(std::initializer_list<const CharT*> init, const Hash& hash = Hash())
: htrie_set(hash) {
insert(init);
}
#endif
#ifdef TSL_HT_HAS_STRING_VIEW
htrie_set& operator=(
std::initializer_list<std::basic_string_view<CharT>> ilist) {
clear();
insert(ilist);
return *this;
}
#else
htrie_set& operator=(std::initializer_list<const CharT*> ilist) {
clear();
insert(ilist);
return *this;
}
#endif
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
size_type max_key_size() const noexcept { return m_ht.max_key_size(); }
/**
* Call shrink_to_fit() on each hash node of the hat-trie to reduce its size.
*/
void shrink_to_fit() { m_ht.shrink_to_fit(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert_ks(const CharT* key, size_type key_size) {
return m_ht.insert(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
std::pair<iterator, bool> insert(const std::basic_string_view<CharT>& key) {
return m_ht.insert(key.data(), key.size());
}
#else
std::pair<iterator, bool> insert(const CharT* key) {
return m_ht.insert(key, std::strlen(key));
}
std::pair<iterator, bool> insert(const std::basic_string<CharT>& key) {
return m_ht.insert(key.data(), key.size());
}
#endif
template <class InputIt, typename std::enable_if<
is_iterator<InputIt>::value>::type* = nullptr>
void insert(InputIt first, InputIt last) {
for (auto it = first; it != last; ++it) {
insert(*it);
}
}
#ifdef TSL_HT_HAS_STRING_VIEW
void insert(std::initializer_list<std::basic_string_view<CharT>> ilist) {
insert(ilist.begin(), ilist.end());
}
#else
void insert(std::initializer_list<const CharT*> ilist) {
insert(ilist.begin(), ilist.end());
}
#endif
std::pair<iterator, bool> emplace_ks(const CharT* key, size_type key_size) {
return m_ht.insert(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
std::pair<iterator, bool> emplace(const std::basic_string_view<CharT>& key) {
return m_ht.insert(key.data(), key.size());
}
#else
std::pair<iterator, bool> emplace(const CharT* key) {
return m_ht.insert(key, std::strlen(key));
}
std::pair<iterator, bool> emplace(const std::basic_string<CharT>& key) {
return m_ht.insert(key.data(), key.size());
}
#endif
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase_ks(const CharT* key, size_type key_size) {
return m_ht.erase(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
size_type erase(const std::basic_string_view<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#else
size_type erase(const CharT* key) {
return m_ht.erase(key, std::strlen(key));
}
size_type erase(const std::basic_string<CharT>& key) {
return m_ht.erase(key.data(), key.size());
}
#endif
/**
* Erase all the elements which have 'prefix' as prefix. Return the number of
* erase elements.
*/
size_type erase_prefix_ks(const CharT* prefix, size_type prefix_size) {
return m_ht.erase_prefix(prefix, prefix_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
/**
* @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size)
*/
size_type erase_prefix(const std::basic_string_view<CharT>& prefix) {
return m_ht.erase_prefix(prefix.data(), prefix.size());
}
#else
/**
* @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size)
*/
size_type erase_prefix(const CharT* prefix) {
return m_ht.erase_prefix(prefix, std::strlen(prefix));
}
/**
* @copydoc erase_prefix_ks(const CharT* prefix, size_type prefix_size)
*/
size_type erase_prefix(const std::basic_string<CharT>& prefix) {
return m_ht.erase_prefix(prefix.data(), prefix.size());
}
#endif
void swap(htrie_set& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
size_type count_ks(const CharT* key, size_type key_size) const {
return m_ht.count(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
size_type count(const std::basic_string_view<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#else
size_type count(const CharT* key) const {
return m_ht.count(key, std::strlen(key));
}
size_type count(const std::basic_string<CharT>& key) const {
return m_ht.count(key.data(), key.size());
}
#endif
iterator find_ks(const CharT* key, size_type key_size) {
return m_ht.find(key, key_size);
}
const_iterator find_ks(const CharT* key, size_type key_size) const {
return m_ht.find(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
iterator find(const std::basic_string_view<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string_view<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#else
iterator find(const CharT* key) { return m_ht.find(key, std::strlen(key)); }
const_iterator find(const CharT* key) const {
return m_ht.find(key, std::strlen(key));
}
iterator find(const std::basic_string<CharT>& key) {
return m_ht.find(key.data(), key.size());
}
const_iterator find(const std::basic_string<CharT>& key) const {
return m_ht.find(key.data(), key.size());
}
#endif
std::pair<iterator, iterator> equal_range_ks(const CharT* key,
size_type key_size) {
return m_ht.equal_range(key, key_size);
}
std::pair<const_iterator, const_iterator> equal_range_ks(
const CharT* key, size_type key_size) const {
return m_ht.equal_range(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
std::pair<iterator, iterator> equal_range(
const std::basic_string_view<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string_view<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#else
std::pair<iterator, iterator> equal_range(const CharT* key) {
return m_ht.equal_range(key, std::strlen(key));
}
std::pair<const_iterator, const_iterator> equal_range(
const CharT* key) const {
return m_ht.equal_range(key, std::strlen(key));
}
std::pair<iterator, iterator> equal_range(
const std::basic_string<CharT>& key) {
return m_ht.equal_range(key.data(), key.size());
}
std::pair<const_iterator, const_iterator> equal_range(
const std::basic_string<CharT>& key) const {
return m_ht.equal_range(key.data(), key.size());
}
#endif
/**
* Return a range containing all the elements which have 'prefix' as prefix.
* The range is defined by a pair of iterator, the first being the begin
* iterator and the second being the end iterator.
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range_ks(
const CharT* prefix, size_type prefix_size) {
return m_ht.equal_prefix_range(prefix, prefix_size);
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range_ks(
const CharT* prefix, size_type prefix_size) const {
return m_ht.equal_prefix_range(prefix, prefix_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range(
const std::basic_string_view<CharT>& prefix) {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range(
const std::basic_string_view<CharT>& prefix) const {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
#else
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range(
const CharT* prefix) {
return m_ht.equal_prefix_range(prefix, std::strlen(prefix));
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range(
const CharT* prefix) const {
return m_ht.equal_prefix_range(prefix, std::strlen(prefix));
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<prefix_iterator, prefix_iterator> equal_prefix_range(
const std::basic_string<CharT>& prefix) {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
/**
* @copydoc equal_prefix_range_ks(const CharT* prefix, size_type prefix_size)
*/
std::pair<const_prefix_iterator, const_prefix_iterator> equal_prefix_range(
const std::basic_string<CharT>& prefix) const {
return m_ht.equal_prefix_range(prefix.data(), prefix.size());
}
#endif
/**
* Return the element in the trie which is the longest prefix of `key`. If no
* element in the trie is a prefix of `key`, the end iterator is returned.
*
* Example:
*
* tsl::htrie_set<char> set = {"/foo", "/foo/bar"};
*
* set.longest_prefix("/foo"); // returns "/foo"
* set.longest_prefix("/foo/baz"); // returns "/foo"
* set.longest_prefix("/foo/bar/baz"); // returns "/foo/bar"
* set.longest_prefix("/foo/bar/"); // returns "/foo/bar"
* set.longest_prefix("/bar"); // returns end()
* set.longest_prefix(""); // returns end()
*/
iterator longest_prefix_ks(const CharT* key, size_type key_size) {
return m_ht.longest_prefix(key, key_size);
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix_ks(const CharT* key, size_type key_size) const {
return m_ht.longest_prefix(key, key_size);
}
#ifdef TSL_HT_HAS_STRING_VIEW
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
iterator longest_prefix(const std::basic_string_view<CharT>& key) {
return m_ht.longest_prefix(key.data(), key.size());
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix(
const std::basic_string_view<CharT>& key) const {
return m_ht.longest_prefix(key.data(), key.size());
}
#else
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
iterator longest_prefix(const CharT* key) {
return m_ht.longest_prefix(key, std::strlen(key));
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix(const CharT* key) const {
return m_ht.longest_prefix(key, std::strlen(key));
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
iterator longest_prefix(const std::basic_string<CharT>& key) {
return m_ht.longest_prefix(key.data(), key.size());
}
/**
* @copydoc longest_prefix_ks(const CharT* key, size_type key_size)
*/
const_iterator longest_prefix(const std::basic_string<CharT>& key) const {
return m_ht.longest_prefix(key.data(), key.size());
}
#endif
/*
* Hash policy
*/
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
/*
* Burst policy
*/
size_type burst_threshold() const { return m_ht.burst_threshold(); }
void burst_threshold(size_type threshold) { m_ht.burst_threshold(threshold); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
/*
* Other
*/
/**
* Serialize the set through the `serializer` parameter.
*
* The `serializer` parameter must be a function object that supports the
* following calls:
* - `void operator()(const U& value);` where the types `std::uint64_t` and
* `float` must be supported for U.
* - `void operator()(const CharT* value, std::size_t value_size);`
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, ...) of the types it serializes in the hands of the `Serializer`
* function object if compatibility is required.
*/
template <class Serializer>
void serialize(Serializer& serializer) const {
m_ht.serialize(serializer);
}
/**
* Deserialize a previously serialized set through the `deserializer`
* parameter.
*
* The `deserializer` parameter must be a function object that supports the
* following calls:
* - `template<typename U> U operator()();` where the types `std::uint64_t`
* and `float` must be supported for U.
* - `void operator()(CharT* value_out, std::size_t value_size);`
*
* If the deserialized hash set part of the hat-trie is hash compatible with
* the serialized set, the deserialization process can be sped up by setting
* `hash_compatible` to true. To be hash compatible, the Hash (take care of
* the 32-bits vs 64 bits), and KeySizeT must behave the same than the ones
* used in the serialized set. Otherwise the behaviour is undefined with
* `hash_compatible` sets to true.
*
* The behaviour is undefined if the type `CharT` of the `htrie_set` is not
* the same as the type used during serialization.
*
* The implementation leaves binary compatibility (endianness, IEEE 754 for
* floats, size of int, ...) of the types it deserializes in the hands of the
* `Deserializer` function object if compatibility is required.
*/
template <class Deserializer>
static htrie_set deserialize(Deserializer& deserializer,
bool hash_compatible = false) {
htrie_set set;
set.m_ht.deserialize(deserializer, hash_compatible);
return set;
}
friend bool operator==(const htrie_set& lhs, const htrie_set& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
std::string key_buffer;
for (auto it = lhs.cbegin(); it != lhs.cend(); ++it) {
it.key(key_buffer);
const auto it_element_rhs = rhs.find(key_buffer);
if (it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
friend bool operator!=(const htrie_set& lhs, const htrie_set& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(htrie_set& lhs, htrie_set& rhs) { lhs.swap(rhs); }
private:
ht m_ht;
};
} // end namespace tsl
#endif

View File

@ -689,7 +689,10 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
const size_t min_len_1typo,
const size_t min_len_2typo,
bool split_join_tokens,
const size_t max_candidates) const {
const size_t max_candidates,
const std::vector<infix_t>& infixes,
const size_t max_extra_prefix,
const size_t max_extra_suffix) const {
std::shared_lock lock(mutex);
@ -721,6 +724,13 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
}
}
if(!search_fields.empty() && search_fields.size() != infixes.size()) {
if(infixes.size() != 1) {
return Option<nlohmann::json>(400, "Number of infix values in `infix` does not match "
"number of `query_by` fields.");
}
}
// process weights for search fields
std::vector<search_field_t> weighted_search_fields;
size_t max_weight = 20;
@ -1005,7 +1015,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
group_by_fields, group_limit, default_sorting_field, prioritize_exact_match,
exhaustive_search, 4, filter_overrides,
search_stop_millis,
min_len_1typo, min_len_2typo, max_candidates);
min_len_1typo, min_len_2typo, max_candidates, infixes,
max_extra_prefix, max_extra_suffix);
index->run_search(search_params);

View File

@ -5,6 +5,7 @@
#include "collection_manager.h"
#include "batched_indexer.h"
#include "logger.h"
#include "magic_enum.hpp"
constexpr const size_t CollectionManager::DEFAULT_NUM_MEMORY_SHARDS;
@ -567,6 +568,10 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *ENABLE_OVERRIDES = "enable_overrides";
const char *MAX_CANDIDATES = "max_candidates";
const char *INFIX = "infix";
const char *MAX_EXTRA_PREFIX = "max_extra_prefix";
const char *MAX_EXTRA_SUFFIX = "max_extra_suffix";
// strings under this length will be fully highlighted, instead of showing a snippet of relevant portion
const char *SNIPPET_THRESHOLD = "snippet_threshold";
@ -708,6 +713,14 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
req_params[SPLIT_JOIN_TOKENS] = "true";
}
if(req_params.count(MAX_EXTRA_PREFIX) == 0) {
req_params[MAX_EXTRA_PREFIX] = std::to_string(INT16_MAX);
}
if(req_params.count(MAX_EXTRA_SUFFIX) == 0) {
req_params[MAX_EXTRA_SUFFIX] = std::to_string(INT16_MAX);
}
std::vector<std::string> query_by_weights_str;
std::vector<size_t> query_by_weights;
@ -781,6 +794,14 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
return Option<bool>(400,"Parameter `" + std::string(SEARCH_CUTOFF_MS) + "` must be an unsigned integer.");
}
if(!StringUtils::is_uint32_t(req_params[MAX_EXTRA_PREFIX])) {
return Option<bool>(400,"Parameter `" + std::string(MAX_EXTRA_PREFIX) + "` must be an unsigned integer.");
}
if(!StringUtils::is_uint32_t(req_params[MAX_EXTRA_SUFFIX])) {
return Option<bool>(400,"Parameter `" + std::string(MAX_EXTRA_SUFFIX) + "` must be an unsigned integer.");
}
bool prioritize_exact_match = (req_params[PRIORITIZE_EXACT_MATCH] == "true");
bool pre_segmented_query = (req_params[PRE_SEGMENTED_QUERY] == "true");
bool exhaustive_search = (req_params[EXHAUSTIVE_SEARCH] == "true");
@ -833,6 +854,21 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
req_params[MAX_CANDIDATES] = exhaustive_search ? "10000" : "4";
}
std::vector<infix_t> infixes;
if(req_params.count(INFIX) != 0) {
std::vector<std::string> infix_strs;
StringUtils::split(req_params[INFIX], infix_strs, ",");
for(auto& infix_str: infix_strs) {
auto infix_op = magic_enum::enum_cast<infix_t>(infix_str);
if(infix_op.has_value()) {
infixes.push_back(infix_op.value());
}
}
} else {
infixes.push_back(off);
}
bool enable_overrides = (req_params[ENABLE_OVERRIDES] == "true");
CollectionManager & collectionManager = CollectionManager::get_instance();
@ -898,7 +934,10 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
static_cast<size_t>(std::stol(req_params[MIN_LEN_1TYPO])),
static_cast<size_t>(std::stol(req_params[MIN_LEN_2TYPO])),
split_join_tokens,
static_cast<size_t>(std::stol(req_params[MAX_CANDIDATES]))
static_cast<size_t>(std::stol(req_params[MAX_CANDIDATES])),
infixes,
static_cast<size_t>(std::stol(req_params[MAX_EXTRA_PREFIX])),
static_cast<size_t>(std::stol(req_params[MAX_EXTRA_SUFFIX]))
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -72,6 +72,16 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
art_tree_init(ft);
search_index.emplace(fname_field.second.faceted_name(), ft);
}
if(fname_field.second.infix) {
array_mapped_infix_t infix_sets(ARRAY_INFIX_DIM);
for(auto& infix_set: infix_sets) {
infix_set = new tsl::htrie_set<char>();
}
infix_index.emplace(fname_field.second.name, infix_sets);
}
}
for(const auto & pair: sort_schema) {
@ -139,6 +149,15 @@ Index::~Index() {
sort_index.clear();
for(auto& kv: infix_index) {
for(auto& infix_set: kv.second) {
delete infix_set;
infix_set = nullptr;
}
}
infix_index.clear();
for(auto& name_tree: str_sort_index) {
delete name_tree.second;
name_tree.second = nullptr;
@ -660,6 +679,12 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
for(auto &token_offsets: field_index_it->second.offsets) {
token_to_doc_offsets[token_offsets.first].emplace_back(seq_id, record.points, token_offsets.second);
if(afield.infix) {
auto strhash = StringUtils::hash_wy(token_offsets.first.c_str(), token_offsets.first.size());
const auto& infix_sets = infix_index.at(afield.name);
infix_sets[strhash % 4]->insert(token_offsets.first);
}
}
}
@ -1638,7 +1663,10 @@ void Index::run_search(search_args* search_params) {
search_params->search_cutoff_ms,
search_params->min_len_1typo,
search_params->min_len_2typo,
search_params->max_candidates);
search_params->max_candidates,
search_params->infixes,
search_params->max_extra_prefix,
search_params->max_extra_suffix);
}
void Index::collate_included_ids(const std::vector<std::string>& q_included_tokens,
@ -2013,6 +2041,59 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string&
return false;
}
void Index::search_infix(const std::string& query, const std::string& field_name,
std::vector<uint32_t>& ids, const size_t max_extra_prefix, const size_t max_extra_suffix) const {
auto infix_maps_it = infix_index.find(field_name);
if(infix_maps_it == infix_index.end()) {
return ;
}
auto infix_sets = infix_maps_it->second;
std::vector<art_leaf*> leaves;
size_t num_processed = 0;
std::mutex m_process;
std::condition_variable cv_process;
auto search_tree = search_index.at(field_name);
for(auto infix_set: infix_sets) {
thread_pool->enqueue([infix_set, &leaves, search_tree, &query, max_extra_prefix, max_extra_suffix,
&num_processed, &m_process, &cv_process]() {
std::vector<art_leaf*> this_leaves;
std::string key_buffer;
for(auto it = infix_set->begin(); it != infix_set->end(); it++) {
it.key(key_buffer);
auto start_index = key_buffer.find(query);
if(start_index != std::string::npos && start_index <= max_extra_prefix &&
(key_buffer.size() - (start_index + query.size())) <= max_extra_suffix) {
art_leaf* l = (art_leaf *) art_search(search_tree,
(const unsigned char *) key_buffer.c_str(),
key_buffer.size()+1);
if(l != nullptr) {
this_leaves.push_back(l);
}
}
}
std::unique_lock<std::mutex> lock(m_process);
leaves.insert(leaves.end(), this_leaves.begin(), this_leaves.end());
num_processed++;
cv_process.notify_one();
});
}
std::unique_lock<std::mutex> lock_process(m_process);
cv_process.wait(lock_process, [&](){ return num_processed == infix_sets.size(); });
for(auto leaf: leaves) {
posting_t::merge({leaf->values}, ids);
}
}
void Index::search(std::vector<query_tokens_t>& field_query_tokens,
const std::vector<search_field_t>& search_fields,
std::vector<filter>& filters,
@ -2040,7 +2121,10 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
const size_t search_cutoff_ms,
size_t min_len_1typo,
size_t min_len_2typo,
const size_t max_candidates) const {
const size_t max_candidates,
const std::vector<infix_t>& infixes,
const size_t max_extra_prefix,
const size_t max_extra_suffix) const {
search_begin = std::chrono::high_resolution_clock::now();
search_stop_ms = search_cutoff_ms;
@ -2247,6 +2331,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
int field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0];
bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0];
infix_t field_infix = (i < infixes.size()) ? infixes[i] : infixes[0];
// proceed to query search only when no filters are provided or when filtering produces results
if(filters.empty() || actual_filter_ids_length > 0) {
@ -2282,6 +2367,34 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
query_hashes, token_order, field_prefix,
drop_tokens_threshold, typo_tokens_threshold, exhaustive_search,
min_len_1typo, min_len_2typo, max_candidates);
if(field_infix == always || (field_infix == fallback && field_num_results == 0)) {
std::vector<uint32_t> infix_ids;
search_infix(query_tokens[0].value, field_name, infix_ids, max_extra_prefix, max_extra_suffix);
if(!infix_ids.empty()) {
int sort_order[3]; // 1 or -1 based on DESC or ASC respectively
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values;
std::vector<size_t> geopoint_indices;
populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
uint32_t token_bits = 255;
for(auto seq_id: infix_ids) {
score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 2,
actual_topster, {}, groups_processed, seq_id, sort_order, field_values,
geopoint_indices, group_limit, group_by_fields, token_bits,
false, false, {});
}
std::sort(infix_ids.begin(), infix_ids.end());
infix_ids.erase(std::unique( infix_ids.begin(), infix_ids.end() ), infix_ids.end());
uint32_t* new_all_result_ids = nullptr;
all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &infix_ids[0],
infix_ids.size(), &new_all_result_ids);
delete[] all_result_ids;
all_result_ids = new_all_result_ids;
}
}
} else if(actual_filter_ids_length != 0) {
// indicates exact match query
curate_filtered_ids(filters, curated_ids, exclude_token_ids,
@ -3708,6 +3821,15 @@ void Index::refresh_schemas(const std::vector<field>& new_fields) {
search_index.emplace(new_field.faceted_name(), ft);
}
}
if(new_field.infix) {
array_mapped_infix_t infix_sets(ARRAY_INFIX_DIM);
for(auto& infix_set: infix_sets) {
infix_set = new tsl::htrie_set<char>();
}
infix_index.emplace(new_field.name, infix_sets);
}
}
}

View File

@ -581,6 +581,17 @@ void posting_list_t::merge(const std::vector<posting_list_t*>& posting_lists, st
sum_sizes += posting_list->num_ids();
}
if(its.size() == 1) {
result_ids.reserve(posting_lists[0]->ids_length);
auto it = posting_lists[0]->new_iterator();
while(it.valid()) {
result_ids.push_back(it.id());
it.next();
}
return ;
}
result_ids.reserve(sum_sizes);
size_t num_lists = its.size();

View File

@ -0,0 +1,222 @@
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <fstream>
#include <algorithm>
#include <collection_manager.h>
#include "collection.h"
class CollectionInfixSearchTest : public ::testing::Test {
protected:
Store *store;
CollectionManager & collectionManager = CollectionManager::get_instance();
std::atomic<bool> quit = false;
std::vector<std::string> query_fields;
std::vector<sort_by> sort_fields;
void setupCollection() {
std::string state_dir_path = "/tmp/typesense_test/collection_infix";
LOG(INFO) << "Truncating and creating: " << state_dir_path;
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
store = new Store(state_dir_path);
collectionManager.init(store, 1.0, "auth_key", quit);
collectionManager.load(8, 1000);
}
virtual void SetUp() {
setupCollection();
}
virtual void TearDown() {
collectionManager.dispose();
delete store;
}
};
TEST_F(CollectionInfixSearchTest, InfixBasics) {
std::vector<field> fields = {field("title", field_types::STRING, false, false, true, "", -1, 1),
field("points", field_types::INT32, false),};
Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
nlohmann::json doc;
doc["id"] = "0";
doc["title"] = "GH100037IN8900X";
doc["points"] = 100;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
auto results = coll1->search("100037",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always}).get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
// verify off behavior
results = coll1->search("100037",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {off}).get();
ASSERT_EQ(0, results["found"].get<size_t>());
ASSERT_EQ(0, results["hits"].size());
// when fallback is used, only the prefix result is returned
doc["id"] = "1";
doc["title"] = "100037SG7120X";
doc["points"] = 100;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
results = coll1->search("100037",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {fallback}).get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
// always behavior: both prefix and infix matches are returned but ranked below prefix match
results = coll1->search("100037",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always}).get();
ASSERT_EQ(2, results["found"].get<size_t>());
ASSERT_EQ(2, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_TRUE(results["hits"][0]["text_match"].get<size_t>() > results["hits"][1]["text_match"].get<size_t>());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionInfixSearchTest, RespectPrefixAndSuffixLimits) {
std::vector<field> fields = {field("title", field_types::STRING, false, false, true, "", -1, 1),
field("points", field_types::INT32, false),};
Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
nlohmann::json doc;
doc["id"] = "0";
doc["title"] = "GH100037IN8900X";
doc["points"] = 100;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
doc["id"] = "1";
doc["title"] = "X100037SG89007120X";
doc["points"] = 100;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
// check extra prefixes
auto results = coll1->search("100037",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always}, 1).get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
results = coll1->search("100037",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always}, 2).get();
ASSERT_EQ(2, results["found"].get<size_t>());
ASSERT_EQ(2, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
// check extra suffixes
results = coll1->search("8900",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always}, INT16_MAX, 2).get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
results = coll1->search("8900",
{"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always}, INT16_MAX, 5).get();
ASSERT_EQ(2, results["found"].get<size_t>());
ASSERT_EQ(2, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionInfixSearchTest, InfixSpecificField) {
std::vector<field> fields = {field("title", field_types::STRING, false, false, true, "", -1, 1),
field("description", field_types::STRING, false, false, true, "", -1, 1),
field("points", field_types::INT32, false),};
Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
nlohmann::json doc;
doc["id"] = "0";
doc["title"] = "GH100037IN8900X";
doc["description"] = "foobar";
doc["points"] = 100;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
doc["id"] = "1";
doc["title"] = "foobar";
doc["description"] = "GH100037IN8900X";
doc["points"] = 100;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
auto results = coll1->search("100037",
{"title", "description"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {always, off}).get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
results = coll1->search("100037",
{"title", "description"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {off, always}).get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}