Address prefix search issues.

Score based comparison was broken - test has been enhanced.
This commit is contained in:
Kishore Nallan 2017-08-18 15:26:17 +05:30
parent 38fbbea71f
commit f5848be750
7 changed files with 67 additions and 20 deletions

View File

@ -95,7 +95,7 @@ private:
std::vector<std::vector<art_leaf*>> & token_to_candidates,
std::vector<std::vector<art_leaf*>> & searched_queries, Topster<100> & topster,
size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len,
const size_t & max_results);
const size_t & max_results, const bool prefix);
void index_string_field(const std::string & text, const uint32_t score, art_tree *t, uint32_t seq_id,
const bool verbatim) const;

View File

@ -116,14 +116,15 @@ void post_create_collection(http_req & req, http_res & res) {
sort_fields.push_back(field(sort_field_json["name"], sort_field_json["type"]));
}
const char* PREFIX_RANKING_FIELD = "prefix_ranking_field";
std::string token_ranking_field = "";
if(req_json.count("token_ranking_field") != 0) {
if(!req_json["token_ranking_field"].is_string()) {
return res.send_400("Wrong format for `token_ranking_field`. It should be a string (name of a field).");
if(req_json.count(PREFIX_RANKING_FIELD) != 0) {
if(!req_json[PREFIX_RANKING_FIELD].is_string()) {
return res.send_400(std::string("Wrong format for `") + PREFIX_RANKING_FIELD + "`. It should be the name of an unsigned INT32 field.");
}
token_ranking_field = req_json["token_ranking_field"].get<std::string>();
token_ranking_field = req_json[PREFIX_RANKING_FIELD].get<std::string>();
}
collectionManager.create_collection(req_json["name"], search_fields, facet_fields, sort_fields, token_ranking_field);
@ -157,6 +158,7 @@ void get_search(http_req & req, http_res & res) {
const char *PER_PAGE = "per_page";
const char *PAGE = "page";
const char *CALLBACK = "callback";
const char *SORT_PREFIXES_BY = "sort_prefixes_by";
if(req.params.count(NUM_TYPOS) == 0) {
req.params[NUM_TYPOS] = "2";
@ -217,11 +219,17 @@ void get_search(http_req & req, http_res & res) {
bool prefix = (req.params[PREFIX] == "true");
token_ordering token_order = FREQUENCY;
if(prefix && !collection->get_token_ranking_field().empty()) {
token_order = MAX_SCORE;
if(req.params.count(SORT_PREFIXES_BY) == 0) {
if(prefix && !collection->get_token_ranking_field().empty()) {
req.params[SORT_PREFIXES_BY] = "PREFIX_SORT_FIELD";
} else {
req.params[SORT_PREFIXES_BY] = "TERM_FREQUENCY";
}
}
StringUtils::toupper(req.params[SORT_PREFIXES_BY]);
token_ordering token_order = (req.params[SORT_PREFIXES_BY] == "PREFIX_SORT_FIELD") ? MAX_SCORE : FREQUENCY;
Option<nlohmann::json> result_op = collection->search(req.params["q"], search_fields, filter_str, facet_fields,
sort_fields, std::stoi(req.params[NUM_TYPOS]),
std::stoi(req.params[PER_PAGE]), std::stoi(req.params[PAGE]),

View File

@ -91,7 +91,7 @@ bool compare_art_node_frequency_pq(const art_node *a, const art_node *b) {
}
bool compare_art_node_score_pq(const art_node* a, const art_node* b) {
return !compare_art_node_frequency(a, b);
return !compare_art_node_score(a, b);
}
/**

View File

@ -405,7 +405,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
std::vector<std::vector<art_leaf*>> & token_to_candidates,
std::vector<std::vector<art_leaf*>> & searched_queries, Topster<100> & topster,
size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len,
const size_t & max_results) {
const size_t & max_results, const bool prefix) {
const size_t combination_limit = 10;
auto product = []( long long a, std::vector<art_leaf*>& b ) { return a*b.size(); };
long long int N = std::accumulate(token_to_candidates.begin(), token_to_candidates.end(), 1LL, product);
@ -422,10 +422,17 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
uint32_t* result_ids = query_suggestion[0]->values->ids.uncompress();
size_t result_size = query_suggestion[0]->values->ids.getLength();
if(result_size == 0) continue;
if(result_size == 0) {
continue;
}
candidate_rank += 1;
int actual_candidate_rank = candidate_rank;
if(prefix) {
actual_candidate_rank = 0;
}
// intersect the document ids for each token to find docs that contain all the tokens (stored in `result_ids`)
for(auto i=1; i < query_suggestion.size(); i++) {
uint32_t* out = nullptr;
@ -449,7 +456,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
do_facets(facets, filtered_result_ids, filtered_results_size);
// go through each matching document id and calculate match score
score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion,
score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion,
filtered_result_ids, filtered_results_size);
delete[] filtered_result_ids;
@ -463,14 +470,19 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
delete [] *all_result_ids;
*all_result_ids = new_all_result_ids;
score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion, result_ids, result_size);
score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion,
result_ids, result_size);
delete[] result_ids;
}
total_results += topster.size;
searched_queries.push_back(query_suggestion);
if(total_results >= max_results) {
if(!prefix && total_results >= max_results) {
break;
}
if(prefix && candidate_rank >= max_results) {
break;
}
}
@ -944,7 +956,6 @@ void Collection::search_field(std::string & query, const std::string & field, ui
leaves = token_cost_cache[token_cost_hash];
} else {
int token_len = prefix ? (int) token.length() : (int) token.length() + 1;
int count = search_index.count(field);
art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len,
@ -985,12 +996,17 @@ void Collection::search_field(std::string & query, const std::string & field, ui
if(token_to_candidates.size() != 0 && token_to_candidates.size() == tokens.size()) {
// If all tokens were found, go ahead and search for candidates with what we have so far
search_candidates(filter_ids, filter_ids_length, facets, sort_fields, candidate_rank, token_to_candidates,
searched_queries, topster, total_results, all_result_ids, all_result_ids_len, max_results);
searched_queries, topster, total_results, all_result_ids, all_result_ids_len,
max_results, prefix);
if (total_results >= max_results) {
if (!prefix && total_results >= max_results) {
// If we don't find enough results, we continue outerloop (looking at tokens with greater cost)
break;
}
if(prefix && candidate_rank > 10) {
break;
}
}
n++;

View File

@ -24,7 +24,7 @@ int main(int argc, char* argv[]) {
Collection *collection = collectionManager.get_collection("hnstories_direct");
if(collection == nullptr) {
collection = collectionManager.create_collection("hnstories_direct", fields_to_index, {}, sort_fields);
collection = collectionManager.create_collection("hnstories_direct", fields_to_index, {}, sort_fields, "points");
}
std::ifstream infile("/Users/kishore/Downloads/hnstories.jsonl");

View File

@ -349,7 +349,7 @@ TEST_F(CollectionTest, PrefixSearching) {
std::vector<std::string> facets;
nlohmann::json results = collection->search("ex", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, true).get();
ASSERT_EQ(2, results["hits"].size());
std::vector<std::string> ids = {"12", "6"};
std::vector<std::string> ids = {"6", "12"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
@ -379,6 +379,29 @@ TEST_F(CollectionTest, PrefixSearching) {
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// restrict to only 2 results and differentiate between MAX_SCORE and FREQUENCY
results = collection->search("t", query_fields, "", facets, sort_fields, 0, 2, 1, MAX_SCORE, true).get();
ASSERT_EQ(2, results["hits"].size());
ids = {"19", "22"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
results = collection->search("t", query_fields, "", facets, sort_fields, 0, 2, 1, FREQUENCY, true).get();
ASSERT_EQ(2, results["hits"].size());
ids = {"1", "6"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
}
TEST_F(CollectionTest, MultipleFields) {

View File

@ -16,7 +16,7 @@
{"points":10,"title":"How late do the launch propellants ionize in a chemical rocket mission?"}
{"points":8,"title":"How much does it cost to launch (right from start) a rocket today?"}
{"points":16,"title":"Difference between Space Dynamics & Astrodynamics in engineering perspective?"}
{"points":18,"title":"What kind of biological research does ISS do?"}
{"points":18,"title":"What kind of biological research does ISS do then?"}
{"points":10,"title":"Which kinds of radiation hit ISX ?"}
{"points":7,"title":"What kinds of things have been tossed out of ISS in space?"}
{"points":17,"title":"What does triple redundant closed loop digital avionics system mean?"}