Refactor referenced_in initialization of a collection. (#1585)

* Refactor `referenced_in` initialization of a collection.

* Review changes.

* Review changes.

---------

Co-authored-by: Kishore Nallan <kishorenc@gmail.com>
This commit is contained in:
Harpreet Sangar 2024-02-28 20:48:31 +05:30 committed by GitHub
parent 482858d05d
commit 1f6fbed372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 111 additions and 55 deletions

View File

@ -8,16 +8,15 @@
#include <mutex>
#include <condition_variable>
#include <shared_mutex>
#include <art.h>
#include <index.h>
#include <number.h>
#include <sparsepp.h>
#include <store.h>
#include <topster.h>
#include <json.hpp>
#include <field.h>
#include <option.h>
#include <tsl/htrie_map.h>
#include "art.h"
#include "index.h"
#include "number.h"
#include "store.h"
#include "topster.h"
#include "json.hpp"
#include "field.h"
#include "option.h"
#include "tsl/htrie_map.h"
#include "tokenizer.h"
#include "synonym_index.h"
#include "vq_model_manager.h"
@ -186,7 +185,7 @@ private:
bool does_override_match(const override_t& override, std::string& query,
std::set<uint32_t>& excluded_set,
string& actual_query, const string& filter_query,
std::string& actual_query, const std::string& filter_query,
bool already_segmented,
const bool tags_matched,
const bool wildcard_tag_matched,
@ -199,7 +198,7 @@ private:
std::string& curated_sort_by,
nlohmann::json& override_metadata) const;
void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented,
void curate_results(std::string& actual_query, const std::string& filter_query, bool enable_overrides, bool already_segmented,
const std::set<std::string>& tags,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
@ -384,7 +383,8 @@ public:
const std::string& default_sorting_field,
const float max_memory_ratio, const std::string& fallback_field_type,
const std::vector<std::string>& symbols_to_index, const std::vector<std::string>& token_separators,
const bool enable_nested_fields, std::shared_ptr<VQModel> vq_model = nullptr);
const bool enable_nested_fields, std::shared_ptr<VQModel> vq_model = nullptr,
spp::sparse_hash_map<std::string, std::string> referenced_in = spp::sparse_hash_map<std::string, std::string>());
~Collection();
@ -644,8 +644,8 @@ public:
// highlight ops
static void highlight_text(const string& highlight_start_tag, const string& highlight_end_tag,
const string& text, const std::map<size_t, size_t>& token_offsets,
static void highlight_text(const std::string& highlight_start_tag, const std::string& highlight_end_tag,
const std::string& text, const std::map<size_t, size_t>& token_offsets,
size_t snippet_end_offset,
std::vector<std::string>& matched_tokens, std::map<size_t, size_t>::iterator& offset_it,
std::stringstream& highlighted_text,
@ -703,9 +703,9 @@ public:
std::shared_mutex& get_lifecycle_mutex();
void expand_search_query(const string& raw_query, size_t offset, size_t total, const search_args* search_params,
void expand_search_query(const std::string& raw_query, size_t offset, size_t total, const search_args* search_params,
const std::vector<std::vector<KV*>>& result_group_kvs,
const std::vector<std::string>& raw_search_fields, string& first_q) const;
const std::vector<std::string>& raw_search_fields, std::string& first_q) const;
};
template<class T>

View File

@ -125,12 +125,14 @@ public:
static Collection* init_collection(const nlohmann::json & collection_meta,
const uint32_t collection_next_seq_id,
Store* store,
float max_memory_ratio);
float max_memory_ratio,
spp::sparse_hash_map<std::string, std::string>& referenced_in);
static Option<bool> load_collection(const nlohmann::json& collection_meta,
const size_t batch_size,
const StoreStatus& next_coll_id_status,
const std::atomic<bool>& quit);
const std::atomic<bool>& quit,
spp::sparse_hash_map<std::string, std::string>& referenced_in);
Option<Collection*> clone_collection(const std::string& existing_name, const nlohmann::json& req_json);
@ -233,4 +235,7 @@ public:
std::map<std::string, std::set<reference_pair>> _get_referenced_in_backlog() const;
void process_embedding_field_delete(const std::string& model_name);
static void _populate_referenced_ins(const std::string& collection_meta_json,
std::map<std::string, spp::sparse_hash_map<std::string, std::string>>& referenced_ins);
};

View File

@ -47,14 +47,16 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co
const float max_memory_ratio, const std::string& fallback_field_type,
const std::vector<std::string>& symbols_to_index,
const std::vector<std::string>& token_separators,
const bool enable_nested_fields, std::shared_ptr<VQModel> vq_model) :
const bool enable_nested_fields, std::shared_ptr<VQModel> vq_model,
spp::sparse_hash_map<std::string, std::string> referenced_in) :
name(name), collection_id(collection_id), created_at(created_at),
next_seq_id(next_seq_id), store(store),
fields(fields), default_sorting_field(default_sorting_field), enable_nested_fields(enable_nested_fields),
max_memory_ratio(max_memory_ratio),
fallback_field_type(fallback_field_type), dynamic_fields({}),
symbols_to_index(to_char_array(symbols_to_index)), token_separators(to_char_array(token_separators)),
index(init_index()), vq_model(vq_model) {
index(init_index()), vq_model(vq_model),
referenced_in(std::move(referenced_in)) {
if (vq_model) {
vq_model->inc_collection_ref_count();

View File

@ -21,7 +21,8 @@ CollectionManager::CollectionManager() {
Collection* CollectionManager::init_collection(const nlohmann::json & collection_meta,
const uint32_t collection_next_seq_id,
Store* store,
float max_memory_ratio) {
float max_memory_ratio,
spp::sparse_hash_map<std::string, std::string>& referenced_in) {
std::string this_collection_name = collection_meta[Collection::COLLECTION_NAME_KEY].get<std::string>();
std::vector<field> fields;
@ -193,7 +194,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
fallback_field_type,
symbols_to_index,
token_separators,
enable_nested_fields, model);
enable_nested_fields, model, std::move(referenced_in));
return collection;
}
@ -228,6 +229,32 @@ void CollectionManager::init(Store *store, const float max_memory_ratio, const s
init(store, thread_pool, max_memory_ratio, auth_key, quit, nullptr);
}
void CollectionManager::_populate_referenced_ins(const std::string& collection_meta_json,
std::map<std::string, spp::sparse_hash_map<std::string, std::string>>& referenced_ins) {
auto const& obj = nlohmann::json::parse(collection_meta_json, nullptr, false);
if (!obj.is_discarded() && obj.is_object() && obj.contains("name") && obj["name"].is_string() &&
obj.contains("fields")) {
auto const& collection_name = obj["name"];
for (const auto &field: obj["fields"]) {
if (!field.contains("name") || !field.contains("reference")) {
continue;
}
auto field_name = std::string(field["name"]) + fields::REFERENCE_HELPER_FIELD_SUFFIX;
std::vector<std::string> split_result;
StringUtils::split(field["reference"], split_result, ".");
auto const& ref_coll_name = split_result.front();
if (referenced_ins.count(ref_coll_name) == 0) {
referenced_ins.emplace(ref_coll_name, spp::sparse_hash_map<std::string, std::string>());
}
referenced_ins[ref_coll_name].emplace(collection_name, field_name);
}
}
}
Option<bool> CollectionManager::load(const size_t collection_batch_size, const size_t document_batch_size) {
// This function must be idempotent, i.e. when called multiple times, must produce the same state without leaks
LOG(INFO) << "CollectionManager::load()";
@ -263,11 +290,16 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
ThreadPool loading_pool(collection_batch_size);
// Collection name -> Ref collection name -> Ref field name
std::map<std::string, spp::sparse_hash_map<std::string, std::string>> referenced_ins;
for (const auto &collection_meta_json: collection_meta_jsons) {
_populate_referenced_ins(collection_meta_json, referenced_ins);
}
size_t num_processed = 0;
// Collection name -> Referenced in
std::map<std::string, std::set<reference_pair>> referenced_ins = {};
std::mutex m_process;
std::condition_variable cv_process;
std::string collection_name;
for(size_t coll_index = 0; coll_index < num_collections; coll_index++) {
const auto& collection_meta_json = collection_meta_jsons[coll_index];
@ -277,13 +309,16 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
return Option<bool>(500, "Error while parsing collection meta.");
}
collection_name = collection_meta[Collection::COLLECTION_NAME_KEY].get<std::string>();
auto captured_store = store;
loading_pool.enqueue([captured_store, num_collections, collection_meta, document_batch_size,
&m_process, &cv_process, &num_processed, &next_coll_id_status, quit = quit,
&referenced_ins]() {
&referenced_ins, collection_name]() {
//auto begin = std::chrono::high_resolution_clock::now();
Option<bool> res = load_collection(collection_meta, document_batch_size, next_coll_id_status, *quit);
Option<bool> res = load_collection(collection_meta, document_batch_size, next_coll_id_status, *quit,
referenced_ins[collection_name]);
/*long long int timeMillis =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - begin).count();
LOG(INFO) << "Time taken for indexing: " << timeMillis << "ms";*/
@ -299,18 +334,6 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
num_processed++;
auto& cm = CollectionManager::get_instance();
auto const& collection_name = collection_meta.at("name");
auto collection = cm.get_collection(collection_name);
if (collection != nullptr) {
for (const auto &item: collection->get_reference_fields()) {
auto const& ref_coll_name = item.second.collection;
if (referenced_ins.count(ref_coll_name) == 0) {
referenced_ins[ref_coll_name] = {};
}
auto const field_name = item.first + fields::REFERENCE_HELPER_FIELD_SUFFIX;
referenced_ins.at(ref_coll_name).insert(reference_pair{collection_name, field_name});
}
}
cv_process.notify_one();
size_t progress_modulo = std::max<size_t>(1, (num_collections / 10)); // every 10%
@ -326,17 +349,6 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
return num_processed == num_collections;
});
// Initialize references
for (const auto &item: referenced_ins) {
auto& cm = CollectionManager::get_instance();
auto collection = cm.get_collection(item.first);
if (collection != nullptr) {
for (const auto &reference_pair: item.second) {
collection->add_referenced_in(reference_pair);
}
}
}
// load aliases
std::string symlink_prefix_key = std::string(SYMLINK_PREFIX) + "_";
@ -2063,7 +2075,8 @@ Option<Collection*> CollectionManager::create_collection(nlohmann::json& req_jso
Option<bool> CollectionManager::load_collection(const nlohmann::json &collection_meta,
const size_t batch_size,
const StoreStatus& next_coll_id_status,
const std::atomic<bool>& quit) {
const std::atomic<bool>& quit,
spp::sparse_hash_map<std::string, std::string>& referenced_in) {
auto& cm = CollectionManager::get_instance();
@ -2109,7 +2122,7 @@ Option<bool> CollectionManager::load_collection(const nlohmann::json &collection
}
}
Collection* collection = init_collection(collection_meta, collection_next_seq_id, cm.store, 1.0f);
Collection* collection = init_collection(collection_meta, collection_next_seq_id, cm.store, 1.0f, referenced_in);
LOG(INFO) << "Loading collection " << collection->get_name();

View File

@ -313,7 +313,8 @@ TEST_F(CollectionManagerTest, ShouldInitCollection) {
nlohmann::json::parse("{\"name\": \"foobar\", \"id\": 100, \"fields\": [{\"name\": \"org\", \"type\": "
"\"string\", \"facet\": false}], \"default_sorting_field\": \"foo\"}");
Collection *collection = collectionManager.init_collection(collection_meta1, 100, store, 1.0f);
spp::sparse_hash_map<std::string, std::string> referenced_in;
Collection *collection = collectionManager.init_collection(collection_meta1, 100, store, 1.0f, referenced_in);
ASSERT_EQ("foobar", collection->get_name());
ASSERT_EQ(100, collection->get_collection_id());
ASSERT_EQ(1, collection->get_fields().size());
@ -335,7 +336,7 @@ TEST_F(CollectionManagerTest, ShouldInitCollection) {
"\"symbols_to_index\": [\"+\"], \"token_separators\": [\"-\"]}");
collection = collectionManager.init_collection(collection_meta2, 100, store, 1.0f);
collection = collectionManager.init_collection(collection_meta2, 100, store, 1.0f, referenced_in);
ASSERT_EQ(12345, collection->get_created_at());
std::vector<char> expected_symbols = {'+'};
@ -1909,6 +1910,41 @@ TEST_F(CollectionManagerTest, CollectionCreationWithMetadata) {
ASSERT_EQ(expected_meta_json.dump(), actual_json.dump());
}
TEST_F(CollectionManagerTest, PopulateReferencedIns) {
std::vector<std::string> collection_meta_jsons = {
R"({
"name": "A",
"fields": [
{"name": "a_id", "type": "string"}
]
})"_json.dump(),
R"({
"name": "B",
"fields": [
{"name": "b_id", "type": "string"},
{"name": "b_ref", "type": "string", "reference": "A.a_id"}
]
})"_json.dump(),
R"({
"name": "C",
"fields": [
{"name": "c_id", "type": "string"}
]
})"_json.dump(),
};
std::map<std::string, spp::sparse_hash_map<std::string, std::string>> referenced_ins;
for (const auto &collection_meta_json: collection_meta_jsons) {
CollectionManager::_populate_referenced_ins(collection_meta_json, referenced_ins);
}
ASSERT_EQ(1, referenced_ins.size());
ASSERT_EQ(1, referenced_ins.count("A"));
ASSERT_EQ(1, referenced_ins["A"].size());
ASSERT_EQ(1, referenced_ins["A"].count("B"));
ASSERT_EQ("b_ref_sequence_id", referenced_ins["A"]["B"]);
}
TEST_F(CollectionManagerTest, CollectionPagination) {
//remove all collections first
auto collections = collectionManager.get_collections().get();
@ -1987,4 +2023,4 @@ TEST_F(CollectionManagerTest, CollectionPagination) {
collection_op = collectionManager.get_collections(limit, offset);
ASSERT_FALSE(collection_op.ok());
ASSERT_EQ("Invalid offset param.", collection_op.error());
}
}