#pragma once #include #include #include #include #include #include #include #include #include "art.h" #include "index.h" #include "number.h" #include "store.h" #include "topster.h" #include "json.hpp" #include "field.h" #include "option.h" #include "tsl/htrie_map.h" #include "tokenizer.h" #include "synonym_index.h" #include "vq_model_manager.h" struct doc_seq_id_t { uint32_t seq_id; bool is_new; }; struct highlight_field_t { std::string name; bool fully_highlighted; bool infix; bool is_string; tsl::htrie_map qtoken_leaves; highlight_field_t(const std::string& name, bool fully_highlighted, bool infix, bool is_string): name(name), fully_highlighted(fully_highlighted), infix(infix), is_string(is_string) { } }; struct reference_pair { std::string collection; std::string field; reference_pair(std::string collection, std::string field) : collection(std::move(collection)), field(std::move(field)) {} bool operator < (const reference_pair& pair) const { return collection < pair.collection; } }; class Collection { private: mutable std::shared_mutex mutex; // ensures that a Collection* is not destructed while in use by multiple threads mutable std::shared_mutex lifecycle_mutex; const uint8_t CURATED_RECORD_IDENTIFIER = 100; const size_t DEFAULT_TOPSTER_SIZE = 250; struct highlight_t { size_t field_index; std::string field; std::vector snippets; std::vector values; std::vector indices; uint64_t match_score; std::vector> matched_tokens; highlight_t(): field_index(0), match_score(0) { } bool operator<(const highlight_t& a) const { return std::tie(match_score, field_index) > std::tie(a.match_score, field_index); } }; struct match_index_t { Match match; uint64_t match_score = 0; size_t index; match_index_t(Match match, uint64_t match_score, size_t index): match(match), match_score(match_score), index(index) { } bool operator<(const match_index_t& a) const { if(match_score != a.match_score) { return match_score > a.match_score; } return index < a.index; } }; const std::string name; const std::atomic collection_id; const std::atomic created_at; std::atomic num_documents; // Auto incrementing record ID used internally for indexing - not exposed to the client std::atomic next_seq_id; Store* store; std::vector fields; tsl::htrie_map search_schema; std::map overrides; // maps tag name => override_ids std::map> override_tags; std::string default_sorting_field; const float max_memory_ratio; std::string fallback_field_type; std::unordered_map dynamic_fields; tsl::htrie_map nested_fields; tsl::htrie_map embedding_fields; bool enable_nested_fields; std::vector symbols_to_index; std::vector token_separators; SynonymIndex* synonym_index; /// "field name" -> reference_pair(referenced_collection_name, referenced_field_name) spp::sparse_hash_map reference_fields; /// Contains the info where the current collection is referenced. /// Useful to perform operations such as cascading delete. /// collection_name -> field_name spp::sparse_hash_map referenced_in; /// Reference helper fields that are part of an object. The reference doc of these fields will be included in the /// object rather than in the document. tsl::htrie_set object_reference_helper_fields; // Keep index as the last field since it is initialized in the constructor via init_index(). Add a new field before it. Index* index; // methods std::string get_doc_id_key(const std::string & doc_id) const; std::string get_seq_id_key(uint32_t seq_id) const; void highlight_result(const std::string& h_obj, const field &search_field, const size_t search_field_index, const tsl::htrie_map& qtoken_leaves, const KV* field_order_kv, const nlohmann::json &document, nlohmann::json& highlight_doc, StringUtils & string_utils, const size_t snippet_threshold, const size_t highlight_affix_num_tokens, bool highlight_fully, bool is_infix_search, const std::string& highlight_start_tag, const std::string& highlight_end_tag, const uint8_t* index_symbols, highlight_t &highlight, bool& found_highlight, bool& found_full_highlight) const; void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); void process_remove_field_for_embedding_fields(const field& del_field, std::vector& garbage_embed_fields); bool does_override_match(const override_t& override, std::string& query, std::set& excluded_set, std::string& actual_query, const std::string& filter_query, bool already_segmented, const bool tags_matched, const bool wildcard_tag_matched, const std::map>& pinned_hits, const std::vector& hidden_hits, std::vector>& included_ids, std::vector& excluded_ids, std::vector& filter_overrides, bool& filter_curated_hits, std::string& curated_sort_by, nlohmann::json& override_metadata) const; void curate_results(std::string& actual_query, const std::string& filter_query, bool enable_overrides, bool already_segmented, const std::set& tags, const std::map>& pinned_hits, const std::vector& hidden_hits, std::vector>& included_ids, std::vector& excluded_ids, std::vector& filter_overrides, bool& filter_curated_hits, std::string& curated_sort_by, nlohmann::json& override_metadata) const; static Option detect_new_fields(nlohmann::json& document, const DIRTY_VALUES& dirty_values, const tsl::htrie_map& schema, const std::unordered_map& dyn_fields, tsl::htrie_map& nested_fields, const std::string& fallback_field_type, bool is_update, std::vector& new_fields, bool enable_nested_fields, const spp::sparse_hash_map& reference_fields, tsl::htrie_set& object_reference_helper_fields); static bool facet_count_compare(const facet_count_t& a, const facet_count_t& b) { return std::tie(a.count, a.fhash) > std::tie(b.count, b.fhash); } static bool facet_count_str_compare(const facet_value_t& a, const facet_value_t& b) { size_t a_count = a.count; size_t b_count = b.count; size_t a_value_size = UINT64_MAX - a.value.size(); size_t b_value_size = UINT64_MAX - b.value.size(); return std::tie(a_count, a_value_size, a.value) > std::tie(b_count, b_value_size, b.value); } static Option parse_pinned_hits(const std::string& pinned_hits_str, std::map>& pinned_hits); static Option parse_drop_tokens_mode(const std::string& drop_tokens_mode); Index* init_index(); static std::vector to_char_array(const std::vector& strs); Option validate_and_standardize_sort_fields_with_lock(const std::vector & sort_fields, std::vector& sort_fields_std, bool is_wildcard_query,const bool is_vector_query, const std::string& query, bool is_group_by_query = false, const size_t remote_embedding_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) const; Option validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, const bool is_wildcard_query, const bool is_vector_query, const std::string& query, bool is_group_by_query = false, const size_t remote_embedding_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2, const bool is_reference_sort = false) const; Option persist_collection_meta(); Option batch_alter_data(const std::vector& alter_fields, const std::vector& del_fields, const std::string& this_fallback_field_type); Option validate_alter_payload(nlohmann::json& schema_changes, std::vector& addition_fields, std::vector& reindex_fields, std::vector& del_fields, std::string& fallback_field_type); void process_filter_overrides(std::vector& filter_overrides, std::vector& q_include_tokens, token_ordering token_order, filter_node_t*& filter_tree_root, std::vector>& included_ids, std::vector& excluded_ids, nlohmann::json& override_metadata, bool enable_typos_for_numerical_tokens=true, bool enable_typos_for_alpha_numerical_tokens=true) const; void populate_text_match_info(nlohmann::json& info, uint64_t match_score, const text_match_type_t match_type, const size_t total_tokens) const; bool handle_highlight_text(std::string& text, bool normalise, const field &search_field, const std::vector& symbols_to_index, const std::vector& token_separators, highlight_t& highlight, StringUtils & string_utils, bool use_word_tokenizer, const size_t highlight_affix_num_tokens, const tsl::htrie_map& qtoken_leaves, int last_valid_offset_index, const size_t prefix_token_num_chars, bool highlight_fully, const size_t snippet_threshold, bool is_infix_search, std::vector& raw_query_tokens, size_t last_valid_offset, const std::string& highlight_start_tag, const std::string& highlight_end_tag, const uint8_t* index_symbols, const match_index_t& match_index) const; static Option extract_field_name(const std::string& field_name, const tsl::htrie_map& search_schema, std::vector& processed_search_fields, bool extract_only_string_fields, bool enable_nested_fields, const bool handle_wildcard = true, const bool& include_id = false); bool is_nested_array(const nlohmann::json& obj, std::vector path_parts, size_t part_i) const; template static bool highlight_nested_field(const nlohmann::json& hdoc, nlohmann::json& hobj, std::vector& path_parts, size_t path_index, bool is_arr_obj_ele, int array_index, T func); static Option resolve_field_type(field& new_field, nlohmann::detail::iter_impl>& kv, nlohmann::json& document, const DIRTY_VALUES& dirty_values, const bool found_dynamic_field, const std::string& fallback_field_type, bool enable_nested_fields, std::vector& new_fields); static uint64_t extract_bits(uint64_t value, unsigned lsb_offset, unsigned n); Option populate_include_exclude_fields(const spp::sparse_hash_set& include_fields, const spp::sparse_hash_set& exclude_fields, tsl::htrie_set& include_fields_full, tsl::htrie_set& exclude_fields_full) const; Option get_referenced_in_field(const std::string& collection_name) const; Option get_related_ids(const std::string& ref_field_name, const uint32_t& seq_id, std::vector& result) const; Option get_object_array_related_id(const std::string& ref_field_name, const uint32_t& seq_id, const uint32_t& object_index, uint32_t& result) const; void remove_embedding_field(const std::string& field_name); Option parse_and_validate_vector_query(const std::string& vector_query_str, vector_query_t& vector_query, const bool is_wildcard_query, const size_t remote_embedding_timeout_ms, const size_t remote_embedding_num_tries, size_t& per_page) const; std::shared_ptr vq_model = nullptr; public: enum {MAX_ARRAY_MATCHES = 5}; const size_t GROUP_LIMIT_MAX = 99; // Using a $ prefix so that these meta keys stay above record entries in a lexicographically ordered KV store static constexpr const char* COLLECTION_META_PREFIX = "$CM"; static constexpr const char* COLLECTION_NEXT_SEQ_PREFIX = "$CS"; static constexpr const char* COLLECTION_OVERRIDE_PREFIX = "$CO"; static constexpr const char* SEQ_ID_PREFIX = "$SI"; static constexpr const char* DOC_ID_PREFIX = "$DI"; static constexpr const char* COLLECTION_NAME_KEY = "name"; static constexpr const char* COLLECTION_ID_KEY = "id"; static constexpr const char* COLLECTION_SEARCH_FIELDS_KEY = "fields"; static constexpr const char* COLLECTION_DEFAULT_SORTING_FIELD_KEY = "default_sorting_field"; static constexpr const char* COLLECTION_CREATED = "created_at"; static constexpr const char* COLLECTION_NUM_MEMORY_SHARDS = "num_memory_shards"; static constexpr const char* COLLECTION_FALLBACK_FIELD_TYPE = "fallback_field_type"; static constexpr const char* COLLECTION_ENABLE_NESTED_FIELDS = "enable_nested_fields"; static constexpr const char* COLLECTION_SYMBOLS_TO_INDEX = "symbols_to_index"; static constexpr const char* COLLECTION_SEPARATORS = "token_separators"; static constexpr const char* COLLECTION_VOICE_QUERY_MODEL = "voice_query_model"; static constexpr const char* COLLECTION_METADATA = "metadata"; // methods Collection() = delete; Collection(const std::string& name, const uint32_t collection_id, const uint64_t created_at, const uint32_t next_seq_id, Store *store, const std::vector& fields, const std::string& default_sorting_field, const float max_memory_ratio, const std::string& fallback_field_type, const std::vector& symbols_to_index, const std::vector& token_separators, const bool enable_nested_fields, std::shared_ptr vq_model = nullptr, spp::sparse_hash_map referenced_in = spp::sparse_hash_map()); ~Collection(); static std::string get_next_seq_id_key(const std::string & collection_name); static std::string get_meta_key(const std::string & collection_name); static std::string get_override_key(const std::string & collection_name, const std::string & override_id); std::string get_seq_id_collection_prefix() const; std::string get_name() const; uint64_t get_created_at() const; uint32_t get_collection_id() const; uint32_t get_next_seq_id(); Option doc_id_to_seq_id_with_lock(const std::string & doc_id) const; Option doc_id_to_seq_id(const std::string & doc_id) const; std::vector get_facet_fields(); std::vector get_sort_fields(); std::vector get_fields(); bool contains_field(const std::string&); std::unordered_map get_dynamic_fields(); tsl::htrie_map get_schema(); tsl::htrie_map get_nested_fields(); tsl::htrie_map get_embedding_fields(); tsl::htrie_map get_embedding_fields_unsafe(); tsl::htrie_set get_object_reference_helper_fields(); std::string get_default_sorting_field(); static Option add_reference_helper_fields(nlohmann::json& document, const tsl::htrie_map& schema, const spp::sparse_hash_map& reference_fields, tsl::htrie_set& object_reference_helper_fields, const bool& is_update); Option to_doc(const std::string& json_str, nlohmann::json& document, const index_operation_t& operation, const DIRTY_VALUES dirty_values, const std::string& id=""); static uint32_t get_seq_id_from_key(const std::string & key); Option get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const; Option get_document_from_store(const uint32_t& seq_id, nlohmann::json & document, bool raw_doc = false) const; Option index_in_memory(nlohmann::json & document, uint32_t seq_id, const index_operation_t op, const DIRTY_VALUES& dirty_values); static void remove_flat_fields(nlohmann::json& document); static void remove_reference_helper_fields(nlohmann::json& document); static Option prune_ref_doc(nlohmann::json& doc, const reference_filter_result_t& references, const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, const bool& is_reference_array, const ref_include_exclude_fields& ref_include_exclude); static Option include_references(nlohmann::json& doc, const uint32_t& seq_id, Collection *const collection, const std::map& reference_filter_results, const std::vector& ref_include_exclude_fields_vec); static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0, const std::map& reference_filter_results = {}, Collection *const collection = nullptr, const uint32_t& seq_id = 0, const std::vector& ref_include_exclude_fields_vec = {}); const Index* _get_index() const; bool facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, nlohmann::json &document, std::string &value) const; nlohmann::json get_facet_parent(const std::string& facet_field_name, const nlohmann::json& document) const; static void populate_result_kvs(Topster *topster, std::vector> &result_kvs, const spp::sparse_hash_map& groups_processed, const std::vector& sort_by_fields); void batch_index(std::vector& index_records, std::vector& json_out, size_t &num_indexed, const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size = 200, const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2); bool is_exceeding_memory_threshold() const; void parse_search_query(const std::string &query, std::vector& q_include_tokens, std::vector>& q_exclude_tokens, std::vector>& q_phrases, const std::string& locale, const bool already_segmented, const std::string& stopword_set="") const; // PUBLIC OPERATIONS nlohmann::json get_summary_json() const; size_t batch_index_in_memory(std::vector& index_records, const size_t remote_embedding_batch_size, const size_t remote_embedding_timeout_ms, const size_t remote_embedding_num_tries, const bool generate_embeddings); Option add(const std::string & json_str, const index_operation_t& operation=CREATE, const std::string& id="", const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT); nlohmann::json add_many(std::vector& json_lines, nlohmann::json& document, const index_operation_t& operation=CREATE, const std::string& id="", const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT, const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=200, const size_t remote_embedding_timeout_ms=60000, const size_t remote_embedding_num_tries=2); Option update_matching_filter(const std::string& filter_query, const std::string & json_str, std::string& req_dirty_values, const int batch_size = 1000); Option populate_include_exclude_fields_lk(const spp::sparse_hash_set& include_fields, const spp::sparse_hash_set& exclude_fields, tsl::htrie_set& include_fields_full, tsl::htrie_set& exclude_fields_full) const; void do_housekeeping(); Option search(std::string query, const std::vector & search_fields, const std::string & filter_query, const std::vector & facet_fields, const std::vector & sort_fields, const std::vector& num_typos, size_t per_page = 10, size_t page = 1, token_ordering token_order = FREQUENCY, const std::vector& prefixes = {true}, size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD, const spp::sparse_hash_set & include_fields = spp::sparse_hash_set(), const spp::sparse_hash_set & exclude_fields = spp::sparse_hash_set(), size_t max_facet_values=10, const std::string & simple_facet_query = "", const size_t snippet_threshold = 30, const size_t highlight_affix_num_tokens = 4, const std::string & highlight_full_fields = "", size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD, const std::string& pinned_hits_str="", const std::string& hidden_hits="", const std::vector& group_by_fields={}, size_t group_limit = 3, const std::string& highlight_start_tag="", const std::string& highlight_end_tag="", std::vector raw_query_by_weights={}, size_t limit_hits=UINT32_MAX, bool prioritize_exact_match=true, bool pre_segmented_query=false, bool enable_overrides=true, const std::string& highlight_fields="", const bool exhaustive_search = false, size_t search_stop_millis = 6000*1000, size_t min_len_1typo = 4, size_t min_len_2typo = 7, enable_t split_join_tokens = fallback, size_t max_candidates = 4, const std::vector& infixes = {off}, const size_t max_extra_prefix = INT16_MAX, const size_t max_extra_suffix = INT16_MAX, const size_t facet_query_num_typos = 2, const size_t filter_curated_hits_option = 2, const bool prioritize_token_position = false, const std::string& vector_query_str = "", const bool enable_highlight_v1 = true, const uint64_t search_time_start_us = 0, const text_match_type_t match_type = max_score, const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, const size_t page_offset = 0, facet_index_type_t facet_index_type = HASH, const size_t remote_embedding_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2, const std::string& stopwords_set="", const std::vector& facet_return_parent = {}, const std::vector& ref_include_exclude_fields_vec = {}, const std::string& drop_tokens_mode = "right_to_left", const bool prioritize_num_matching_fields = true, const bool group_missing_values = true, const bool converstaion = false, const std::string& conversation_model_id = "", std::string conversation_id = "", const std::string& override_tags_str = "", const std::string& voice_query = "", bool enable_typos_for_numerical_tokens = true, bool enable_synonyms = true, bool synonym_prefix = false, uint32_t synonym_num_typos = 0, bool enable_lazy_filter = false, bool enable_typos_for_alpha_numerical_tokens = true) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; Option get_reference_filter_ids(const std::string& filter_query, filter_result_t& filter_result, const std::string& reference_field_name) const; Option get(const std::string & id) const; Option remove(const std::string & id, bool remove_from_store = true); Option remove_if_found(uint32_t seq_id, bool remove_from_store = true); size_t get_num_documents() const; DIRTY_VALUES parse_dirty_values_option(std::string& dirty_values) const; std::vector get_symbols_to_index(); std::vector get_token_separators(); std::string get_fallback_field_type(); bool get_enable_nested_fields(); std::shared_ptr get_vq_model(); Option parse_facet(const std::string& facet_field, std::vector& facets) const; // Override operations Option add_override(const override_t & override, bool write_to_store = true); Option remove_override(const std::string & id); Option> get_overrides(uint32_t limit=0, uint32_t offset=0); Option get_override(const std::string& override_id); // synonym operations Option> get_synonyms(uint32_t limit=0, uint32_t offset=0); bool get_synonym(const std::string& id, synonym_t& synonym); Option add_synonym(const nlohmann::json& syn_json, bool write_to_store = true); Option remove_synonym(const std::string & id); void synonym_reduction(const std::vector& tokens, std::vector>& results, bool synonym_prefix = false, uint32_t synonym_num_typos = 0) const; SynonymIndex* get_synonym_index(); spp::sparse_hash_map get_reference_fields(); // highlight ops static void highlight_text(const std::string& highlight_start_tag, const std::string& highlight_end_tag, const std::string& text, const std::map& token_offsets, size_t snippet_end_offset, std::vector& matched_tokens, std::map::iterator& offset_it, std::stringstream& highlighted_text, const uint8_t* index_symbols, size_t snippet_start_offset) ; void process_highlight_fields(const std::vector& search_fields, const std::vector& raw_search_fields, const tsl::htrie_set& include_fields, const tsl::htrie_set& exclude_fields, const std::vector& highlight_field_names, const std::vector& highlight_full_field_names, const std::vector& infixes, std::vector& q_tokens, const tsl::htrie_map& qtoken_set, std::vector& highlight_items) const; static void copy_highlight_doc(std::vector& hightlight_items, const bool nested_fields_enabled, const nlohmann::json& src, nlohmann::json& dst); Option alter(nlohmann::json& alter_payload); void process_search_field_weights(const std::vector& search_fields, std::vector& query_by_weights, std::vector& weighted_search_fields) const; Option truncate_after_top_k(const std::string& field_name, size_t k); void reference_populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, std::vector& sort_fields_std, std::array*, 3>& field_values) const; int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const; bool is_referenced_in(const std::string& collection_name) const; void add_referenced_in(const reference_pair& pair); void add_referenced_ins(const std::set& pairs); void add_referenced_in(const std::string& collection_name, const std::string& field_name); Option get_referenced_in_field_with_lock(const std::string& collection_name) const; Option get_related_ids_with_lock(const std::string& field_name, const uint32_t& seq_id, std::vector& result) const; Option get_sort_index_value_with_lock(const std::string& field_name, const uint32_t& seq_id) const; static void hide_credential(nlohmann::json& json, const std::string& credential_name); friend class filter_result_iterator_t; std::shared_mutex& get_lifecycle_mutex(); void expand_search_query(const std::string& raw_query, size_t offset, size_t total, const search_args* search_params, const std::vector>& result_group_kvs, const std::vector& raw_search_fields, std::string& first_q) const; }; template bool Collection::highlight_nested_field(const nlohmann::json& hdoc, nlohmann::json& hobj, std::vector& path_parts, size_t path_index, bool is_arr_obj_ele, int array_index, T func) { if(path_index == path_parts.size()) { func(hobj, is_arr_obj_ele, array_index); return true; } const std::string& fragment = path_parts[path_index]; const auto& it = hobj.find(fragment); if(it != hobj.end()) { if(it.value().is_array()) { bool resolved = false; for(size_t i = 0; i < it.value().size(); i++) { auto& h_ele = it.value().at(i); is_arr_obj_ele = is_arr_obj_ele || h_ele.is_object(); resolved = highlight_nested_field(hdoc, h_ele, path_parts, path_index + 1, is_arr_obj_ele, i, func) || resolved; } return resolved; } else { return highlight_nested_field(hdoc, it.value(), path_parts, path_index + 1, is_arr_obj_ele, 0, func); } } { return false; } }