mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 21:52:23 +08:00
Fix vector search filter: don't assume order of IDs.
This commit is contained in:
parent
d489702fca
commit
ba2d5d10ce
@ -235,50 +235,28 @@ struct index_record {
|
||||
class VectorFilterFunctor: public hnswlib::FilterFunctor {
|
||||
const uint32_t* filter_ids = nullptr;
|
||||
const uint32_t filter_ids_length = 0;
|
||||
uint32 filter_ids_index = 0;
|
||||
|
||||
public:
|
||||
explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) :
|
||||
filter_ids(filter_ids), filter_ids_length(filter_ids_length) {}
|
||||
|
||||
bool operator()(unsigned int id) {
|
||||
if(filter_ids_length != 0) {
|
||||
if(filter_ids_index >= filter_ids_length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns iterator to the first element that is >= to value or last if no such element is found.
|
||||
size_t found_index = std::lower_bound(filter_ids + filter_ids_index,
|
||||
filter_ids + filter_ids_length, id) - filter_ids;
|
||||
|
||||
if(found_index == filter_ids_length) {
|
||||
// all elements are lesser than lowest value (id), so we can stop looking
|
||||
filter_ids_index = found_index + 1;
|
||||
return false;
|
||||
} else {
|
||||
if(filter_ids[found_index] == id) {
|
||||
filter_ids_index = found_index + 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
filter_ids_index = found_index;
|
||||
}
|
||||
|
||||
return false;
|
||||
if(filter_ids_length == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return true;
|
||||
return std::binary_search(filter_ids, filter_ids + filter_ids_length, id);
|
||||
}
|
||||
};
|
||||
|
||||
struct hnsw_index_t {
|
||||
hnswlib::L2Space* space;
|
||||
hnswlib::InnerProductSpace* space;
|
||||
hnswlib::HierarchicalNSW<float, VectorFilterFunctor>* vecdex;
|
||||
size_t num_dim;
|
||||
vector_distance_type_t distance_type;
|
||||
|
||||
hnsw_index_t(size_t num_dim, size_t init_size, vector_distance_type_t distance_type):
|
||||
space(new hnswlib::L2Space(num_dim)),
|
||||
space(new hnswlib::InnerProductSpace(num_dim)),
|
||||
vecdex(new hnswlib::HierarchicalNSW<float, VectorFilterFunctor>(space, init_size)),
|
||||
num_dim(num_dim), distance_type(distance_type) {
|
||||
|
||||
|
@ -200,3 +200,77 @@ TEST_F(CollectionVectorTest, IndexGreaterThan1KVectors) {
|
||||
|
||||
ASSERT_EQ(1500, results["found"].get<size_t>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, VecSearchWithFiltering) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string"},
|
||||
{"name": "points", "type": "int32"},
|
||||
{"name": "vec", "type": "float[]", "num_dim": 4}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
std::mt19937 rng;
|
||||
rng.seed(47);
|
||||
std::uniform_real_distribution<> distrib;
|
||||
|
||||
size_t num_docs = 20;
|
||||
|
||||
for (size_t i = 0; i < num_docs; i++) {
|
||||
nlohmann::json doc;
|
||||
doc["id"] = std::to_string(i);
|
||||
doc["title"] = std::to_string(i) + " title";
|
||||
doc["points"] = i;
|
||||
|
||||
std::vector<float> values;
|
||||
for(size_t j = 0; j < 4; j++) {
|
||||
values.push_back(distrib(rng));
|
||||
}
|
||||
|
||||
doc["vec"] = values;
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
auto results = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
|
||||
|
||||
ASSERT_EQ(num_docs, results["found"].get<size_t>());
|
||||
ASSERT_EQ(num_docs, results["hits"].size());
|
||||
|
||||
// with points:<10
|
||||
|
||||
results = coll1->search("*", {}, "points:<10", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
|
||||
|
||||
ASSERT_EQ(10, results["found"].get<size_t>());
|
||||
ASSERT_EQ(10, results["hits"].size());
|
||||
|
||||
// single point
|
||||
|
||||
results = coll1->search("*", {}, "points:1", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
|
||||
|
||||
ASSERT_EQ(1, results["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user