diff --git a/src/field.cpp b/src/field.cpp index f892a9f5..bf979c56 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -181,10 +181,18 @@ Option toFilter(const std::string expression, } return Option(true); } - if (search_schema.count(field_name) == 0) { + + auto field_it = search_schema.find(field_name); + + if (field_it == search_schema.end()) { return Option(404, "Could not find a filter field named `" + field_name + "` in the schema."); } - field _field = search_schema.at(field_name); + + if (field_it->num_dim > 0) { + return Option(404, "Cannot filter on vector field `" + field_name + "`."); + } + + const field& _field = field_it.value(); std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); StringUtils::trim(raw_value); // skip past optional `:=` operator, which has no meaning for non-string fields @@ -568,7 +576,11 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso if(field_json["type"] == field_types::INT32 || field_json["type"] == field_types::INT64 || field_json["type"] == field_types::FLOAT || field_json["type"] == field_types::BOOL || field_json["type"] == field_types::GEOPOINT || field_json["type"] == field_types::GEOPOINT_ARRAY) { - field_json[fields::sort] = true; + if(field_json.count(fields::num_dim) == 0) { + field_json[fields::sort] = true; + } else { + field_json[fields::sort] = false; + } } else { field_json[fields::sort] = false; } @@ -592,6 +604,14 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso return Option(400, "Property `" + fields::num_dim + "` is only allowed on a float array field."); } + if(field_json[fields::facet].get()) { + return Option(400, "Property `" + fields::facet + "` is not allowed on a vector field."); + } + + if(field_json[fields::sort].get()) { + return Option(400, "Property `" + fields::sort + "` cannot be enabled on a vector field."); + } + if(field_json.count(fields::vec_dist) == 0) { field_json[fields::vec_dist] = DEFAULT_VEC_DIST_METRIC; } else { diff --git a/src/index.cpp b/src/index.cpp index 7413d94d..1e60a32d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -64,6 +64,12 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* continue; } + if(a_field.num_dim > 0) { + auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024, a_field.vec_dist); + vector_index.emplace(a_field.name, hnsw_index); + continue; + } + if(a_field.is_string()) { art_tree *t = new art_tree; art_tree_init(t); @@ -116,11 +122,6 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* infix_index.emplace(a_field.name, infix_sets); } - - if(a_field.num_dim) { - auto hnsw_index = new hnsw_index_t(a_field.num_dim, 1024, a_field.vec_dist); - vector_index.emplace(a_field.name, hnsw_index); - } } num_documents = 0; @@ -917,9 +918,77 @@ void Index::index_field_in_memory(const field& afield, std::vector geo_array_index.at(afield.name)->emplace(seq_id, packed_latlongs); }); } else if(afield.is_array()) { + // handle vector index first + if(afield.type == field_types::FLOAT_ARRAY && afield.num_dim > 0) { + auto vec_index = vector_index[afield.name]->vecdex; + size_t curr_ele_count = vec_index->getCurrentElementCount(); + if(curr_ele_count + iter_batch.size() > vec_index->getMaxElements()) { + vec_index->resizeIndex((curr_ele_count + iter_batch.size()) * 1.3); + } + + const size_t num_threads = std::min(4, iter_batch.size()); + const size_t window_size = (num_threads == 0) ? 0 : + (iter_batch.size() + num_threads - 1) / num_threads; // rounds up + size_t num_processed = 0; + std::mutex m_process; + std::condition_variable cv_process; + + size_t num_queued = 0; + size_t result_index = 0; + + for(size_t thread_id = 0; thread_id < num_threads && result_index < iter_batch.size(); thread_id++) { + size_t batch_len = window_size; + + if(result_index + window_size > iter_batch.size()) { + batch_len = iter_batch.size() - result_index; + } + + num_queued++; + + thread_pool->enqueue([thread_id, &afield, &vec_index, &records = iter_batch, + result_index, batch_len, &num_processed, &m_process, &cv_process]() { + + size_t batch_counter = 0; + while(batch_counter < batch_len) { + auto& record = records[result_index + batch_counter]; + if(record.doc.count(afield.name) == 0) { + batch_counter++; + continue; + } + + const std::vector& float_vals = record.doc[afield.name].get>(); + + try { + if(afield.vec_dist == cosine) { + std::vector normalized_vals(afield.num_dim); + hnsw_index_t::normalize_vector(float_vals, normalized_vals); + vec_index->insertPoint(normalized_vals.data(), (size_t)record.seq_id); + } else { + vec_index->insertPoint(float_vals.data(), (size_t)record.seq_id); + } + } catch(const std::exception &e) { + record.index_failure(400, e.what()); + } + + batch_counter++; + } + + std::unique_lock lock(m_process); + num_processed++; + cv_process.notify_one(); + }); + + result_index += batch_len; + } + + std::unique_lock lock_process(m_process); + cv_process.wait(lock_process, [&](){ return num_processed == num_queued; }); + return; + } + // all other numerical arrays auto num_tree = numerical_index.at(afield.name); - iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, &vector_index=vector_index] + iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) { const auto& arr_value = record.doc[afield.name][arr_i]; @@ -945,23 +1014,6 @@ void Index::index_field_in_memory(const field& afield, std::vector num_tree->insert(int64_t(value), seq_id); } } - - if(afield.type == field_types::FLOAT_ARRAY && afield.num_dim > 0) { - auto vec_index = vector_index[afield.name]->vecdex; - size_t curr_ele_count = vec_index->getCurrentElementCount(); - if(curr_ele_count == vec_index->getMaxElements()) { - vec_index->resizeIndex(curr_ele_count * 1.3); - } - - const std::vector& float_vals = record.doc[afield.name].get>(); - if(afield.vec_dist == cosine) { - std::vector normalized_vals(afield.num_dim); - hnsw_index_t::normalize_vector(float_vals, normalized_vals); - vector_index[afield.name]->vecdex->insertPoint(normalized_vals.data(), (size_t)seq_id); - } else { - vector_index[afield.name]->vecdex->insertPoint(float_vals.data(), (size_t)seq_id); - } - } }); } @@ -5148,10 +5200,13 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const remove_facet_token(search_field, search_index, std::to_string(value), seq_id); } } + } else if(search_field.num_dim) { + vector_index[search_field.name]->vecdex->markDelete(seq_id); } else if(search_field.is_float()) { const std::vector& values = search_field.is_single_float() ? std::vector{document[field_name].get()} : document[field_name].get>(); + for(float value: values) { num_tree_t* num_tree = numerical_index.at(field_name); int64_t fintval = float_to_int64_t(value); @@ -5160,10 +5215,6 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id); } } - - if(search_field.num_dim) { - vector_index[search_field.name]->vecdex->markDelete(seq_id); - } } else if(search_field.is_bool()) { const std::vector& values = search_field.is_single_bool() ? @@ -5322,6 +5373,12 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec search_schema.emplace(new_field.name, new_field); + if(new_field.type == field_types::FLOAT_ARRAY && new_field.num_dim > 0) { + auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024, new_field.vec_dist); + vector_index.emplace(new_field.name, hnsw_index); + continue; + } + if(new_field.is_sortable()) { if(new_field.is_num_sortable()) { spp::sparse_hash_map * doc_to_score = new spp::sparse_hash_map(); @@ -5373,11 +5430,6 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec infix_index.emplace(new_field.name, infix_sets); } - - if(new_field.type == field_types::FLOAT_ARRAY && new_field.num_dim) { - auto hnsw_index = new hnsw_index_t(new_field.num_dim, 1024, new_field.vec_dist); - vector_index.emplace(new_field.name, hnsw_index); - } } for(const auto & del_field: del_fields) { diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 6467068e..68372ac7 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -336,6 +336,7 @@ TEST_F(CollectionVectorTest, VecSearchWithFilteringWithMissingVectorValues) { std::uniform_real_distribution<> distrib; size_t num_docs = 20; + std::vector json_lines; for (size_t i = 0; i < num_docs; i++) { nlohmann::json doc; @@ -351,9 +352,14 @@ TEST_F(CollectionVectorTest, VecSearchWithFilteringWithMissingVectorValues) { if(i != 5 && i != 15) { doc["vec"] = values; } - ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + json_lines.push_back(doc.dump()); } + nlohmann::json insert_doc; + auto res = coll1->add_many(json_lines, insert_doc, UPSERT); + ASSERT_TRUE(res["success"].get()); + auto results = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 5, @@ -418,6 +424,41 @@ TEST_F(CollectionVectorTest, VecSearchWithFilteringWithMissingVectorValues) { ASSERT_EQ(1, results["found"].get()); ASSERT_EQ(1, results["hits"].size()); + + ASSERT_EQ(1, coll1->_get_index()->_get_numerical_index().size()); + ASSERT_EQ(1, coll1->_get_index()->_get_numerical_index().count("points")); + + // should not be able to filter / sort / facet on vector fields + auto res_op = coll1->search("*", {}, "vec:1", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set()); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Cannot filter on vector field `vec`.", res_op.error()); + + schema = R"({ + "name": "coll2", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "vec", "type": "float[]", "num_dim": 4, "facet": true} + ] + })"_json; + + auto coll_op = collectionManager.create_collection(schema); + ASSERT_FALSE(coll_op.ok()); + ASSERT_EQ("Property `facet` is not allowed on a vector field.", coll_op.error()); + + schema = R"({ + "name": "coll2", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "vec", "type": "float[]", "num_dim": 4, "sort": true} + ] + })"_json; + + coll_op = collectionManager.create_collection(schema); + ASSERT_FALSE(coll_op.ok()); + ASSERT_EQ("Property `sort` cannot be enabled on a vector field.", coll_op.error()); } TEST_F(CollectionVectorTest, VectorSearchTestDeletion) {