Ability to disable overrides.

This commit is contained in:
Kishore Nallan 2021-06-11 20:29:21 +05:30
parent b141e01b1b
commit 0d5eef664e
4 changed files with 74 additions and 34 deletions

View File

@ -346,11 +346,12 @@ private:
void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
void populate_overrides(std::string query,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t> & excluded_ids) const;
void curate_results(std::string query,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t> & excluded_ids,
bool enable_overrides) const;
Option<bool> check_and_update_schema(nlohmann::json& document, const DIRTY_VALUES& dirty_values);
@ -532,7 +533,8 @@ public:
std::vector<size_t> query_by_weights={},
size_t limit_hits=UINT32_MAX,
bool prioritize_exact_match=true,
bool pre_segmented_query=false) const;
bool pre_segmented_query=false,
bool enable_overrides=true) const;
Option<bool> get_filter_ids(const std::string & simple_filter_query,
std::vector<std::pair<size_t, uint32_t*>>& index_ids);

View File

@ -414,11 +414,12 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash
}
}
void Collection::populate_overrides(std::string query,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t> & excluded_ids) const {
void Collection::curate_results(std::string query,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<uint32_t> & excluded_ids,
bool enable_overrides) const {
StringUtils::tolowercase(query);
std::set<uint32_t> excluded_set;
@ -435,30 +436,32 @@ void Collection::populate_overrides(std::string query,
}
}
for(const auto & override_kv: overrides) {
const auto & override = override_kv.second;
if(enable_overrides) {
for(const auto & override_kv: overrides) {
const auto & override = override_kv.second;
if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) ||
(override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) {
if( (override.rule.match == override_t::MATCH_EXACT && override.rule.query == query) ||
(override.rule.match == override_t::MATCH_CONTAINS && query.find(override.rule.query) != std::string::npos) ) {
// have to ensure that dropped hits take precedence over added hits
for(const auto & hit: override.drop_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(seq_id_op.ok()) {
excluded_ids.push_back(seq_id_op.get());
excluded_set.insert(seq_id_op.get());
// have to ensure that dropped hits take precedence over added hits
for(const auto & hit: override.drop_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(seq_id_op.ok()) {
excluded_ids.push_back(seq_id_op.get());
excluded_set.insert(seq_id_op.get());
}
}
}
for(const auto & hit: override.add_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(!seq_id_op.ok()) {
continue;
}
uint32_t seq_id = seq_id_op.get();
bool excluded = (excluded_set.count(seq_id) != 0);
if(!excluded) {
include_ids[hit.position].push_back(seq_id);
for(const auto & hit: override.add_hits) {
Option<uint32_t> seq_id_op = doc_id_to_seq_id(hit.doc_id);
if(!seq_id_op.ok()) {
continue;
}
uint32_t seq_id = seq_id_op.get();
bool excluded = (excluded_set.count(seq_id) != 0);
if(!excluded) {
include_ids[hit.position].push_back(seq_id);
}
}
}
}
@ -505,7 +508,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
std::vector<size_t> query_by_weights,
size_t limit_hits,
bool prioritize_exact_match,
bool pre_segmented_query) const {
bool pre_segmented_query,
bool enable_overrides) const {
std::shared_lock lock(mutex);
@ -550,7 +554,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
std::vector<std::string> hidden_hits;
StringUtils::split(hidden_hits_str, hidden_hits, ",");
populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids);
curate_results(query, pinned_hits, hidden_hits, include_ids, excluded_ids, enable_overrides);
/*for(auto& kv: include_ids) {
LOG(INFO) << "key: " << kv.first;

View File

@ -485,6 +485,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *PINNED_HITS = "pinned_hits";
const char *HIDDEN_HITS = "hidden_hits";
const char *ENABLE_OVERRIDES = "enable_overrides";
// strings under this length will be fully highlighted, instead of showing a snippet of relevant portion
const char *SNIPPET_THRESHOLD = "snippet_threshold";
@ -705,6 +706,12 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
req_params[HIDDEN_HITS] = "";
}
if(req_params.count(ENABLE_OVERRIDES) == 0) {
req_params[ENABLE_OVERRIDES] = "true";
}
bool enable_overrides = (req_params[ENABLE_OVERRIDES] == "true");
CollectionManager & collectionManager = CollectionManager::get_instance();
auto collection = collectionManager.get_collection(req_params["collection"]);
@ -760,7 +767,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
query_by_weights,
static_cast<size_t>(std::stol(req_params[LIMIT_HITS])),
prioritize_exact_match,
pre_segmented_query
pre_segmented_query,
enable_overrides
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -170,6 +170,32 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeExactQueryMatch) {
ASSERT_STREQ("2", results["hits"][2]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("1", results["hits"][3]["document"]["id"].get<std::string>().c_str());
// ability to disable overrides
bool enable_overrides = false;
res_op = coll_mul_fields->search("will", {"title"}, "", {}, {}, {0}, 10,
1, FREQUENCY, {false}, 0, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 0, {}, {}, {}, 0,
"<mark>", "</mark>", {1}, 10000, true, false, enable_overrides);
ASSERT_TRUE(res_op.ok());
results = res_op.get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ(2, results["found"].get<uint32_t>());
ASSERT_STREQ("3", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get<std::string>().c_str());
enable_overrides = true;
res_op = coll_mul_fields->search("will", {"title"}, "", {}, {}, {0}, 10,
1, FREQUENCY, {false}, 0, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 0, {}, {}, {}, 0,
"<mark>", "</mark>", {1}, 10000, true, false, enable_overrides);
ASSERT_TRUE(res_op.ok());
results = res_op.get();
ASSERT_EQ(4, results["hits"].size());
ASSERT_EQ(4, results["found"].get<uint32_t>());
coll_mul_fields->remove_override("include-rule");
}