Tests for asc/desc sort order.

This commit is contained in:
Kishore Nallan 2017-05-14 12:25:59 +05:30
parent 060959ad70
commit 1992d92eaf
4 changed files with 53 additions and 3 deletions

View File

@ -30,8 +30,8 @@
- ~~Schema validation during insertion (missing fields + type errors)~~
- ~~Proper score field for ranking tokens~~
- ~~Throw errors when schema is broken~~
- Desc/Asc ordering with tests
- Found count is wrong
- ~~Desc/Asc ordering with tests~~
- ~~Found count is wrong~~
- Proper pagination
- Filter query in the API
- Prevent string copy during indexing

View File

@ -67,6 +67,8 @@ struct filter {
namespace sort_field_const {
static const std::string name = "name";
static const std::string order = "order";
static const std::string asc = "ASC";
static const std::string desc = "DESC";
}
struct sort_field {

View File

@ -640,6 +640,7 @@ nlohmann::json Collection::search(std::string query, const std::vector<std::stri
delete [] filter_ids;
// All fields are sorted descending
std::sort(field_order_kvs.begin(), field_order_kvs.end(),
[](const std::pair<int, Topster<100>::KV> & a, const std::pair<int, Topster<100>::KV> & b) {
if(a.second.match_score != b.second.match_score) return a.second.match_score > b.second.match_score;
@ -867,13 +868,23 @@ void Collection::score_results(const std::vector<sort_field> & sort_fields, cons
spp::sparse_hash_map<uint32_t, int64_t> * primary_rank_scores = nullptr;
spp::sparse_hash_map<uint32_t, int64_t> * secondary_rank_scores = nullptr;
// Used for asc/desc ordering. NOTE: Topster keeps biggest keys (i.e. it's desc in nature)
int64_t primary_rank_factor = 1;
int64_t secondary_rank_factor = 1;
if(sort_fields.size() > 0) {
// assumed that rank field exists in the index - checked earlier in the chain
primary_rank_scores = sort_index.at(sort_fields[0].name);
if(sort_fields[0].order == sort_field_const::asc) {
primary_rank_factor = -1;
}
}
if(sort_fields.size() > 1) {
secondary_rank_scores = sort_index.at(sort_fields[1].name);
if(sort_fields[1].order == sort_field_const::asc) {
secondary_rank_factor = -1;
}
}
for(auto i=0; i<result_size; i++) {
@ -919,7 +930,10 @@ void Collection::score_results(const std::vector<sort_field> & sort_fields, cons
int64_t primary_rank_score = primary_rank_scores->count(seq_id) > 0 ? primary_rank_scores->at(seq_id) : 0;
int64_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ?
secondary_rank_scores->at(seq_id) : 0;
topster.add(seq_id, match_score, primary_rank_score, secondary_rank_score);
topster.add(seq_id, match_score,
primary_rank_factor * primary_rank_score,
secondary_rank_factor * secondary_rank_score);
/*std::cout << "token_rank_score: " << token_rank_score << ", match_score: "
<< match_score << ", primary_rank_score: " << primary_rank_score << ", seq_id: " << seq_id << std::endl;*/
}

View File

@ -2,6 +2,7 @@
#include <string>
#include <vector>
#include <fstream>
#include <algorithm>
#include <collection_manager.h>
#include "collection.h"
@ -73,6 +74,22 @@ TEST_F(CollectionTest, ExactSearchShouldBeStable) {
std::string result_id = result["id"];
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// check ASC sorting
std::vector<sort_field> sort_fields_asc = { sort_field("points", "ASC") };
results = collection->search("the", query_fields, "", facets, sort_fields_asc, 0, 10);
ASSERT_EQ(7, results["hits"].size());
ASSERT_EQ(7, results["found"].get<int>());
ids = {"16", "13", "10", "8", "6", "foo", "1"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string id = ids.at(i);
std::string result_id = result["id"];
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
}
TEST_F(CollectionTest, ExactPhraseSearch) {
@ -99,11 +116,28 @@ TEST_F(CollectionTest, ExactPhraseSearch) {
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// Check ASC sort order
std::vector<sort_field> sort_fields_asc = { sort_field("points", "ASC") };
results = collection->search("rocket launch", query_fields, "", facets, sort_fields_asc, 0, 10);
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ(5, results["found"].get<uint32_t>());
ids = {"8", "17", "1", "16", "13"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string id = ids.at(i);
std::string result_id = result["id"];
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// Check pagination
results = collection->search("rocket launch", query_fields, "", facets, sort_fields, 0, 3);
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ(4, results["found"].get<uint32_t>());
ids = {"8", "1", "17", "16", "13"};
for(size_t i = 0; i < 3; i++) {
nlohmann::json result = results["hits"].at(i);
std::string id = ids.at(i);