diff --git a/TODO.md b/TODO.md index b369a5d2..79520a37 100644 --- a/TODO.md +++ b/TODO.md @@ -34,8 +34,11 @@ - ~~Found count is wrong~~ - ~~Filter query in the API~~ - Fix API response codes +- Test for search without any sort_by given +- Test for asc/desc upper/lower casing - Test for collection creation validation - Proper pagination +- Deprecate old split function - Prevent string copy during indexing - clean special chars before indexing - Minimum results should be a variable instead of blindly going with max_results diff --git a/include/string_utils.h b/include/string_utils.h index a93c6bff..200e3371 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -122,4 +122,8 @@ struct StringUtils { strtol(s.c_str(), &p, 10); return (*p == 0); } + + static void toupper(std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + } }; \ No newline at end of file diff --git a/src/api.cpp b/src/api.cpp index fdcf179f..d347a071 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -1,5 +1,6 @@ #include #include "api.h" +#include "string_utils.h" #include "collection.h" #include "collection_manager.h" @@ -105,8 +106,11 @@ void post_create_collection(http_req & req, http_res & res) { void get_search(http_req & req, http_res & res) { const char *NUM_TYPOS = "num_typos"; const char *PREFIX = "prefix"; + const char *FILTER = "filter"; + const char *SEARCH_BY = "search_by"; + const char *SORT_BY = "sort_by"; + const char *FACET_BY = "facet_by"; const char *TOKEN_ORDERING = "token_ordering"; - const char *FILTERS = "filters"; if(req.params.count(NUM_TYPOS) == 0) { req.params[NUM_TYPOS] = "2"; @@ -120,15 +124,43 @@ void get_search(http_req & req, http_res & res) { req.params[TOKEN_ORDERING] = "FREQUENCY"; } - std::string filter_str = req.params.count(FILTERS) != 0 ? req.params[FILTERS] : ""; - //std::cout << "filter_str: " << filter_str << std::endl; + if(req.params.count(SEARCH_BY) == 0) { + return res.send_400(std::string("Parameter `") + SEARCH_BY + "` is required."); + } + + std::string filter_str = req.params.count(FILTER) != 0 ? req.params[FILTER] : ""; token_ordering token_order = (req.params[TOKEN_ORDERING] == "MAX_SCORE") ? MAX_SCORE : FREQUENCY; - //printf("Query: %s\n", req.params["q"].c_str()); - auto begin = std::chrono::high_resolution_clock::now(); + std::vector search_fields; + StringUtils::split(req.params[SEARCH_BY], search_fields, ","); - std::vector search_fields = {"title"}; + std::vector facet_fields; + StringUtils::split(req.params[FACET_BY], facet_fields, ","); + + std::vector sort_fields; + if(req.params.count(SORT_BY) != 0) { + std::vector sort_field_strs; + StringUtils::split(req.params[SORT_BY], sort_field_strs, ","); + + if(sort_field_strs.size() > 2) { + return res.send_400("Only upto 2 sort fields are allowed."); + } + + for(const std::string & sort_field_str: sort_field_strs) { + std::vector expression_parts; + StringUtils::split(sort_field_str, expression_parts, ":"); + + if(expression_parts.size() != 2) { + return res.send_400(std::string("Parameter `") + SORT_BY + "` is malformed."); + } + + StringUtils::toupper(expression_parts[1]); + sort_fields.push_back(sort_field(expression_parts[0], expression_parts[1])); + } + } + + auto begin = std::chrono::high_resolution_clock::now(); CollectionManager & collectionManager = CollectionManager::get_instance(); Collection* collection = collectionManager.get_collection(req.params["collection"]); @@ -137,8 +169,8 @@ void get_search(http_req & req, http_res & res) { return res.send_404(); } - nlohmann::json result = collection->search(req.params["q"], search_fields, filter_str, { }, - {sort_field("points", "DESC")}, std::stoi(req.params[NUM_TYPOS]), 100, + nlohmann::json result = collection->search(req.params["q"], search_fields, filter_str, facet_fields, + sort_fields, std::stoi(req.params[NUM_TYPOS]), 100, token_order, false); const std::string & json_str = result.dump(); //std::cout << "JSON:" << json_str << std::endl; diff --git a/src/collection.cpp b/src/collection.cpp index a581bba8..d4d479ab 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -605,6 +605,11 @@ nlohmann::json Collection::search(std::string query, const std::vector & sort_fields, cons ((uint64_t)(mscore.words_present) << 8) + (MAX_SEARCH_TOKENS - mscore.distance); - int64_t primary_rank_score = primary_rank_scores->count(seq_id) > 0 ? primary_rank_scores->at(seq_id) : 0; + int64_t primary_rank_score = (primary_rank_scores && primary_rank_scores->count(seq_id) > 0) ? + primary_rank_scores->at(seq_id) : 0; int64_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ? secondary_rank_scores->at(seq_id) : 0; topster.add(seq_id, match_score,