Expose search, sort and facet fields to the API.

This commit is contained in:
Kishore Nallan 2017-05-16 20:55:06 +05:30
parent a25d2f590d
commit 56c539d0df
4 changed files with 54 additions and 9 deletions

View File

@ -34,8 +34,11 @@
- ~~Found count is wrong~~
- ~~Filter query in the API~~
- Fix API response codes
- Test for search without any sort_by given
- Test for asc/desc upper/lower casing
- Test for collection creation validation
- Proper pagination
- Deprecate old split function
- Prevent string copy during indexing
- clean special chars before indexing
- Minimum results should be a variable instead of blindly going with max_results

View File

@ -122,4 +122,8 @@ struct StringUtils {
strtol(s.c_str(), &p, 10);
return (*p == 0);
}
static void toupper(std::string& str) {
std::transform(str.begin(), str.end(), str.begin(), ::toupper);
}
};

View File

@ -1,5 +1,6 @@
#include <regex>
#include "api.h"
#include "string_utils.h"
#include "collection.h"
#include "collection_manager.h"
@ -105,8 +106,11 @@ void post_create_collection(http_req & req, http_res & res) {
void get_search(http_req & req, http_res & res) {
const char *NUM_TYPOS = "num_typos";
const char *PREFIX = "prefix";
const char *FILTER = "filter";
const char *SEARCH_BY = "search_by";
const char *SORT_BY = "sort_by";
const char *FACET_BY = "facet_by";
const char *TOKEN_ORDERING = "token_ordering";
const char *FILTERS = "filters";
if(req.params.count(NUM_TYPOS) == 0) {
req.params[NUM_TYPOS] = "2";
@ -120,15 +124,43 @@ void get_search(http_req & req, http_res & res) {
req.params[TOKEN_ORDERING] = "FREQUENCY";
}
std::string filter_str = req.params.count(FILTERS) != 0 ? req.params[FILTERS] : "";
//std::cout << "filter_str: " << filter_str << std::endl;
if(req.params.count(SEARCH_BY) == 0) {
return res.send_400(std::string("Parameter `") + SEARCH_BY + "` is required.");
}
std::string filter_str = req.params.count(FILTER) != 0 ? req.params[FILTER] : "";
token_ordering token_order = (req.params[TOKEN_ORDERING] == "MAX_SCORE") ? MAX_SCORE : FREQUENCY;
//printf("Query: %s\n", req.params["q"].c_str());
auto begin = std::chrono::high_resolution_clock::now();
std::vector<std::string> search_fields;
StringUtils::split(req.params[SEARCH_BY], search_fields, ",");
std::vector<std::string> search_fields = {"title"};
std::vector<std::string> facet_fields;
StringUtils::split(req.params[FACET_BY], facet_fields, ",");
std::vector<sort_field> sort_fields;
if(req.params.count(SORT_BY) != 0) {
std::vector<std::string> sort_field_strs;
StringUtils::split(req.params[SORT_BY], sort_field_strs, ",");
if(sort_field_strs.size() > 2) {
return res.send_400("Only upto 2 sort fields are allowed.");
}
for(const std::string & sort_field_str: sort_field_strs) {
std::vector<std::string> expression_parts;
StringUtils::split(sort_field_str, expression_parts, ":");
if(expression_parts.size() != 2) {
return res.send_400(std::string("Parameter `") + SORT_BY + "` is malformed.");
}
StringUtils::toupper(expression_parts[1]);
sort_fields.push_back(sort_field(expression_parts[0], expression_parts[1]));
}
}
auto begin = std::chrono::high_resolution_clock::now();
CollectionManager & collectionManager = CollectionManager::get_instance();
Collection* collection = collectionManager.get_collection(req.params["collection"]);
@ -137,8 +169,8 @@ void get_search(http_req & req, http_res & res) {
return res.send_404();
}
nlohmann::json result = collection->search(req.params["q"], search_fields, filter_str, { },
{sort_field("points", "DESC")}, std::stoi(req.params[NUM_TYPOS]), 100,
nlohmann::json result = collection->search(req.params["q"], search_fields, filter_str, facet_fields,
sort_fields, std::stoi(req.params[NUM_TYPOS]), 100,
token_order, false);
const std::string & json_str = result.dump();
//std::cout << "JSON:" << json_str << std::endl;

View File

@ -605,6 +605,11 @@ nlohmann::json Collection::search(std::string query, const std::vector<std::stri
result["error"] = "Could not find a sort field named `" + _sort_field.name + "` in the schema.";
return result;
}
if(_sort_field.order != sort_field_const::asc && _sort_field.order != sort_field_const::desc) {
result["error"] = "Order for sort field` " + _sort_field.name + "` should be either ASC or DESC.";
return result;
}
}
// process the filters first
@ -927,7 +932,8 @@ void Collection::score_results(const std::vector<sort_field> & sort_fields, cons
((uint64_t)(mscore.words_present) << 8) +
(MAX_SEARCH_TOKENS - mscore.distance);
int64_t primary_rank_score = primary_rank_scores->count(seq_id) > 0 ? primary_rank_scores->at(seq_id) : 0;
int64_t primary_rank_score = (primary_rank_scores && primary_rank_scores->count(seq_id) > 0) ?
primary_rank_scores->at(seq_id) : 0;
int64_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ?
secondary_rank_scores->at(seq_id) : 0;
topster.add(seq_id, match_score,