From aa9945c3c09408b10a1bb7d13c3649835e4a3fe1 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 12 Feb 2017 12:51:28 +0530 Subject: [PATCH] Implemented filter on a single int32 value. --- TODO.md | 3 +- include/collection.h | 11 +++-- include/field.h | 15 +++++++ src/collection.cpp | 97 +++++++++++++++++++++++++++++++--------- src/sorted_array.cpp | 6 +++ test/collection_test.cpp | 28 +++++++----- 6 files changed, 123 insertions(+), 37 deletions(-) diff --git a/TODO.md b/TODO.md index 3ad58a09..99420241 100644 --- a/TODO.md +++ b/TODO.md @@ -28,7 +28,8 @@ - Pagination parameter - UTF-8 support for fuzzy search - ~~Multi-key binary search during scoring~~ -- Assumption that all tokens match for scoring is no longer true +- ~~Assumption that all tokens match for scoring is no longer true~~ +- Handle searching for non-existing fields gracefully - Intersection without unpacking - Facets - Filters diff --git a/include/collection.h b/include/collection.h index a839e775..fdc39b92 100644 --- a/include/collection.h +++ b/include/collection.h @@ -39,11 +39,14 @@ private: void log_leaves(const int cost, const std::string &token, const std::vector &leaves) const; - void search(std::string & query, const std::string & field, const int num_typos, const size_t num_results, - Topster<100> & topster, size_t & num_found, const token_ordering token_order = FREQUENCY, - const bool prefix = false); + void union_with_filter_ids(std::vector & leaves, uint32_t** filter_ids, uint32_t & filter_ids_length); - void search_candidates(int & token_rank, std::vector> & token_leaves, Topster<100> & topster, + void search(uint32_t* filter_ids, size_t filter_ids_length, std::string & query, const std::string & field, + const int num_typos, const size_t num_results, Topster<100> & topster, size_t & num_found, + const token_ordering token_order = FREQUENCY, const bool prefix = false); + + void search_candidates(uint32_t* filter_ids, size_t filter_ids_length, int & token_rank, + std::vector> & token_leaves, Topster<100> & topster, size_t & total_results, size_t & num_found, const size_t & max_results); void index_string_field(const std::string & text, const uint32_t score, art_tree *t, uint32_t seq_id) const; diff --git a/include/field.h b/include/field.h index 2603cf00..b83b933d 100644 --- a/include/field.h +++ b/include/field.h @@ -1,6 +1,7 @@ #pragma once #include +#include "art.h" namespace field_types { static const std::string STRING = "STRING"; @@ -29,4 +30,18 @@ struct filter { std::string field_name; std::string value_json; std::string compare_operator; + + NUM_COMPARATOR get_comparator() const { + if(compare_operator == "LESS_THAN") { + return LESS_THAN; + } else if(compare_operator == "LESS_THAN_EQUALS") { + return LESS_THAN_EQUALS; + } else if(compare_operator == "EQUALS") { + return EQUALS; + } else if(compare_operator == "GREATER_THAN") { + return GREATER_THAN; + } else { + return GREATER_THAN_EQUALS; + } + } }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 8c2bb041..ecfc4ffa 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -53,7 +53,11 @@ void Collection::index_in_memory(const nlohmann::json &document, uint32_t seq_id for(const std::pair & field_pair: schema) { const std::string & field_name = field_pair.first; art_tree *t = index_map.at(field_name); - uint32_t points = document["points"]; + + uint32_t points = 0; + if(document.count("points") != 0) { + points = document["points"]; + } if(field_pair.second.type == field_types::STRING) { const std::string & text = document[field_name]; @@ -68,9 +72,14 @@ void Collection::index_in_memory(const nlohmann::json &document, uint32_t seq_id std::vector strings = document[field_name]; index_string_array_field(strings, points, t, seq_id); } else if(field_pair.second.type == field_types::INT32_ARRAY) { - std::vector strings = document[field_name]; + std::vector values = document[field_name]; + index_int32_array_field(values, points, t, seq_id); + } else if(field_pair.second.type == field_types::INT64_ARRAY) { + std::vector values = document[field_name]; + index_int64_array_field(values, points, t, seq_id); } } + if(rank_fields.size() > 0 && document.count(rank_fields[0])) { primary_rank_scores[seq_id] = document[rank_fields[0]].get(); } @@ -186,8 +195,9 @@ void Collection::index_int64_array_field(const std::vector & values, co } } -void Collection::search_candidates(int & token_rank, std::vector> & token_leaves, - Topster<100> & topster, size_t & total_results, size_t & num_found, const size_t & max_results) { +void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_length, int & token_rank, + std::vector> & token_leaves, Topster<100> & topster, + size_t & total_results, size_t & num_found, const size_t & max_results) { const size_t combination_limit = 10; auto product = []( long long a, std::vector& b ) { return a*b.size(); }; long long int N = std::accumulate(token_leaves.begin(), token_leaves.end(), 1LL, product); @@ -211,11 +221,24 @@ void Collection::search_candidates(int & token_rank, std::vector= max_results) { @@ -224,25 +247,51 @@ void Collection::search_candidates(int & token_rank, std::vector & leaves, uint32_t** filter_ids, + uint32_t & filter_ids_length) { + + for(const art_leaf* leaf: leaves) { + uint32_t* results = new uint32_t[filter_ids_length + leaf->values->ids.getLength()]; + size_t results_length = leaf->values->ids.do_union(*filter_ids, filter_ids_length, results); + + delete [] *filter_ids; + + *filter_ids = results; + filter_ids_length = results_length; + } +} + nlohmann::json Collection::search(std::string query, const std::vector fields, const std::vector filters, const int num_typos, const size_t num_results, const token_ordering token_order, const bool prefix) { size_t num_found = 0; + uint32_t* filter_ids = nullptr; + uint32_t filter_ids_length = 0; + // process the filters first - /*for(const filter & a_filter: filters) { + for(const filter & a_filter: filters) { if(index_map.count(a_filter.field_name) != 0) { art_tree* t = index_map.at(a_filter.field_name); - nlohmann::json json_value = nlohmann::json::parse(a_filter.value_json); - if(json_value.is_number()) { - // do integer art search - } else if(json_value.is_string()) { + field f = schema.at(a_filter.field_name); - } else if(json_value.is_array()) { + nlohmann::json json_value = nlohmann::json::parse(a_filter.value_json); + if(f.type == field_types::INT64) { + + + } else if(f.type == field_types::INT32) { + int32_t value = json_value.get(); + NUM_COMPARATOR comparator = a_filter.get_comparator(); + + std::vector leaves; + art_int32_search(t, value, comparator, leaves); + + union_with_filter_ids(leaves, &filter_ids, filter_ids_length); + } else if(f.type == field_types::INT32_ARRAY) { } } - }*/ + } // Order of `fields` are used to rank results auto begin = std::chrono::high_resolution_clock::now(); @@ -251,8 +300,8 @@ nlohmann::json Collection::search(std::string query, const std::vector topster; const std::string & field = fields[i]; - - search(query, field, num_typos, num_results, topster, num_found, token_order, prefix); + search(filter_ids, filter_ids_length, query, field, num_typos, num_results, + topster, num_found, token_order, prefix); topster.sort(); for(auto t = 0; t < topster.size && t < num_results; t++) { @@ -297,7 +346,8 @@ nlohmann::json Collection::search(std::string query, const std::vector & topster, size_t & num_found, const token_ordering token_order, const bool prefix) { std::vector tokens; StringUtils::tokenize(query, tokens, " ", true); @@ -396,7 +446,8 @@ void Collection::search(std::string & query, const std::string & field, const in if(token_leaves.size() != 0 && token_leaves.size() == tokens.size()) { // If a) all tokens were found, or b) Some were skipped because they don't exist within max_cost, // go ahead and search for candidates with what we have so far - search_candidates(token_rank, token_leaves, topster, total_results, num_found, max_results); + search_candidates(filter_ids, filter_ids_length, token_rank, token_leaves, topster, + total_results, num_found, max_results); if (total_results >= max_results) { // If we don't find enough results, we continue outerloop (looking at tokens with greater cost) @@ -429,7 +480,8 @@ void Collection::search(std::string & query, const std::string & field, const in } } - return search(truncated_query, field, num_typos, num_results, topster, num_found, token_order, prefix); + return search(filter_ids, filter_ids_length, truncated_query, field, num_typos, num_results, topster, + num_found, token_order, prefix); } } @@ -506,6 +558,11 @@ void Collection::score_results(Topster<100> &topster, const int & token_rank, /*std::cout << "token_rank_score: " << token_rank_score << ", match_score: " << match_score << ", primary_rank_score: " << primary_rank_score << ", seq_id: " << seq_id << std::endl;*/ } + + for (auto it = leaf_to_indices.begin(); it != leaf_to_indices.end(); it++) { + delete [] it->second; + it->second = nullptr; + } } inline std::vector Collection::next_suggestion(const std::vector> &token_leaves, diff --git a/src/sorted_array.cpp b/src/sorted_array.cpp index e38b2a69..b9384dca 100644 --- a/src/sorted_array.cpp +++ b/src/sorted_array.cpp @@ -180,6 +180,12 @@ size_t sorted_array::do_union(uint32_t *arr, const size_t arr_length, uint32_t * size_t curr_index = 0, arr_index = 0, res_index = 0; uint32_t* curr = uncompress(); + if(arr == nullptr) { + memcpy(results, curr, length * sizeof(uint32_t)); + delete[] curr; + return length; + } + while (curr_index < length && arr_index < arr_length) { if (curr[curr_index] < arr[arr_index]) { if(res_index == 0 || results[res_index-1] != curr[curr_index]) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 169acd3a..16d22967 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -378,12 +378,12 @@ TEST_F(CollectionTest, MultipleFields) { } } -/* TEST_F(CollectionTest, SearchNumericFields) { Collection *coll_array_fields; std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); - std::vector fields = {field("name", field_types::STRING), field("years", field_types::INT32_ARRAY), + std::vector fields = {field("name", field_types::STRING), field("age", field_types::INT32), + field("years", field_types::INT32_ARRAY), field("timestamps", field_types::INT64_ARRAY)}; std::vector rank_fields = {"age"}; @@ -400,11 +400,12 @@ TEST_F(CollectionTest, SearchNumericFields) { infile.close(); - search_fields = {"years"}; + // Plain search with no filters - results should be sorted by rank fields + search_fields = {"name"}; nlohmann::json results = coll_array_fields->search("Jeremy", search_fields, {}, 0, 10, FREQUENCY, false); - ASSERT_EQ(4, results["hits"].size()); + ASSERT_EQ(5, results["hits"].size()); - std::vector ids = {"3", "2", "1", "0"}; + std::vector ids = {"3", "0", "4", "1", "2"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); @@ -413,21 +414,25 @@ TEST_F(CollectionTest, SearchNumericFields) { ASSERT_STREQ(id.c_str(), result_id.c_str()); } + search_fields = {"name"}; + std::vector filters; + filter f1 = {"age", "24", "GREATER_THAN"}; + filters.push_back(f1); - search_fields = {"starring", "title"}; - results = coll_array_fields->search("thomas", search_fields, {}, 0, 10, FREQUENCY, false); - ASSERT_EQ(4, results["hits"].size()); + results = coll_array_fields->search("Jeremy", search_fields, filters, 0, 10, FREQUENCY, false); + ASSERT_EQ(3, results["hits"].size()); - ids = {"15", "14", "12", "13"}; + ids = {"3", "0", "4"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); std::string result_id = result["id"]; std::string id = ids.at(i); ASSERT_STREQ(id.c_str(), result_id.c_str()); + //std::cout << result_id << std::endl; } - search_fields = {"starring", "title", "cast"}; + /*search_fields = {"starring", "title", "cast"}; results = coll_array_fields->search("ben affleck", search_fields, {}, 0, 10, FREQUENCY, false); ASSERT_EQ(1, results["hits"].size()); @@ -453,6 +458,5 @@ TEST_F(CollectionTest, SearchNumericFields) { std::string result_id = result["id"]; std::string id = ids.at(i); ASSERT_STREQ(id.c_str(), result_id.c_str()); - } + }*/ } -*/