Fix vector search filter: don't assume order of IDs.

This commit is contained in:
Kishore Nallan 2022-09-08 16:19:46 +05:30
parent d489702fca
commit ba2d5d10ce
2 changed files with 79 additions and 27 deletions

View File

@ -235,50 +235,28 @@ struct index_record {
class VectorFilterFunctor: public hnswlib::FilterFunctor {
const uint32_t* filter_ids = nullptr;
const uint32_t filter_ids_length = 0;
uint32 filter_ids_index = 0;
public:
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) :
filter_ids(filter_ids), filter_ids_length(filter_ids_length) {}
bool operator()(unsigned int id) {
if(filter_ids_length != 0) {
if(filter_ids_index >= filter_ids_length) {
return false;
}
// Returns iterator to the first element that is >= to value or last if no such element is found.
size_t found_index = std::lower_bound(filter_ids + filter_ids_index,
filter_ids + filter_ids_length, id) - filter_ids;
if(found_index == filter_ids_length) {
// all elements are lesser than lowest value (id), so we can stop looking
filter_ids_index = found_index + 1;
return false;
} else {
if(filter_ids[found_index] == id) {
filter_ids_index = found_index + 1;
return true;
}
filter_ids_index = found_index;
}
return false;
if(filter_ids_length == 0) {
return true;
}
return true;
return std::binary_search(filter_ids, filter_ids + filter_ids_length, id);
}
};
struct hnsw_index_t {
hnswlib::L2Space* space;
hnswlib::InnerProductSpace* space;
hnswlib::HierarchicalNSW<float, VectorFilterFunctor>* vecdex;
size_t num_dim;
vector_distance_type_t distance_type;
hnsw_index_t(size_t num_dim, size_t init_size, vector_distance_type_t distance_type):
space(new hnswlib::L2Space(num_dim)),
space(new hnswlib::InnerProductSpace(num_dim)),
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
num_dim(num_dim), distance_type(distance_type) {

View File

@ -200,3 +200,77 @@ TEST_F(CollectionVectorTest, IndexGreaterThan1KVectors) {
ASSERT_EQ(1500, results["found"].get<size_t>());
}
TEST_F(CollectionVectorTest, VecSearchWithFiltering) {
nlohmann::json schema = R"({
"name": "coll1",
"fields": [
{"name": "title", "type": "string"},
{"name": "points", "type": "int32"},
{"name": "vec", "type": "float[]", "num_dim": 4}
]
})"_json;
Collection* coll1 = collectionManager.create_collection(schema).get();
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;
size_t num_docs = 20;
for (size_t i = 0; i < num_docs; i++) {
nlohmann::json doc;
doc["id"] = std::to_string(i);
doc["title"] = std::to_string(i) + " title";
doc["points"] = i;
std::vector<float> values;
for(size_t j = 0; j < 4; j++) {
values.push_back(distrib(rng));
}
doc["vec"] = values;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto results = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(num_docs, results["found"].get<size_t>());
ASSERT_EQ(num_docs, results["hits"].size());
// with points:<10
results = coll1->search("*", {}, "points:<10", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(10, results["found"].get<size_t>());
ASSERT_EQ(10, results["hits"].size());
// single point
results = coll1->search("*", {}, "points:1", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(1, results["found"].get<size_t>());
ASSERT_EQ(1, results["hits"].size());
}