Support for grouping overrides.

This commit is contained in:
kishorenc 2020-06-15 19:20:00 +05:30
parent c5010a6a5f
commit 3f6f13baf1
7 changed files with 159 additions and 112 deletions

View File

@ -99,8 +99,8 @@
- ~~Use snappy compression for storage~~
- ~~Fix exclude_scalar early returns~~
- ~~Fix result ids length during grouped overrides~~
- Fix override grouping (collate_included_ids)
- Test for overriding result on second page
- ~~Fix override grouping (collate_included_ids)~~
- ~~Test for overriding result on second page~~
- atleast 1 token match for proceeding with drop tokens
- support wildcard query with filters
- API for optimizing on disk storage

View File

@ -155,10 +155,10 @@ private:
void remove_document(nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
void populate_overrides(std::string query,
const std::map<std::string, size_t>& pinned_hits,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<uint32_t, size_t> & id_pos_map,
std::vector<uint32_t> & included_ids, std::vector<uint32_t> & excluded_ids);
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t> & excluded_ids);
static bool facet_count_compare(const std::pair<uint64_t, facet_count_t>& a,
const std::pair<uint64_t, facet_count_t>& b) {
@ -236,7 +236,7 @@ public:
const size_t snippet_threshold = 30,
const std::string & highlight_full_fields = "",
size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD,
const std::map<std::string, size_t>& pinned_hits={},
const std::map<size_t, std::vector<std::string>>& pinned_hits={},
const std::vector<std::string>& hidden_hits={},
const std::vector<std::string>& group_by_fields={},
const size_t group_limit = 0);

View File

@ -27,7 +27,7 @@ struct search_args {
std::vector<std::string> search_fields;
std::vector<filter> filters;
std::vector<facet> facets;
std::vector<uint32_t> included_ids;
std::map<size_t, std::vector<uint32_t>> included_ids;
std::vector<uint32_t> excluded_ids;
std::vector<sort_by> sort_fields_std;
facet_query_t facet_query;
@ -47,7 +47,7 @@ struct search_args {
Topster* topster;
Topster* curated_topster;
std::vector<std::vector<KV*>> raw_result_kvs;
std::vector<KV*> override_result_kvs;
std::vector<std::vector<KV*>> override_result_kvs;
Option<uint32_t> outcome;
search_args(): outcome(0) {
@ -55,7 +55,7 @@ struct search_args {
}
search_args(std::string query, std::vector<std::string> search_fields, std::vector<filter> filters,
std::vector<facet> facets, std::vector<uint32_t> included_ids, std::vector<uint32_t> excluded_ids,
std::vector<facet> facets, std::map<size_t, std::vector<uint32_t>> included_ids, std::vector<uint32_t> excluded_ids,
std::vector<sort_by> sort_fields_std, facet_query_t facet_query, int num_typos, size_t max_facet_values,
size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix,
size_t drop_tokens_threshold, size_t typo_tokens_threshold,
@ -70,7 +70,7 @@ struct search_args {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
topster = new Topster(topster_size, group_limit);
curated_topster = new Topster(topster_size);
curated_topster = new Topster(topster_size, group_limit);
}
~search_args() {
@ -213,7 +213,7 @@ private:
const uint32_t indices_length);
void collate_included_ids(const std::string & query, const std::string & field, const uint8_t field_id,
const std::vector<uint32_t> & included_ids,
const std::map<size_t, std::vector<uint32_t>> & included_ids_map,
Topster* curated_topster, std::vector<std::vector<art_leaf*>> & searched_queries);
uint64_t facet_token_hash(const field & a_field, const std::string &token);
@ -239,7 +239,8 @@ public:
void search(Option<uint32_t> & outcome, const std::string & query, const std::vector<std::string> & search_fields,
const std::vector<filter> & filters, std::vector<facet> & facets,
facet_query_t & facet_query,
const std::vector<uint32_t> & included_ids, const std::vector<uint32_t> & excluded_ids,
const std::map<size_t, std::vector<uint32_t>> & included_ids_map,
const std::vector<uint32_t> & excluded_ids,
const std::vector<sort_by> & sort_fields_std, const int num_typos,
Topster* topster, Topster* curated_topster,
const size_t per_page, const size_t page, const token_ordering token_order,
@ -247,7 +248,8 @@ public:
size_t & all_result_ids_len,
spp::sparse_hash_set<uint64_t>& groups_processed,
std::vector<std::vector<art_leaf*>> & searched_queries,
std::vector<std::vector<KV*>> & raw_result_kvs, std::vector<KV*> & override_result_kvs,
std::vector<std::vector<KV*>> & raw_result_kvs,
std::vector<std::vector<KV*>> & override_result_kvs,
const size_t typo_tokens_threshold);
Option<uint32_t> remove(const uint32_t seq_id, nlohmann::json & document);

View File

@ -305,54 +305,70 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash
}
void Collection::populate_overrides(std::string query,
const std::map<std::string, size_t>& pinned_hits,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<uint32_t, size_t> & id_pos_map,
std::vector<uint32_t> & included_ids,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t> & excluded_ids) {
StringUtils::tolowercase(query);
std::set<uint32_t> excluded_set;
// If pinned or hidden hits are provided, they take precedence over overrides
// have to ensure that hidden hits take precedence over included hits
if(!hidden_hits.empty()) {
for(const auto & hit: hidden_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit);
if(seq_id_op.ok()) {
excluded_ids.push_back(seq_id_op.get());
excluded_set.insert(seq_id_op.get());
}
}
}
for(const auto & override_kv: overrides) {
const auto & override = override_kv.second;
if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) ||
(override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) {
for(const auto & hit: override.add_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(seq_id_op.ok()) {
included_ids.push_back(seq_id_op.get());
id_pos_map[seq_id_op.get()] = hit.position;
}
}
// have to ensure that dropped hits take precedence over added hits
for(const auto & hit: override.drop_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(seq_id_op.ok()) {
excluded_ids.push_back(seq_id_op.get());
excluded_set.insert(seq_id_op.get());
}
}
for(const auto & hit: override.add_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(!seq_id_op.ok()) {
continue;
}
uint32_t seq_id = seq_id_op.get();
bool excluded = (excluded_set.count(seq_id) != 0);
if(!excluded) {
include_ids[hit.position].push_back(seq_id);
}
}
break;
}
}
// If pinned or hidden hits are provided, they take precedence over overrides
if(!pinned_hits.empty()) {
for(const auto & hit: pinned_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.first);
if(seq_id_op.ok()) {
included_ids.push_back(seq_id_op.get());
id_pos_map[seq_id_op.get()] = hit.second;
}
}
}
if(!hidden_hits.empty()) {
for(const auto & hit: hidden_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit);
if(seq_id_op.ok()) {
included_ids.erase(std::remove(included_ids.begin(), included_ids.end(), seq_id_op.get()), included_ids.end());
id_pos_map.erase(seq_id_op.get());
excluded_ids.push_back(seq_id_op.get());
for(const auto& pos_ids: pinned_hits) {
size_t pos = pos_ids.first;
for(const std::string& id: pos_ids.second) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(id);
if(!seq_id_op.ok()) {
continue;
}
uint32_t seq_id = seq_id_op.get();
bool excluded = (excluded_set.count(seq_id) != 0);
if(!excluded) {
include_ids[pos].push_back(seq_id);
}
}
}
}
@ -371,22 +387,40 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
const size_t snippet_threshold,
const std::string & highlight_full_fields,
size_t typo_tokens_threshold,
const std::map<std::string, size_t>& pinned_hits,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
const std::vector<std::string>& group_by_fields,
const size_t group_limit) {
std::vector<uint32_t> included_ids;
std::vector<uint32_t> excluded_ids;
std::map<uint32_t, size_t> id_pos_map;
populate_overrides(query, pinned_hits, hidden_hits, id_pos_map, included_ids, excluded_ids);
std::map<size_t, std::vector<uint32_t>> include_ids; // position => list of IDs
populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids);
std::map<uint32_t, std::vector<uint32_t>> index_to_included_ids;
/*for(auto kv: include_ids) {
LOG(INFO) << "key: " << kv.first;
for(auto val: kv.second) {
LOG(INFO) << val;
}
}
LOG(INFO) << "Excludes:";
for(auto id: excluded_ids) {
LOG(INFO) << id;
}*/
//LOG(INFO) << "include_ids size: " << include_ids.size();
//LOG(INFO) << "Pos 1: " << include_ids[1][0];
std::map<uint32_t, std::map<size_t, std::vector<uint32_t>>> index_to_included_ids;
std::map<uint32_t, std::vector<uint32_t>> index_to_excluded_ids;
for(auto seq_id: included_ids) {
auto index_id = (seq_id % num_indices);
index_to_included_ids[index_id].push_back(seq_id);
for(const auto& pos_ids: include_ids) {
size_t position = pos_ids.first;
for(auto seq_id: pos_ids.second) {
auto index_id = (seq_id % num_indices);
index_to_included_ids[index_id][position].push_back(seq_id);
}
}
for(auto seq_id: excluded_ids) {
@ -650,7 +684,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
std::vector<std::vector<art_leaf*>> searched_queries; // search queries used for generating the results
std::vector<std::vector<KV*>> raw_result_kvs;
std::vector<KV*> override_result_kvs;
std::vector<std::vector<KV*>> override_result_kvs;
size_t total_found = 0;
spp::sparse_hash_set<uint64_t> groups_processed; // used to calculate total_found for grouped query
@ -697,9 +731,9 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
raw_result_kvs.push_back(kv_group);
}
for(auto & field_order_kv: index->search_params->override_result_kvs) {
field_order_kv->query_index += searched_queries.size();
override_result_kvs.push_back(field_order_kv);
for(const std::vector<KV*> & kv_group: index->search_params->override_result_kvs) {
kv_group[0]->query_index += searched_queries.size();
override_result_kvs.push_back(kv_group);
}
searched_queries.insert(searched_queries.end(), index->search_params->searched_queries.begin(),
@ -797,8 +831,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
// Sort based on position in overridden list
std::sort(
override_result_kvs.begin(), override_result_kvs.end(),
[&id_pos_map](const KV* a, const KV* b) -> bool {
return id_pos_map[a->key] < id_pos_map[b->key];
[](const std::vector<KV*>& a, std::vector<KV*>& b) -> bool {
return a[0]->distinct_key < b[0]->distinct_key;
}
);
@ -808,11 +842,12 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
// merge raw results and override results
while(override_kv_index < override_result_kvs.size() && raw_results_index < raw_result_kvs.size()) {
if(override_kv_index < override_result_kvs.size() &&
id_pos_map.count(override_result_kvs[override_kv_index]->key) != 0 &&
result_group_kvs.size() + 1 == id_pos_map[override_result_kvs[override_kv_index]->key]) {
result_group_kvs.push_back({override_result_kvs[override_kv_index]});
override_kv_index++;
size_t result_position = result_group_kvs.size() + 1;
uint64_t override_position = override_result_kvs[override_kv_index][0]->distinct_key;
if(result_position == override_position) {
result_group_kvs.push_back(override_result_kvs[override_kv_index]);
override_kv_index++;
} else {
result_group_kvs.push_back(raw_result_kvs[raw_results_index]);
raw_results_index++;

View File

@ -390,7 +390,8 @@ bool get_search(http_req & req, http_res & res) {
}
}
std::map<std::string, size_t> pinned_hits;
std::map<size_t, std::vector<std::string>> pinned_hits;
if(req.params.count(PINNED_HITS) != 0) {
std::vector<std::string> pinned_hits_strs;
StringUtils::split(req.params[PINNED_HITS], pinned_hits_strs, ",");
@ -415,7 +416,7 @@ bool get_search(http_req & req, http_res & res) {
return false;
}
pinned_hits.emplace(expression_parts[0], position);
pinned_hits[position].emplace_back(expression_parts[0]);
}
}

View File

@ -1038,11 +1038,11 @@ void Index::run_search() {
}
void Index::collate_included_ids(const std::string & query, const std::string & field, const uint8_t field_id,
const std::vector<uint32_t> & included_ids,
const std::map<size_t, std::vector<uint32_t>> & included_ids_map,
Topster* curated_topster,
std::vector<std::vector<art_leaf*>> & searched_queries) {
if(included_ids.empty()) {
if(included_ids_map.empty()) {
return;
}
@ -1061,50 +1061,28 @@ void Index::collate_included_ids(const std::string & query, const std::string &
art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len,
0, 0, 1, token_ordering::MAX_SCORE, false, leaves);
if(leaves.size() > 0) {
if(!leaves.empty()) {
override_query.push_back(leaves[0]);
}
}
spp::sparse_hash_map<const art_leaf*, uint32_t*> leaf_to_indices;
for(const auto& pos_ids: included_ids_map) {
const size_t pos = pos_ids.first;
for (art_leaf *token_leaf : override_query) {
uint32_t *indices = new uint32_t[included_ids.size()];
token_leaf->values->ids.indexOf(&included_ids[0], included_ids.size(), indices);
leaf_to_indices.emplace(token_leaf, indices);
}
for(size_t i = 0; i < pos_ids.second.size(); i++) {
uint32_t seq_id = pos_ids.second[i];
// curated_topster->MAX_SIZE is initialized based on max_hits.
// Even if override has more IDs, we should restrict to max hits.
size_t iter_size = std::min((size_t)curated_topster->MAX_SIZE, included_ids.size());
uint64_t distinct_id = pos; // position is the group distinct key
uint64_t match_score = (64000 - i); // index within a group is the match score
for(size_t j=0; j<iter_size; j++) {
const uint32_t seq_id = included_ids[j];
int64_t scores[3];
scores[0] = match_score;
scores[1] = int64_t(1);
scores[2] = int64_t(1);
std::vector<std::vector<std::vector<uint16_t>>> array_token_positions;
populate_token_positions(override_query, leaf_to_indices, j, array_token_positions);
uint64_t match_score = 0;
for(const std::vector<std::vector<uint16_t>> & token_positions: array_token_positions) {
if(token_positions.empty()) {
continue;
}
const Match & match = Match::match(seq_id, token_positions);
uint64_t this_match_score = match.get_match_score(0, field_id);
if(this_match_score > match_score) {
match_score = this_match_score;
}
KV kv(field_id, searched_queries.size(), seq_id, distinct_id, match_score, scores);
curated_topster->add(&kv);
}
int64_t scores[3];
scores[0] = int64_t(match_score);
scores[1] = int64_t(1);
scores[2] = int64_t(1);
KV kv(field_id, searched_queries.size(), seq_id, seq_id, match_score, scores);
curated_topster->add(&kv);
}
searched_queries.push_back(override_query);
@ -1115,7 +1093,7 @@ void Index::search(Option<uint32_t> & outcome,
const std::vector<std::string> & search_fields,
const std::vector<filter> & filters,
std::vector<facet> & facets, facet_query_t & facet_query,
const std::vector<uint32_t> & included_ids,
const std::map<size_t, std::vector<uint32_t>> & included_ids_map,
const std::vector<uint32_t> & excluded_ids,
const std::vector<sort_by> & sort_fields_std, const int num_typos,
Topster* topster,
@ -1126,7 +1104,7 @@ void Index::search(Option<uint32_t> & outcome,
spp::sparse_hash_set<uint64_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
std::vector<std::vector<KV*>> & raw_result_kvs,
std::vector<KV*> & override_result_kvs,
std::vector<std::vector<KV*>> & override_result_kvs,
const size_t typo_tokens_threshold) {
// process the filters
@ -1141,7 +1119,16 @@ void Index::search(Option<uint32_t> & outcome,
uint32_t filter_ids_length = op_filter_ids_length.get();
// we will be removing all curated IDs from organic result ids before running topster
std::set<uint32_t> curated_ids(included_ids.begin(), included_ids.end());
std::set<uint32_t> curated_ids;
std::vector<uint32_t> included_ids;
for(const auto& pos_ids: included_ids_map) {
for(const uint32_t id: pos_ids.second) {
curated_ids.insert(id);
included_ids.push_back(id);
}
}
curated_ids.insert(excluded_ids.begin(), excluded_ids.end());
std::vector<uint32_t> curated_ids_sorted(curated_ids.begin(), curated_ids.end());
@ -1166,7 +1153,7 @@ void Index::search(Option<uint32_t> & outcome,
score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {},
groups_processed, filter_ids, filter_ids_length);
collate_included_ids(query, field, field_id, included_ids, curated_topster, searched_queries);
collate_included_ids(query, field, field_id, included_ids_map, curated_topster, searched_queries);
all_result_ids_len = filter_ids_length;
all_result_ids = filter_ids;
@ -1182,7 +1169,7 @@ void Index::search(Option<uint32_t> & outcome,
search_field(field_id, query, field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len,
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold);
collate_included_ids(query, field, field_id, included_ids, curated_topster, searched_queries);
collate_included_ids(query, field, field_id, included_ids_map, curated_topster, searched_queries);
}
}
}
@ -1202,16 +1189,23 @@ void Index::search(Option<uint32_t> & outcome,
const std::vector<KV*> group_kvs(group_topster->kvs, group_topster->kvs+group_topster->size);
raw_result_kvs.emplace_back(group_kvs);
}
for(auto &curated_topster_entry: curated_topster->group_kv_map) {
Topster* group_topster = curated_topster_entry.second;
const std::vector<KV*> group_kvs(group_topster->kvs, group_topster->kvs+group_topster->size);
override_result_kvs.emplace_back(group_kvs);
}
} else {
for(uint32_t t = 0; t < topster->size; t++) {
KV* kv = topster->getKV(t);
raw_result_kvs.push_back({kv});
}
}
for(uint32_t t = 0; t < curated_topster->size; t++) {
KV* kv = curated_topster->getKV(t);
override_result_kvs.push_back(kv);
for(uint32_t t = 0; t < curated_topster->size; t++) {
KV* kv = curated_topster->getKV(t);
override_result_kvs.push_back({kv});
}
}
// add curated IDs to result count

View File

@ -256,10 +256,9 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
}
TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
std::map<std::string, size_t> pinned_hits;
std::vector<std::string> hidden_hits;
pinned_hits["13"] = 1;
pinned_hits["4"] = 2;
std::map<size_t, std::vector<std::string>> pinned_hits;
pinned_hits[1] = {"13"};
pinned_hits[2] = {"4"};
// basic pinning
@ -279,6 +278,7 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
// both pinning and hiding
std::vector<std::string> hidden_hits;
hidden_hits = {"11", "16"};
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 50, 1, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
@ -291,6 +291,21 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("6", results["hits"][2]["document"]["id"].get<std::string>().c_str());
// paginating such that pinned hits appear on second page
pinned_hits.clear();
pinned_hits[4] = {"13"};
pinned_hits[5] = {"4"};
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, 0, 2, 2, FREQUENCY,
false, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "starring: will", 30,
"", 10,
pinned_hits, hidden_hits).get();
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("13", results["hits"][1]["document"]["id"].get<std::string>().c_str());
// take precedence over override rules
nlohmann::json override_json_include = {