Implemented filter on a single int32 value.

This commit is contained in:
Kishore Nallan 2017-02-12 12:51:28 +05:30
parent 60cc05fe52
commit aa9945c3c0
6 changed files with 123 additions and 37 deletions

View File

@ -28,7 +28,8 @@
- Pagination parameter
- UTF-8 support for fuzzy search
- ~~Multi-key binary search during scoring~~
- Assumption that all tokens match for scoring is no longer true
- ~~Assumption that all tokens match for scoring is no longer true~~
- Handle searching for non-existing fields gracefully
- Intersection without unpacking
- Facets
- Filters

View File

@ -39,11 +39,14 @@ private:
void log_leaves(const int cost, const std::string &token, const std::vector<art_leaf *> &leaves) const;
void search(std::string & query, const std::string & field, const int num_typos, const size_t num_results,
Topster<100> & topster, size_t & num_found, const token_ordering token_order = FREQUENCY,
const bool prefix = false);
void union_with_filter_ids(std::vector<const art_leaf*> & leaves, uint32_t** filter_ids, uint32_t & filter_ids_length);
void search_candidates(int & token_rank, std::vector<std::vector<art_leaf*>> & token_leaves, Topster<100> & topster,
void search(uint32_t* filter_ids, size_t filter_ids_length, std::string & query, const std::string & field,
const int num_typos, const size_t num_results, Topster<100> & topster, size_t & num_found,
const token_ordering token_order = FREQUENCY, const bool prefix = false);
void search_candidates(uint32_t* filter_ids, size_t filter_ids_length, int & token_rank,
std::vector<std::vector<art_leaf*>> & token_leaves, Topster<100> & topster,
size_t & total_results, size_t & num_found, const size_t & max_results);
void index_string_field(const std::string & text, const uint32_t score, art_tree *t, uint32_t seq_id) const;

View File

@ -1,6 +1,7 @@
#pragma once
#include <string>
#include "art.h"
namespace field_types {
static const std::string STRING = "STRING";
@ -29,4 +30,18 @@ struct filter {
std::string field_name;
std::string value_json;
std::string compare_operator;
NUM_COMPARATOR get_comparator() const {
if(compare_operator == "LESS_THAN") {
return LESS_THAN;
} else if(compare_operator == "LESS_THAN_EQUALS") {
return LESS_THAN_EQUALS;
} else if(compare_operator == "EQUALS") {
return EQUALS;
} else if(compare_operator == "GREATER_THAN") {
return GREATER_THAN;
} else {
return GREATER_THAN_EQUALS;
}
}
};

View File

@ -53,7 +53,11 @@ void Collection::index_in_memory(const nlohmann::json &document, uint32_t seq_id
for(const std::pair<std::string, field> & field_pair: schema) {
const std::string & field_name = field_pair.first;
art_tree *t = index_map.at(field_name);
uint32_t points = document["points"];
uint32_t points = 0;
if(document.count("points") != 0) {
points = document["points"];
}
if(field_pair.second.type == field_types::STRING) {
const std::string & text = document[field_name];
@ -68,9 +72,14 @@ void Collection::index_in_memory(const nlohmann::json &document, uint32_t seq_id
std::vector<std::string> strings = document[field_name];
index_string_array_field(strings, points, t, seq_id);
} else if(field_pair.second.type == field_types::INT32_ARRAY) {
std::vector<std::string> strings = document[field_name];
std::vector<int32_t> values = document[field_name];
index_int32_array_field(values, points, t, seq_id);
} else if(field_pair.second.type == field_types::INT64_ARRAY) {
std::vector<int64_t> values = document[field_name];
index_int64_array_field(values, points, t, seq_id);
}
}
if(rank_fields.size() > 0 && document.count(rank_fields[0])) {
primary_rank_scores[seq_id] = document[rank_fields[0]].get<int64_t>();
}
@ -186,8 +195,9 @@ void Collection::index_int64_array_field(const std::vector<int64_t> & values, co
}
}
void Collection::search_candidates(int & token_rank, std::vector<std::vector<art_leaf*>> & token_leaves,
Topster<100> & topster, size_t & total_results, size_t & num_found, const size_t & max_results) {
void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_length, int & token_rank,
std::vector<std::vector<art_leaf*>> & token_leaves, Topster<100> & topster,
size_t & total_results, size_t & num_found, const size_t & max_results) {
const size_t combination_limit = 10;
auto product = []( long long a, std::vector<art_leaf*>& b ) { return a*b.size(); };
long long int N = std::accumulate(token_leaves.begin(), token_leaves.end(), 1LL, product);
@ -211,11 +221,24 @@ void Collection::search_candidates(int & token_rank, std::vector<std::vector<art
result_ids = out;
}
// go through each matching document id and calculate match score
score_results(topster, token_rank, query_suggestion, result_ids, result_size);
delete[] result_ids;
if(filter_ids != nullptr) {
// intersect once again with filter ids
uint32_t* filtered_result_ids = new uint32_t[std::min(filter_ids_length, result_size)];
size_t filtered_results_size =
Intersection::scalar(filter_ids, filter_ids_length, result_ids, result_size, filtered_result_ids);
// go through each matching document id and calculate match score
score_results(topster, token_rank, query_suggestion, filtered_result_ids, filtered_results_size);
num_found += filtered_results_size;
delete[] filtered_result_ids;
delete[] result_ids;
} else {
score_results(topster, token_rank, query_suggestion, result_ids, result_size);
num_found += result_size;
delete[] result_ids;
}
num_found += result_size;
total_results += topster.size;
if(total_results >= max_results) {
@ -224,25 +247,51 @@ void Collection::search_candidates(int & token_rank, std::vector<std::vector<art
}
}
void Collection::union_with_filter_ids(std::vector<const art_leaf*> & leaves, uint32_t** filter_ids,
uint32_t & filter_ids_length) {
for(const art_leaf* leaf: leaves) {
uint32_t* results = new uint32_t[filter_ids_length + leaf->values->ids.getLength()];
size_t results_length = leaf->values->ids.do_union(*filter_ids, filter_ids_length, results);
delete [] *filter_ids;
*filter_ids = results;
filter_ids_length = results_length;
}
}
nlohmann::json Collection::search(std::string query, const std::vector<std::string> fields, const std::vector<filter> filters,
const int num_typos, const size_t num_results,
const token_ordering token_order, const bool prefix) {
size_t num_found = 0;
uint32_t* filter_ids = nullptr;
uint32_t filter_ids_length = 0;
// process the filters first
/*for(const filter & a_filter: filters) {
for(const filter & a_filter: filters) {
if(index_map.count(a_filter.field_name) != 0) {
art_tree* t = index_map.at(a_filter.field_name);
nlohmann::json json_value = nlohmann::json::parse(a_filter.value_json);
if(json_value.is_number()) {
// do integer art search
} else if(json_value.is_string()) {
field f = schema.at(a_filter.field_name);
} else if(json_value.is_array()) {
nlohmann::json json_value = nlohmann::json::parse(a_filter.value_json);
if(f.type == field_types::INT64) {
} else if(f.type == field_types::INT32) {
int32_t value = json_value.get<int32_t>();
NUM_COMPARATOR comparator = a_filter.get_comparator();
std::vector<const art_leaf*> leaves;
art_int32_search(t, value, comparator, leaves);
union_with_filter_ids(leaves, &filter_ids, filter_ids_length);
} else if(f.type == field_types::INT32_ARRAY) {
}
}
}*/
}
// Order of `fields` are used to rank results
auto begin = std::chrono::high_resolution_clock::now();
@ -251,8 +300,8 @@ nlohmann::json Collection::search(std::string query, const std::vector<std::stri
for(int i = 0; i < fields.size(); i++) {
Topster<100> topster;
const std::string & field = fields[i];
search(query, field, num_typos, num_results, topster, num_found, token_order, prefix);
search(filter_ids, filter_ids_length, query, field, num_typos, num_results,
topster, num_found, token_order, prefix);
topster.sort();
for(auto t = 0; t < topster.size && t < num_results; t++) {
@ -297,7 +346,8 @@ nlohmann::json Collection::search(std::string query, const std::vector<std::stri
4. Intersect the lists to find docs that match each phrase
5. Sort the docs based on some ranking criteria
*/
void Collection::search(std::string & query, const std::string & field, const int num_typos, const size_t num_results,
void Collection::search(uint32_t* filter_ids, size_t filter_ids_length, std::string & query,
const std::string & field, const int num_typos, const size_t num_results,
Topster<100> & topster, size_t & num_found, const token_ordering token_order, const bool prefix) {
std::vector<std::string> tokens;
StringUtils::tokenize(query, tokens, " ", true);
@ -396,7 +446,8 @@ void Collection::search(std::string & query, const std::string & field, const in
if(token_leaves.size() != 0 && token_leaves.size() == tokens.size()) {
// If a) all tokens were found, or b) Some were skipped because they don't exist within max_cost,
// go ahead and search for candidates with what we have so far
search_candidates(token_rank, token_leaves, topster, total_results, num_found, max_results);
search_candidates(filter_ids, filter_ids_length, token_rank, token_leaves, topster,
total_results, num_found, max_results);
if (total_results >= max_results) {
// If we don't find enough results, we continue outerloop (looking at tokens with greater cost)
@ -429,7 +480,8 @@ void Collection::search(std::string & query, const std::string & field, const in
}
}
return search(truncated_query, field, num_typos, num_results, topster, num_found, token_order, prefix);
return search(filter_ids, filter_ids_length, truncated_query, field, num_typos, num_results, topster,
num_found, token_order, prefix);
}
}
@ -506,6 +558,11 @@ void Collection::score_results(Topster<100> &topster, const int & token_rank,
/*std::cout << "token_rank_score: " << token_rank_score << ", match_score: "
<< match_score << ", primary_rank_score: " << primary_rank_score << ", seq_id: " << seq_id << std::endl;*/
}
for (auto it = leaf_to_indices.begin(); it != leaf_to_indices.end(); it++) {
delete [] it->second;
it->second = nullptr;
}
}
inline std::vector<art_leaf *> Collection::next_suggestion(const std::vector<std::vector<art_leaf *>> &token_leaves,

View File

@ -180,6 +180,12 @@ size_t sorted_array::do_union(uint32_t *arr, const size_t arr_length, uint32_t *
size_t curr_index = 0, arr_index = 0, res_index = 0;
uint32_t* curr = uncompress();
if(arr == nullptr) {
memcpy(results, curr, length * sizeof(uint32_t));
delete[] curr;
return length;
}
while (curr_index < length && arr_index < arr_length) {
if (curr[curr_index] < arr[arr_index]) {
if(res_index == 0 || results[res_index-1] != curr[curr_index]) {

View File

@ -378,12 +378,12 @@ TEST_F(CollectionTest, MultipleFields) {
}
}
/*
TEST_F(CollectionTest, SearchNumericFields) {
Collection *coll_array_fields;
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::vector<field> fields = {field("name", field_types::STRING), field("years", field_types::INT32_ARRAY),
std::vector<field> fields = {field("name", field_types::STRING), field("age", field_types::INT32),
field("years", field_types::INT32_ARRAY),
field("timestamps", field_types::INT64_ARRAY)};
std::vector<std::string> rank_fields = {"age"};
@ -400,11 +400,12 @@ TEST_F(CollectionTest, SearchNumericFields) {
infile.close();
search_fields = {"years"};
// Plain search with no filters - results should be sorted by rank fields
search_fields = {"name"};
nlohmann::json results = coll_array_fields->search("Jeremy", search_fields, {}, 0, 10, FREQUENCY, false);
ASSERT_EQ(4, results["hits"].size());
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> ids = {"3", "2", "1", "0"};
std::vector<std::string> ids = {"3", "0", "4", "1", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
@ -413,21 +414,25 @@ TEST_F(CollectionTest, SearchNumericFields) {
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
search_fields = {"name"};
std::vector<filter> filters;
filter f1 = {"age", "24", "GREATER_THAN"};
filters.push_back(f1);
search_fields = {"starring", "title"};
results = coll_array_fields->search("thomas", search_fields, {}, 0, 10, FREQUENCY, false);
ASSERT_EQ(4, results["hits"].size());
results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false);
ASSERT_EQ(3, results["hits"].size());
ids = {"15", "14", "12", "13"};
ids = {"3", "0", "4"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
//std::cout << result_id << std::endl;
}
search_fields = {"starring", "title", "cast"};
/*search_fields = {"starring", "title", "cast"};
results = coll_array_fields->search("ben affleck", search_fields, {}, 0, 10, FREQUENCY, false);
ASSERT_EQ(1, results["hits"].size());
@ -453,6 +458,5 @@ TEST_F(CollectionTest, SearchNumericFields) {
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
}*/
}
*/