mirror of
https://github.com/typesense/typesense.git
synced 2025-05-16 19:55:21 +08:00
Address prefix search issues.
Score based comparison was broken - test has been enhanced.
This commit is contained in:
parent
38fbbea71f
commit
f5848be750
@ -95,7 +95,7 @@ private:
|
||||
std::vector<std::vector<art_leaf*>> & token_to_candidates,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries, Topster<100> & topster,
|
||||
size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len,
|
||||
const size_t & max_results);
|
||||
const size_t & max_results, const bool prefix);
|
||||
|
||||
void index_string_field(const std::string & text, const uint32_t score, art_tree *t, uint32_t seq_id,
|
||||
const bool verbatim) const;
|
||||
|
22
src/api.cpp
22
src/api.cpp
@ -116,14 +116,15 @@ void post_create_collection(http_req & req, http_res & res) {
|
||||
sort_fields.push_back(field(sort_field_json["name"], sort_field_json["type"]));
|
||||
}
|
||||
|
||||
const char* PREFIX_RANKING_FIELD = "prefix_ranking_field";
|
||||
std::string token_ranking_field = "";
|
||||
|
||||
if(req_json.count("token_ranking_field") != 0) {
|
||||
if(!req_json["token_ranking_field"].is_string()) {
|
||||
return res.send_400("Wrong format for `token_ranking_field`. It should be a string (name of a field).");
|
||||
if(req_json.count(PREFIX_RANKING_FIELD) != 0) {
|
||||
if(!req_json[PREFIX_RANKING_FIELD].is_string()) {
|
||||
return res.send_400(std::string("Wrong format for `") + PREFIX_RANKING_FIELD + "`. It should be the name of an unsigned INT32 field.");
|
||||
}
|
||||
|
||||
token_ranking_field = req_json["token_ranking_field"].get<std::string>();
|
||||
token_ranking_field = req_json[PREFIX_RANKING_FIELD].get<std::string>();
|
||||
}
|
||||
|
||||
collectionManager.create_collection(req_json["name"], search_fields, facet_fields, sort_fields, token_ranking_field);
|
||||
@ -157,6 +158,7 @@ void get_search(http_req & req, http_res & res) {
|
||||
const char *PER_PAGE = "per_page";
|
||||
const char *PAGE = "page";
|
||||
const char *CALLBACK = "callback";
|
||||
const char *SORT_PREFIXES_BY = "sort_prefixes_by";
|
||||
|
||||
if(req.params.count(NUM_TYPOS) == 0) {
|
||||
req.params[NUM_TYPOS] = "2";
|
||||
@ -217,11 +219,17 @@ void get_search(http_req & req, http_res & res) {
|
||||
|
||||
bool prefix = (req.params[PREFIX] == "true");
|
||||
|
||||
token_ordering token_order = FREQUENCY;
|
||||
if(prefix && !collection->get_token_ranking_field().empty()) {
|
||||
token_order = MAX_SCORE;
|
||||
if(req.params.count(SORT_PREFIXES_BY) == 0) {
|
||||
if(prefix && !collection->get_token_ranking_field().empty()) {
|
||||
req.params[SORT_PREFIXES_BY] = "PREFIX_SORT_FIELD";
|
||||
} else {
|
||||
req.params[SORT_PREFIXES_BY] = "TERM_FREQUENCY";
|
||||
}
|
||||
}
|
||||
|
||||
StringUtils::toupper(req.params[SORT_PREFIXES_BY]);
|
||||
token_ordering token_order = (req.params[SORT_PREFIXES_BY] == "PREFIX_SORT_FIELD") ? MAX_SCORE : FREQUENCY;
|
||||
|
||||
Option<nlohmann::json> result_op = collection->search(req.params["q"], search_fields, filter_str, facet_fields,
|
||||
sort_fields, std::stoi(req.params[NUM_TYPOS]),
|
||||
std::stoi(req.params[PER_PAGE]), std::stoi(req.params[PAGE]),
|
||||
|
@ -91,7 +91,7 @@ bool compare_art_node_frequency_pq(const art_node *a, const art_node *b) {
|
||||
}
|
||||
|
||||
bool compare_art_node_score_pq(const art_node* a, const art_node* b) {
|
||||
return !compare_art_node_frequency(a, b);
|
||||
return !compare_art_node_score(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -405,7 +405,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
|
||||
std::vector<std::vector<art_leaf*>> & token_to_candidates,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries, Topster<100> & topster,
|
||||
size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len,
|
||||
const size_t & max_results) {
|
||||
const size_t & max_results, const bool prefix) {
|
||||
const size_t combination_limit = 10;
|
||||
auto product = []( long long a, std::vector<art_leaf*>& b ) { return a*b.size(); };
|
||||
long long int N = std::accumulate(token_to_candidates.begin(), token_to_candidates.end(), 1LL, product);
|
||||
@ -422,10 +422,17 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
|
||||
uint32_t* result_ids = query_suggestion[0]->values->ids.uncompress();
|
||||
size_t result_size = query_suggestion[0]->values->ids.getLength();
|
||||
|
||||
if(result_size == 0) continue;
|
||||
if(result_size == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
candidate_rank += 1;
|
||||
|
||||
int actual_candidate_rank = candidate_rank;
|
||||
if(prefix) {
|
||||
actual_candidate_rank = 0;
|
||||
}
|
||||
|
||||
// intersect the document ids for each token to find docs that contain all the tokens (stored in `result_ids`)
|
||||
for(auto i=1; i < query_suggestion.size(); i++) {
|
||||
uint32_t* out = nullptr;
|
||||
@ -449,7 +456,7 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
|
||||
do_facets(facets, filtered_result_ids, filtered_results_size);
|
||||
|
||||
// go through each matching document id and calculate match score
|
||||
score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion,
|
||||
score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion,
|
||||
filtered_result_ids, filtered_results_size);
|
||||
|
||||
delete[] filtered_result_ids;
|
||||
@ -463,14 +470,19 @@ void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_lengt
|
||||
delete [] *all_result_ids;
|
||||
*all_result_ids = new_all_result_ids;
|
||||
|
||||
score_results(sort_fields, searched_queries.size(), candidate_rank, topster, query_suggestion, result_ids, result_size);
|
||||
score_results(sort_fields, searched_queries.size(), actual_candidate_rank, topster, query_suggestion,
|
||||
result_ids, result_size);
|
||||
delete[] result_ids;
|
||||
}
|
||||
|
||||
total_results += topster.size;
|
||||
searched_queries.push_back(query_suggestion);
|
||||
|
||||
if(total_results >= max_results) {
|
||||
if(!prefix && total_results >= max_results) {
|
||||
break;
|
||||
}
|
||||
|
||||
if(prefix && candidate_rank >= max_results) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -944,7 +956,6 @@ void Collection::search_field(std::string & query, const std::string & field, ui
|
||||
leaves = token_cost_cache[token_cost_hash];
|
||||
} else {
|
||||
int token_len = prefix ? (int) token.length() : (int) token.length() + 1;
|
||||
|
||||
int count = search_index.count(field);
|
||||
|
||||
art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len,
|
||||
@ -985,12 +996,17 @@ void Collection::search_field(std::string & query, const std::string & field, ui
|
||||
if(token_to_candidates.size() != 0 && token_to_candidates.size() == tokens.size()) {
|
||||
// If all tokens were found, go ahead and search for candidates with what we have so far
|
||||
search_candidates(filter_ids, filter_ids_length, facets, sort_fields, candidate_rank, token_to_candidates,
|
||||
searched_queries, topster, total_results, all_result_ids, all_result_ids_len, max_results);
|
||||
searched_queries, topster, total_results, all_result_ids, all_result_ids_len,
|
||||
max_results, prefix);
|
||||
|
||||
if (total_results >= max_results) {
|
||||
if (!prefix && total_results >= max_results) {
|
||||
// If we don't find enough results, we continue outerloop (looking at tokens with greater cost)
|
||||
break;
|
||||
}
|
||||
|
||||
if(prefix && candidate_rank > 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
n++;
|
||||
|
@ -24,7 +24,7 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
Collection *collection = collectionManager.get_collection("hnstories_direct");
|
||||
if(collection == nullptr) {
|
||||
collection = collectionManager.create_collection("hnstories_direct", fields_to_index, {}, sort_fields);
|
||||
collection = collectionManager.create_collection("hnstories_direct", fields_to_index, {}, sort_fields, "points");
|
||||
}
|
||||
|
||||
std::ifstream infile("/Users/kishore/Downloads/hnstories.jsonl");
|
||||
|
@ -349,7 +349,7 @@ TEST_F(CollectionTest, PrefixSearching) {
|
||||
std::vector<std::string> facets;
|
||||
nlohmann::json results = collection->search("ex", query_fields, "", facets, sort_fields, 0, 10, 1, FREQUENCY, true).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
std::vector<std::string> ids = {"12", "6"};
|
||||
std::vector<std::string> ids = {"6", "12"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
@ -379,6 +379,29 @@ TEST_F(CollectionTest, PrefixSearching) {
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// restrict to only 2 results and differentiate between MAX_SCORE and FREQUENCY
|
||||
results = collection->search("t", query_fields, "", facets, sort_fields, 0, 2, 1, MAX_SCORE, true).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ids = {"19", "22"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = collection->search("t", query_fields, "", facets, sort_fields, 0, 2, 1, FREQUENCY, true).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ids = {"1", "6"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, MultipleFields) {
|
||||
|
@ -16,7 +16,7 @@
|
||||
{"points":10,"title":"How late do the launch propellants ionize in a chemical rocket mission?"}
|
||||
{"points":8,"title":"How much does it cost to launch (right from start) a rocket today?"}
|
||||
{"points":16,"title":"Difference between Space Dynamics & Astrodynamics in engineering perspective?"}
|
||||
{"points":18,"title":"What kind of biological research does ISS do?"}
|
||||
{"points":18,"title":"What kind of biological research does ISS do then?"}
|
||||
{"points":10,"title":"Which kinds of radiation hit ISX ?"}
|
||||
{"points":7,"title":"What kinds of things have been tossed out of ISS in space?"}
|
||||
{"points":17,"title":"What does triple redundant closed loop digital avionics system mean?"}
|
||||
|
Loading…
x
Reference in New Issue
Block a user