From 1992d92eafde8a698b78508b930fc3bc131116e9 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 14 May 2017 12:25:59 +0530 Subject: [PATCH] Tests for asc/desc sort order. --- TODO.md | 4 ++-- include/field.h | 2 ++ src/collection.cpp | 16 +++++++++++++++- test/collection_test.cpp | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/TODO.md b/TODO.md index b005565b..9ca2d9c9 100644 --- a/TODO.md +++ b/TODO.md @@ -30,8 +30,8 @@ - ~~Schema validation during insertion (missing fields + type errors)~~ - ~~Proper score field for ranking tokens~~ - ~~Throw errors when schema is broken~~ -- Desc/Asc ordering with tests -- Found count is wrong +- ~~Desc/Asc ordering with tests~~ +- ~~Found count is wrong~~ - Proper pagination - Filter query in the API - Prevent string copy during indexing diff --git a/include/field.h b/include/field.h index f3673501..07193a33 100644 --- a/include/field.h +++ b/include/field.h @@ -67,6 +67,8 @@ struct filter { namespace sort_field_const { static const std::string name = "name"; static const std::string order = "order"; + static const std::string asc = "ASC"; + static const std::string desc = "DESC"; } struct sort_field { diff --git a/src/collection.cpp b/src/collection.cpp index efe156ad..52dc2979 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -640,6 +640,7 @@ nlohmann::json Collection::search(std::string query, const std::vector::KV> & a, const std::pair::KV> & b) { if(a.second.match_score != b.second.match_score) return a.second.match_score > b.second.match_score; @@ -867,13 +868,23 @@ void Collection::score_results(const std::vector & sort_fields, cons spp::sparse_hash_map * primary_rank_scores = nullptr; spp::sparse_hash_map * secondary_rank_scores = nullptr; + // Used for asc/desc ordering. NOTE: Topster keeps biggest keys (i.e. it's desc in nature) + int64_t primary_rank_factor = 1; + int64_t secondary_rank_factor = 1; + if(sort_fields.size() > 0) { // assumed that rank field exists in the index - checked earlier in the chain primary_rank_scores = sort_index.at(sort_fields[0].name); + if(sort_fields[0].order == sort_field_const::asc) { + primary_rank_factor = -1; + } } if(sort_fields.size() > 1) { secondary_rank_scores = sort_index.at(sort_fields[1].name); + if(sort_fields[1].order == sort_field_const::asc) { + secondary_rank_factor = -1; + } } for(auto i=0; i & sort_fields, cons int64_t primary_rank_score = primary_rank_scores->count(seq_id) > 0 ? primary_rank_scores->at(seq_id) : 0; int64_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ? secondary_rank_scores->at(seq_id) : 0; - topster.add(seq_id, match_score, primary_rank_score, secondary_rank_score); + topster.add(seq_id, match_score, + primary_rank_factor * primary_rank_score, + secondary_rank_factor * secondary_rank_score); + /*std::cout << "token_rank_score: " << token_rank_score << ", match_score: " << match_score << ", primary_rank_score: " << primary_rank_score << ", seq_id: " << seq_id << std::endl;*/ } diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 56a91593..4e79501d 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "collection.h" @@ -73,6 +74,22 @@ TEST_F(CollectionTest, ExactSearchShouldBeStable) { std::string result_id = result["id"]; ASSERT_STREQ(id.c_str(), result_id.c_str()); } + + // check ASC sorting + std::vector sort_fields_asc = { sort_field("points", "ASC") }; + + results = collection->search("the", query_fields, "", facets, sort_fields_asc, 0, 10); + ASSERT_EQ(7, results["hits"].size()); + ASSERT_EQ(7, results["found"].get()); + + ids = {"16", "13", "10", "8", "6", "foo", "1"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string id = ids.at(i); + std::string result_id = result["id"]; + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } } TEST_F(CollectionTest, ExactPhraseSearch) { @@ -99,11 +116,28 @@ TEST_F(CollectionTest, ExactPhraseSearch) { ASSERT_STREQ(id.c_str(), result_id.c_str()); } + // Check ASC sort order + std::vector sort_fields_asc = { sort_field("points", "ASC") }; + results = collection->search("rocket launch", query_fields, "", facets, sort_fields_asc, 0, 10); + ASSERT_EQ(5, results["hits"].size()); + ASSERT_EQ(5, results["found"].get()); + + ids = {"8", "17", "1", "16", "13"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string id = ids.at(i); + std::string result_id = result["id"]; + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + // Check pagination results = collection->search("rocket launch", query_fields, "", facets, sort_fields, 0, 3); ASSERT_EQ(3, results["hits"].size()); ASSERT_EQ(4, results["found"].get()); + ids = {"8", "1", "17", "16", "13"}; + for(size_t i = 0; i < 3; i++) { nlohmann::json result = results["hits"].at(i); std::string id = ids.at(i);