Settle for partial matches when the whole query produces no results.

This commit is contained in:
Kishore Nallan 2016-11-26 17:13:16 +05:30
parent 396e10be5d
commit 4e10fadeb7
3 changed files with 48 additions and 11 deletions

View File

@ -28,7 +28,7 @@ private:
static inline std::vector<art_leaf *> next_suggestion(const std::vector<std::vector<art_leaf *>> &token_leaves,
long long int n);
void log_leaves(const int max_cost, const std::string &token, const std::vector<art_leaf *> &leaves) const;
void log_leaves(const int cost, const std::string &token, const std::vector<art_leaf *> &leaves) const;
void search_candidates(std::vector<std::vector<art_leaf*>> & token_leaves, Topster<100> & topster,
size_t & total_results, const size_t & max_results);

View File

@ -90,6 +90,12 @@ void Collection::search_candidates(std::vector<std::vector<art_leaf*>> & token_l
// every element in `query_suggestion` represents a token and its associated hits
std::vector<art_leaf *> query_suggestion = next_suggestion(token_leaves, n);
/*std:: cout << "\nSuggestion: ";
for(auto suggestion_leaf: query_suggestion) {
std:: cout << suggestion_leaf->key << " ";
}
std::cout << std::endl;*/
// initialize results with the starting element (for further intersection)
uint32_t* result_ids = query_suggestion[0]->values->ids.uncompress();
size_t result_size = query_suggestion[0]->values->ids.getLength();
@ -136,6 +142,7 @@ std::vector<nlohmann::json> Collection::search(std::string query, const int num_
size_t total_results = 0;
std::vector<nlohmann::json> results;
Topster<100> topster;
spp::sparse_hash_map<std::string, uint32_t> token_to_count;
auto begin = std::chrono::high_resolution_clock::now();
@ -178,8 +185,9 @@ std::vector<nlohmann::json> Collection::search(std::string query, const int num_
art_fuzzy_search(&t, (const unsigned char *) token.c_str(), (int) token.length() + 1, costs[token_index], 3, leaves);
if(!leaves.empty()) {
//log_leaves(max_cost, token, leaves);
log_leaves(costs[token_index], token, leaves);
token_leaves.push_back(leaves);
token_to_count[token] = leaves.at(0)->values->ids.getLength();
} else {
// no result when `cost = costs[token_index]` => remove cost for token and re-do combinations
auto it = std::find(token_to_costs[token_index].begin(), token_to_costs[token_index].end(), costs[token_index]);
@ -231,8 +239,28 @@ std::vector<nlohmann::json> Collection::search(std::string query, const int num_
n++;
}
if(results.size() == 0) {
// FIXME: We could drop certain tokens and try searching again
if(results.size() == 0 && token_to_count.size() != 0) {
// Drop certain token with least hits and try searching again
std::string truncated_query;
std::vector<std::pair<std::string, uint32_t>> token_count_pairs;
for (auto itr = token_to_count.begin(); itr != token_to_count.end(); ++itr) {
token_count_pairs.push_back(*itr);
}
std::sort(token_count_pairs.begin(), token_count_pairs.end(), [=]
(const std::pair<std::string, uint32_t>& a, const std::pair<std::string, uint32_t>& b) {
return a.second > b.second;
}
);
for(uint32_t i = 0; i < token_count_pairs.size()-1; i++) {
if(token_to_count.count(tokens[i]) != 0) {
truncated_query += " " + token_count_pairs.at(i).first;
}
}
return search(truncated_query, num_typos, num_results);
}
long long int timeMillis = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - begin).count();
@ -241,14 +269,14 @@ std::vector<nlohmann::json> Collection::search(std::string query, const int num_
return results;
}
void Collection::log_leaves(const int max_cost, const std::string &token, const std::vector<art_leaf *> &leaves) const {
void Collection::log_leaves(const int cost, const std::string &token, const std::vector<art_leaf *> &leaves) const {
printf("Token: %s, cost: %d, candidates: \n", token.c_str(), cost);
for(auto i=0; i < leaves.size(); i++) {
printf("%s - ", token.c_str());
printf("%.*s", leaves[i]->key_len, leaves[i]->key);
printf(" - max_cost: %d, - num_ids: %d\n", max_cost, leaves[i]->values->ids.getLength());
for(auto j=0; j<leaves[i]->values->ids.getLength(); j++) {
printf("%.*s, ", leaves[i]->key_len, leaves[i]->key);
printf("num_ids: %d\n", leaves[i]->values->ids.getLength());
/*for(auto j=0; j<leaves[i]->values->ids.getLength(); j++) {
printf("id: %d\n", leaves[i]->values->ids.at(j));
}
}*/
}
}

View File

@ -106,7 +106,16 @@ TEST_F(CollectionTest, SkipUnindexedTokensDuringPhraseSearch) {
TEST_F(CollectionTest, PartialPhraseSearch) {
std::vector<nlohmann::json> results = collection->search("rocket research", 0, 10);
//ASSERT_EQ(1, results.size());
ASSERT_EQ(4, results.size());
std::vector<std::string> ids = {"1", "8", "16", "17"};
for(size_t i = 0; i < results.size(); i++) {
nlohmann::json result = results.at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
}
TEST_F(CollectionTest, RegressionTest1) {