diff --git a/include/collection.h b/include/collection.h index 2b658bb6..90586fc1 100644 --- a/include/collection.h +++ b/include/collection.h @@ -354,6 +354,8 @@ public: tsl::htrie_map get_embedding_fields(); + tsl::htrie_map get_embedding_fields_unsafe(); + std::string get_default_sorting_field(); Option to_doc(const std::string& json_str, nlohmann::json& document, @@ -566,8 +568,6 @@ public: std::vector& reordered_search_fields) const; Option truncate_after_top_k(const std::string& field_name, size_t k); - - static void process_embedding_field_delete(const std::string& model_name); }; template diff --git a/include/collection_manager.h b/include/collection_manager.h index 03ef6b43..b92f29e5 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -201,4 +201,6 @@ public: Option upsert_preset(const std::string & preset_name, const nlohmann::json& preset_config); Option delete_preset(const std::string & preset_name); + + void process_embedding_field_delete(const std::string& model_name); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 38133b3f..31b7cc34 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -3801,7 +3801,10 @@ Option Collection::batch_alter_data(const std::vector& alter_fields auto model_name = f.embed[fields::model_config][fields::model_name].get(); if(text_embedders.count(model_name) == 0) { size_t dummy_num_dim = 0; - TextEmbedderManager::get_instance().validate_and_init_model(f.embed[fields::model_config], dummy_num_dim); + auto validate_model_res = TextEmbedderManager::get_instance().validate_and_init_model(f.embed[fields::model_config], dummy_num_dim); + if(!validate_model_res.ok()) { + return Option(validate_model_res.code(), validate_model_res.error()); + } } embedding_fields.emplace(f.name, f); } @@ -5037,34 +5040,12 @@ void Collection::remove_embedding_field(const std::string& field_name) { return; } - auto del_field = embedding_fields[field_name]; - + const auto& del_field = embedding_fields[field_name]; + const auto& model_name = del_field.embed[fields::model_config]["model_name"].get(); embedding_fields.erase(field_name); - - auto model_name = del_field.embed[fields::model_config]["model_name"].get(); - process_embedding_field_delete(model_name); + CollectionManager::get_instance().process_embedding_field_delete(model_name); } -void Collection::process_embedding_field_delete(const std::string& model_name) { - auto collections = CollectionManager::get_instance().get_collections(); - bool found = false; - - for(const auto& collection: collections) { - auto embedding_fields_other = collection->embedding_fields; - - for(auto& embedding_field: embedding_fields_other) { - if(embedding_field.embed.count(fields::model_config) != 0) { - auto model_config = embedding_field.embed[fields::model_config]; - if(model_config["model_name"].get() == model_name) { - found = true; - break; - } - } - } - } - - if(!found) { - LOG(INFO) << "Deleting text embedder: " << model_name; - TextEmbedderManager::get_instance().delete_text_embedder(model_name); - } +tsl::htrie_map Collection::get_embedding_fields_unsafe() { + return embedding_fields; } \ No newline at end of file diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index fc2992d7..8c1b27aa 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -532,16 +532,15 @@ Option CollectionManager::drop_collection(const std::string& col collections.erase(actual_coll_name); collection_id_names.erase(collection->get_collection_id()); - - u_lock.unlock(); - const auto& embedding_fields = collection->get_embedding_fields(); for(const auto& embedding_field : embedding_fields) { - auto model_name = embedding_field.embed[fields::model_config]["model_name"].get(); - Collection::process_embedding_field_delete(model_name); + const auto& model_name = embedding_field.embed[fields::model_config]["model_name"].get(); + process_embedding_field_delete(model_name); } + u_lock.unlock(); + // don't hold any collection manager locks here, since this can take some time delete collection; @@ -1564,3 +1563,31 @@ Option CollectionManager::clone_collection(const string& existing_n return Option(new_coll); } + +void CollectionManager::process_embedding_field_delete(const std::string& model_name) { + // Can'T have a shared lock here + // because we will be already acquiring a lock on collection manager if we are deleting a collection + //std::shared_lock lock(mutex); + bool found = false; + + for(const auto& collection: collections) { + // will be deadlock if we try to acquire lock on collection here + // caller of this function should have already acquired lock on collection + const auto& embedding_fields = collection.second->get_embedding_fields_unsafe(); + + for(const auto& embedding_field: embedding_fields) { + if(embedding_field.embed.count(fields::model_config) != 0) { + const auto& model_config = embedding_field.embed[fields::model_config]; + if(model_config["model_name"].get() == model_name) { + found = true; + break; + } + } + } + } + + if(!found) { + LOG(INFO) << "Deleting text embedder: " << model_name; + TextEmbedderManager::get_instance().delete_text_embedder(model_name); + } +} \ No newline at end of file