diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index acbfc4d2..82c26a81 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -623,31 +623,70 @@ bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector 0) { - if (sort_by_str[i] == '(') { - paren_count++; - } else if (sort_by_str[i] == ')') { - paren_count--; + if (sort_field_expr.empty()) { + if (sort_by_str[i] == '$') { + // Sort by reference field + auto open_paren_pos = sort_by_str.find('(', i); + if (open_paren_pos == std::string::npos) { + return false; } - sort_field_expr += sort_by_str[i]; - } - if (paren_count != 0) { - return false; - } + sort_field_expr = sort_by_str.substr(i, open_paren_pos - i + 1); - sort_fields.emplace_back(sort_field_expr, ""); - sort_field_expr = ""; - continue; + i = open_paren_pos; + int paren_count = 1; + while (++i < sort_by_str.size() && paren_count > 0) { + if (sort_by_str[i] == '(') { + paren_count++; + } else if (sort_by_str[i] == ')') { + paren_count--; + } + sort_field_expr += sort_by_str[i]; + } + if (paren_count != 0) { + return false; + } + + sort_fields.emplace_back(sort_field_expr, ""); + sort_field_expr = ""; + continue; + } else if (sort_by_str.substr(i, 5) == sort_field_const::eval) { + // Optional filtering + auto open_paren_pos = sort_by_str.find('(', i); + if (open_paren_pos == std::string::npos) { + return false; + } + sort_field_expr = sort_field_const::eval + "("; + + i = open_paren_pos; + int paren_count = 1; + while (++i < sort_by_str.size() && paren_count > 0) { + if (sort_by_str[i] == '(') { + paren_count++; + } else if (sort_by_str[i] == ')') { + paren_count--; + } + sort_field_expr += sort_by_str[i]; + } + if (paren_count != 0 || i >= sort_by_str.size()) { + return false; + } + + while (sort_by_str[i] != ':' && ++i < sort_by_str.size()); + if (i >= sort_by_str.size()) { + return false; + } + + std::string order_str; + while (++i < sort_by_str.size() && sort_by_str[i] != ',') { + order_str += sort_by_str[i]; + } + StringUtils::trim(order_str); + StringUtils::toupper(order_str); + + sort_fields.emplace_back(sort_field_expr, order_str); + sort_field_expr = ""; + continue; + } } if(i == sort_by_str.size()-1 || (sort_by_str[i] == ',' && !isdigit(prev_non_space_char))) { diff --git a/src/core_api.cpp b/src/core_api.cpp index 88e8d036..e82b5963 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -1902,9 +1902,9 @@ bool get_stopwords(const std::shared_ptr& req, const std::shared_ptr& req, const std::shared_ptr& req, const std::shared_ptr& res) { - const std::string & stopword_name = req->params["stopword_name"]; + const std::string & stopword_name = req->params["name"]; StopwordsManager& stopwordManager = StopwordsManager::get_instance(); spp::sparse_hash_set stopwords; @@ -1980,9 +1980,9 @@ bool del_stopword(const std::shared_ptr& req, const std::shared_ptrset_200(res_json.dump()); diff --git a/src/stopwords_manager.cpp b/src/stopwords_manager.cpp index f24cb8ab..af72c39f 100644 --- a/src/stopwords_manager.cpp +++ b/src/stopwords_manager.cpp @@ -28,6 +28,7 @@ Option StopwordsManager::upsert_stopword(const std::string& stopword_name, const char* STOPWORD_VALUES = "stopwords"; const char* STOPWORD_LOCALE = "locale"; + std::string locale = ""; if(stopwords_json.count(STOPWORD_VALUES) == 0){ return Option(400, (std::string("Parameter `") + STOPWORD_VALUES + "` is required")); @@ -37,12 +38,11 @@ Option StopwordsManager::upsert_stopword(const std::string& stopword_name, return Option(400, (std::string("Parameter `") + STOPWORD_VALUES + "` is required as string array value")); } - if(stopwords_json.count(STOPWORD_LOCALE) == 0) { - return Option(400, (std::string("Parameter `") + STOPWORD_LOCALE + "` is required")); - } - - if(!stopwords_json[STOPWORD_LOCALE].is_string()) { - return Option(400, (std::string("Parameter `") + STOPWORD_LOCALE + "` is required as string value")); + if(stopwords_json.count(STOPWORD_LOCALE) != 0) { + if (!stopwords_json[STOPWORD_LOCALE].is_string()) { + return Option(400, (std::string("Parameter `") + STOPWORD_LOCALE + "` is required as string value")); + } + locale = stopwords_json[STOPWORD_LOCALE]; } if(write_to_store) { @@ -55,7 +55,6 @@ Option StopwordsManager::upsert_stopword(const std::string& stopword_name, std::vector tokens; spp::sparse_hash_set stopwords_set; const auto& stopwords = stopwords_json[STOPWORD_VALUES]; - const auto& locale = stopwords_json[STOPWORD_LOCALE]; for (const auto &stopword: stopwords.items()) { const auto& val = stopword.value().get(); diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index e7b354b9..4ae097d0 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -2135,6 +2135,22 @@ TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSearch) { results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); ASSERT_EQ(5, results["hits"].size()); + std::map req_params = { + {"collection", "coll1"}, + {"q", "title"}, + {"query_by", "title"}, + {"sort_by", "_eval(brand:[nike, adidas] && points:0):desc, points:DESC"} + }; + nlohmann::json embedded_params; + std::string json_res; + auto now_ts = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count(); + + auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + results = nlohmann::json::parse(json_res); + ASSERT_EQ(5, results["hits"].size()); + expected_ids = {"0", "4", "3", "2", "1"}; for(size_t i = 0; i < expected_ids.size(); i++) { ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); diff --git a/test/stopwords_manager_test.cpp b/test/stopwords_manager_test.cpp index 8653af19..bb06eb10 100644 --- a/test/stopwords_manager_test.cpp +++ b/test/stopwords_manager_test.cpp @@ -328,20 +328,8 @@ TEST_F(StopwordsManagerTest, StopwordsValidation) { std::shared_ptr req = std::make_shared(); std::shared_ptr res = std::make_shared(nullptr); - auto stopword_value = R"( - {"stopwords": ["america", "europe"]} - )"_json; - - req->params["collection"] = "coll1"; - req->params["name"] = "continents"; - req->body = stopword_value.dump(); - - auto result = put_upsert_stopword(req, res); - ASSERT_EQ(400, res->status_code); - ASSERT_EQ("{\"message\": \"Parameter `locale` is required\"}", res->body); - //with a typo - stopword_value = R"( + auto stopword_value = R"( {"stopword": ["america", "europe"], "locale": "en"} )"_json; @@ -349,7 +337,7 @@ TEST_F(StopwordsManagerTest, StopwordsValidation) { req->params["name"] = "continents"; req->body = stopword_value.dump(); - result = put_upsert_stopword(req, res); + auto result = put_upsert_stopword(req, res); ASSERT_EQ(400, res->status_code); ASSERT_STREQ("{\"message\": \"Parameter `stopwords` is required\"}", res->body.c_str());