String based filtering.

This commit is contained in:
Kishore Nallan 2017-03-04 18:16:37 +05:30
parent 14168c48fc
commit 0760e4d01b
4 changed files with 142 additions and 37 deletions

View File

@ -41,6 +41,8 @@ private:
size_t union_of_leaf_ids(std::vector<const art_leaf *> &leaves, uint32_t **results_out);
uint32_t do_filtering(uint32_t** filter_ids_out, const std::vector<filter> & filters);
void search(uint32_t* filter_ids, size_t filter_ids_length, std::string & query, const std::string & field,
const int num_typos, const size_t num_results, Topster<100> & topster, size_t & num_found,
const token_ordering token_order = FREQUENCY, const bool prefix = false);

View File

@ -266,11 +266,7 @@ size_t Collection::union_of_leaf_ids(std::vector<const art_leaf *> &leaves, uint
return results_length;
}
nlohmann::json Collection::search(std::string query, const std::vector<std::string> fields, const std::vector<filter> filters,
const int num_typos, const size_t num_results,
const token_ordering token_order, const bool prefix) {
size_t num_found = 0;
uint32_t Collection::do_filtering(uint32_t** filter_ids_out, const std::vector<filter> & filters) {
uint32_t* filter_ids = nullptr;
uint32_t filter_ids_length = 0;
@ -279,36 +275,59 @@ nlohmann::json Collection::search(std::string query, const std::vector<std::stri
if(index_map.count(a_filter.field_name) != 0) {
art_tree* t = index_map.at(a_filter.field_name);
field f = schema.at(a_filter.field_name);
std::vector<const art_leaf*> leaves;
if(f.type == field_types::INT64) {
} else if(f.type == field_types::INT32 || f.type == field_types::INT32_ARRAY) {
std::vector<const art_leaf*> leaves;
if(f.type == field_types::INT32 || f.type == field_types::INT32_ARRAY ||
f.type == field_types::INT64 || f.type == field_types::INT64_ARRAY) {
for(const std::string & filter_value: a_filter.values) {
int32_t value = (int32_t) std::stoi(filter_value);
NUM_COMPARATOR comparator = a_filter.get_comparator();
art_int32_search(t, value, comparator, leaves);
if(f.type == field_types::INT32 || f.type == field_types::INT32_ARRAY) {
int32_t value = (int32_t) std::stoi(filter_value);
NUM_COMPARATOR comparator = a_filter.get_comparator();
art_int32_search(t, value, comparator, leaves);
} else {
int64_t value = (int64_t) std::stoi(filter_value);
NUM_COMPARATOR comparator = a_filter.get_comparator();
art_int64_search(t, value, comparator, leaves);
}
}
uint32_t* result_ids = nullptr;
size_t result_ids_length = union_of_leaf_ids(leaves, &result_ids);
if(filter_ids == nullptr) {
filter_ids = result_ids;
filter_ids_length = result_ids_length;
} else {
uint32_t* filtered_results = new uint32_t[std::min((size_t)filter_ids_length, result_ids_length)];
filter_ids_length = Intersection::scalar(filter_ids, filter_ids_length, result_ids, result_ids_length, filtered_results);
delete [] filter_ids;
delete [] result_ids;
filter_ids = filtered_results;
} else if(f.type == field_types::STRING || f.type == field_types::STRING_ARRAY) {
for(const std::string & filter_value: a_filter.values) {
art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) filter_value.c_str(), filter_value.length()+1);
if(leaf != nullptr) {
leaves.push_back(leaf);
}
}
}
uint32_t* result_ids = nullptr;
size_t result_ids_length = union_of_leaf_ids(leaves, &result_ids);
if(filter_ids == nullptr) {
filter_ids = result_ids;
filter_ids_length = result_ids_length;
} else {
uint32_t* filtered_results = new uint32_t[std::min((size_t)filter_ids_length, result_ids_length)];
filter_ids_length = Intersection::scalar(filter_ids, filter_ids_length, result_ids, result_ids_length, filtered_results);
delete [] filter_ids;
delete [] result_ids;
filter_ids = filtered_results;
}
}
}
*filter_ids_out = filter_ids;
return filter_ids_length;
}
nlohmann::json Collection::search(std::string query, const std::vector<std::string> fields, const std::vector<filter> filters,
const int num_typos, const size_t num_results,
const token_ordering token_order, const bool prefix) {
size_t num_found = 0;
// process the filters first
uint32_t* filter_ids = nullptr;
uint32_t filter_ids_length = do_filtering(&filter_ids, filters);
// Order of `fields` are used to rank results
auto begin = std::chrono::high_resolution_clock::now();
std::vector<std::pair<int, Topster<100>::KV>> field_order_kvs;
@ -316,9 +335,12 @@ nlohmann::json Collection::search(std::string query, const std::vector<std::stri
for(int i = 0; i < fields.size(); i++) {
Topster<100> topster;
const std::string & field = fields[i];
search(filter_ids, filter_ids_length, query, field, num_typos, num_results,
topster, num_found, token_order, prefix);
topster.sort();
// proceed to query search only when no filters are provided or when filtering produces results
if(filters.size() == 0 || filter_ids_length > 0) {
search(filter_ids, filter_ids_length, query, field, num_typos, num_results,
topster, num_found, token_order, prefix);
topster.sort();
}
for(auto t = 0; t < topster.size && t < num_results; t++) {
field_order_kvs.push_back(std::make_pair(fields.size() - i, topster.getKV(t)));

View File

@ -378,7 +378,7 @@ TEST_F(CollectionTest, MultipleFields) {
}
}
TEST_F(CollectionTest, SearchInt32Fields) {
TEST_F(CollectionTest, FilterOnNumericFields) {
Collection *coll_array_fields;
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
@ -488,7 +488,7 @@ TEST_F(CollectionTest, SearchInt32Fields) {
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// multiple search values against an int array field
// multiple search values against an int32 array field
filters = {(filter) {"years", {"2015", "1985", "1999"}, "EQUALS"}};
results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(4, results["hits"].size());
@ -500,4 +500,85 @@ TEST_F(CollectionTest, SearchInt32Fields) {
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// searching on an int64 array field
filters = {(filter) {"timestamps", {"475205222"}, "GREATER_THAN"}};
results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(4, results["hits"].size());
ids = {"1", "4", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// when filters don't match any record, no results should be returned
filters = {(filter) {"timestamps", {"1"}, "LESS_THAN"}};
results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(0, results["hits"].size());
collectionManager.drop_collection("coll_array_fields");
}
TEST_F(CollectionTest, FilterOnTextFields) {
Collection *coll_array_fields;
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::vector<field> fields = {field("name", field_types::STRING), field("age", field_types::INT32),
field("years", field_types::INT32_ARRAY),
field("tags", field_types::STRING_ARRAY)};
std::vector<std::string> rank_fields = {"age"};
coll_array_fields = collectionManager.get_collection("coll_array_fields");
if(coll_array_fields == nullptr) {
coll_array_fields = collectionManager.create_collection("coll_array_fields", fields, rank_fields);
}
std::string json_line;
while (std::getline(infile, json_line)) {
coll_array_fields->add(json_line);
}
infile.close();
search_fields = {"name"};
std::vector<filter> filters = {(filter) {"tags", {"gold"}, "EQUALS"}};
nlohmann::json results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(4, results["hits"].size());
std::vector<std::string> ids = {"1", "4", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
filters = {(filter) {"tags", {"bronze"}, "EQUALS"}};
results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(2, results["hits"].size());
ids = {"4", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// should be exact matches (no normalization or fuzzy searching should happen)
filters = {(filter) {"tags", {"BRONZE"}, "EQUALS"}};
results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(0, results["hits"].size());
collectionManager.drop_collection("coll_array_fields");
}

View File

@ -1,5 +1,5 @@
{"name": "Jeremy Howard", "age": 24, "years": [2014, 2015, 2016], "timestamps": [1390354022, 1421890022, 1453426022]}
{"name": "Jeremy Howard", "age": 44, "years": [2015, 2016], "timestamps": [1421890022, 1453426022]}
{"name": "Jeremy Howard", "age": 21, "years": [2016], "timestamps": [1453426022]}
{"name": "Jeremy Howard", "age": 63, "years": [1981, 1985], "timestamps": [348974822, 475205222]}
{"name": "Jeremy Howard", "age": 32, "years": [1999, 2000, 2001, 2002], "timestamps": [916968422, 948504422, 980126822, 1011662822]}
{"name": "Jeremy Howard", "age": 24, "years": [2014, 2015, 2016], "timestamps": [1390354022, 1421890022, 1453426022], "tags": ["gold", "silver"]}
{"name": "Jeremy Howard", "age": 44, "years": [2015, 2016], "timestamps": [1421890022, 1453426022], "tags": ["gold"]}
{"name": "Jeremy Howard", "age": 21, "years": [2016], "timestamps": [1453426022], "tags": ["bronze", "gold"]}
{"name": "Jeremy Howard", "age": 63, "years": [1981, 1985], "timestamps": [348974822, 475205222], "tags": ["silver"]}
{"name": "Jeremy Howard", "age": 32, "years": [1999, 2000, 2001, 2002], "timestamps": [916968422, 948504422, 980126822, 1011662822], "tags": ["silver", "gold", "bronze"]}