diff --git a/include/collection.h b/include/collection.h index 30eed5f6..c6ab24af 100644 --- a/include/collection.h +++ b/include/collection.h @@ -388,6 +388,8 @@ public: void synonym_reduction(const std::vector& tokens, std::vector>& results) const; + SynonymIndex* get_synonym_index(); + // highlight ops static void highlight_text(const string& highlight_start_tag, const string& highlight_end_tag, diff --git a/include/collection_manager.h b/include/collection_manager.h index c0e1dbd0..78feeaa2 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -118,6 +118,8 @@ public: const StoreStatus& next_coll_id_status, const std::atomic& quit); + Option clone_collection(const std::string& existing_name, const nlohmann::json& req_json); + void add_to_collections(Collection* collection); std::vector get_collections() const; diff --git a/src/collection.cpp b/src/collection.cpp index 4cc94ae7..03c585cc 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2637,19 +2637,19 @@ Option Collection::parse_pinned_hits(const std::string& pinned_hits_str, } if(index == 0) { - return Option(false, "Pinned hits are not in expected format."); + return Option(400, "Pinned hits are not in expected format."); } std::string pinned_id = pinned_hits_part.substr(0, index); std::string pinned_pos = pinned_hits_part.substr(index+1); if(!StringUtils::is_positive_integer(pinned_pos)) { - return Option(false, "Pinned hits are not in expected format."); + return Option(400, "Pinned hits are not in expected format."); } int position = std::stoi(pinned_pos); if(position == 0) { - return Option(false, "Pinned hits must start from position 1."); + return Option(400, "Pinned hits must start from position 1."); } pinned_hits[position].emplace_back(pinned_id); @@ -2692,6 +2692,10 @@ spp::sparse_hash_map Collection::get_synonyms() { return synonym_index->get_synonyms(); } +SynonymIndex* Collection::get_synonym_index() { + return synonym_index; +} + Option Collection::persist_collection_meta() { std::string coll_meta_json; StoreStatus status = store->get(Collection::get_meta_key(name), coll_meta_json); @@ -2777,7 +2781,7 @@ Option Collection::batch_alter_data(const std::unordered_mapvalue().ToString()); } catch(const std::exception& e) { - return Option(false, "Bad JSON in document: " + document.dump(-1, ' ', false, + return Option(400, "Bad JSON in document: " + document.dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore)); } @@ -3070,7 +3074,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, try { document = nlohmann::json::parse(iter->value().ToString()); } catch(const std::exception& e) { - return Option(false, "Bad JSON in document: " + document.dump(-1, ' ', false, + return Option(400, "Bad JSON in document: " + document.dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore)); } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 8b293309..bb3b7801 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1163,7 +1163,7 @@ Option CollectionManager::load_collection(const nlohmann::json &collection document = nlohmann::json::parse(iter->value().ToString()); } catch(const std::exception& e) { LOG(ERROR) << "JSON error: " << e.what(); - return Option(false, "Bad JSON."); + return Option(400, "Bad JSON."); } auto dirty_values = DIRTY_VALUES::DROP; @@ -1185,7 +1185,7 @@ Option CollectionManager::load_collection(const nlohmann::json &collection if(num_indexed != num_records) { const Option & index_error_op = get_first_index_error(index_records); if(!index_error_op.ok()) { - return Option(false, index_error_op.get()); + return Option(400, index_error_op.get()); } } @@ -1261,3 +1261,62 @@ Option CollectionManager::delete_preset(const string& preset_name) { preset_configs.erase(preset_name); return Option(true); } + +Option CollectionManager::clone_collection(const string& existing_name, const nlohmann::json& req_json) { + std::shared_lock lock(mutex); + + if(collections.count(existing_name) == 0) { + return Option(400, "Collection with name `" + existing_name + "` not found."); + } + + if(req_json.count("name") == 0 || !req_json["name"].is_string()) { + return Option(400, "Collection name must be provided."); + } + + const std::string& new_name = req_json["name"].get(); + + if(collections.count(new_name) != 0) { + return Option(400, "Collection with name `" + new_name + "` already exists."); + } + + Collection* existing_coll = collections[existing_name]; + + std::vector symbols_to_index; + std::vector token_separators; + + for(auto c: existing_coll->get_symbols_to_index()) { + symbols_to_index.emplace_back(1, c); + } + + for(auto c: existing_coll->get_token_separators()) { + token_separators.emplace_back(1, c); + } + + lock.unlock(); + + auto coll_create_op = create_collection(new_name, DEFAULT_NUM_MEMORY_SHARDS, existing_coll->get_fields(), + existing_coll->get_default_sorting_field(), static_cast(std::time(nullptr)), + existing_coll->get_fallback_field_type(), symbols_to_index, token_separators); + + lock.lock(); + + if(!coll_create_op.ok()) { + return Option(coll_create_op.code(), coll_create_op.error()); + } + + Collection* new_coll = coll_create_op.get(); + + // copy synonyms + auto synonyms = existing_coll->get_synonyms(); + for(const auto& synonym: synonyms) { + new_coll->get_synonym_index()->add_synonym(new_name, synonym.second); + } + + // copy overrides + auto overrides = existing_coll->get_overrides(); + for(const auto& override: overrides) { + new_coll->add_override(override.second); + } + + return Option(new_coll); +} diff --git a/src/core_api.cpp b/src/core_api.cpp index 60f25f75..c13f9ed7 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -165,8 +165,12 @@ bool post_create_collection(const std::shared_ptr& req, const std::sha return false; } - CollectionManager & collectionManager = CollectionManager::get_instance(); - const Option & collection_op = collectionManager.create_collection(req_json); + const std::string SRC_COLL_NAME = "src_name"; + + CollectionManager& collectionManager = CollectionManager::get_instance(); + const Option &collection_op = req->params.count(SRC_COLL_NAME) != 0 ? + collectionManager.clone_collection(req->params[SRC_COLL_NAME], req_json) : + CollectionManager::create_collection(req_json); if(collection_op.ok()) { nlohmann::json json_response = collection_op.get()->get_summary_json(); diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index b8faad9e..2b40068f 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -826,3 +826,60 @@ TEST_F(CollectionManagerTest, Presets) { preset_op = collectionManager.get_preset("preset1", preset); ASSERT_TRUE(preset_op.ok()); } + +TEST_F(CollectionManagerTest, CloneCollection) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string"} + ], + "symbols_to_index":["+"], + "token_separators":["-", "?"] + })"_json; + + auto create_op = collectionManager.create_collection(schema); + ASSERT_TRUE(create_op.ok()); + auto coll1 = create_op.get(); + + nlohmann::json synonym1 = R"({ + "id": "ipod-synonyms", + "synonyms": ["ipod", "i pod", "pod"] + })"_json; + + ASSERT_TRUE(coll1->add_synonym(synonym1).ok()); + + nlohmann::json override_json = { + {"id", "dynamic-cat-filter"}, + { + "rule", { + {"query", "{categories}"}, + {"match", override_t::MATCH_EXACT} + } + }, + {"remove_matched_tokens", true}, + {"filter_by", "category: {categories}"} + }; + + override_t override; + auto op = override_t::parse(override_json, "dynamic-cat-filter", override); + ASSERT_TRUE(op.ok()); + coll1->add_override(override); + + nlohmann::json req = R"({"name": "coll2"})"_json; + collectionManager.clone_collection("coll1", req); + + auto coll2 = collectionManager.get_collection_unsafe("coll2"); + ASSERT_FALSE(coll2 == nullptr); + ASSERT_EQ("coll2", coll2->get_name()); + ASSERT_EQ(1, coll2->get_fields().size()); + ASSERT_EQ(1, coll2->get_synonyms().size()); + ASSERT_EQ(1, coll2->get_overrides().size()); + ASSERT_EQ("", coll2->get_fallback_field_type()); + + ASSERT_EQ(1, coll2->get_symbols_to_index().size()); + ASSERT_EQ(2, coll2->get_token_separators().size()); + + ASSERT_EQ('+', coll2->get_symbols_to_index().at(0)); + ASSERT_EQ('-', coll2->get_token_separators().at(0)); + ASSERT_EQ('?', coll2->get_token_separators().at(1)); +}