diff --git a/TODO.md b/TODO.md index 84266554..72ea0835 100644 --- a/TODO.md +++ b/TODO.md @@ -15,6 +15,7 @@ - ~~Speed up UUID generation~~ - Prefix-search strings should not be null terminated - Make the search score computation customizable +- string_utils::tokenize should not have max length **API** diff --git a/include/collection.h b/include/collection.h index 19819eaf..1acf0123 100644 --- a/include/collection.h +++ b/include/collection.h @@ -26,8 +26,12 @@ private: std::string get_seq_id_key(uint32_t seq_id); std::string get_id_key(std::string id); - static inline std::vector _next_suggestion(const std::vector> &token_leaves, - long long int n); + static inline std::vector next_suggestion(const std::vector> &token_leaves, + long long int n); + void log_leaves(const int max_cost, const std::string &token, const std::vector &leaves) const; + + void search_candidates(std::vector> & token_leaves, Topster<100> & topster, + size_t & total_results, const size_t & max_results); public: Collection() = delete; @@ -39,5 +43,8 @@ public: void score_results(Topster<100> &topster, const std::vector &query_suggestion, const uint32_t *result_ids, size_t result_size) const; + + enum {MAX_SEARCH_TOKENS = 20}; + enum {MAX_RESULTS = 100}; }; diff --git a/include/match_score.h b/include/match_score.h index 5c02a384..ccbb7eea 100644 --- a/include/match_score.h +++ b/include/match_score.h @@ -60,9 +60,8 @@ struct MatchScore { * compute the max_match and min_displacement of target tokens across the windows. */ static MatchScore match_score(uint32_t doc_id, std::vector> &token_offsets) { - const size_t WINDOW_SIZE = 20; - const size_t MAX_TOKENS_IN_A_QUERY = 20; - const uint16_t MAX_DISPLACEMENT = 20; + const size_t WINDOW_SIZE = Collection::MAX_SEARCH_TOKENS; + const uint16_t MAX_DISPLACEMENT = Collection::MAX_SEARCH_TOKENS; std::priority_queue, TokenOffset> heap; @@ -76,8 +75,8 @@ struct MatchScore { uint16_t min_displacement = MAX_DISPLACEMENT; std::queue window; - uint16_t token_offset[MAX_TOKENS_IN_A_QUERY] = { }; - std::fill_n(token_offset, MAX_TOKENS_IN_A_QUERY, MAX_DISPLACEMENT); + uint16_t token_offset[Collection::MAX_SEARCH_TOKENS] = { }; + std::fill_n(token_offset, Collection::MAX_SEARCH_TOKENS, MAX_DISPLACEMENT); do { if(window.empty()) { diff --git a/src/collection.cpp b/src/collection.cpp index a320752e..a56f9991 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -80,12 +80,49 @@ std::string Collection::add(std::string json_str) { return document["id"]; } +void Collection::search_candidates(std::vector> & token_leaves, Topster<100> & topster, + size_t & total_results, const size_t & max_results) { + const size_t combination_limit = 10; + auto product = []( long long a, std::vector& b ) { return a*b.size(); }; + long long int N = std::accumulate(token_leaves.begin(), token_leaves.end(), 1LL, product); + + for(long long n=0; n query_suggestion = next_suggestion(token_leaves, n); + + // initialize results with the starting element (for further intersection) + uint32_t* result_ids = query_suggestion[0]->values->ids.uncompress(); + size_t result_size = query_suggestion[0]->values->ids.getLength(); + + if(result_size == 0) continue; + + // intersect the document ids for each token to find docs that contain all the tokens (stored in `result_ids`) + for(auto i=1; i < query_suggestion.size(); i++) { + uint32_t* out = new uint32_t[result_size]; + uint32_t* curr = query_suggestion[i]->values->ids.uncompress(); + result_size = Intersection::scalar(result_ids, result_size, curr, query_suggestion[i]->values->ids.getLength(), out); + delete[] result_ids; + delete[] curr; + result_ids = out; + } + + // go through each matching document id and calculate match score + score_results(topster, query_suggestion, result_ids, result_size); + + total_results += result_size; + delete[] result_ids; + + if(total_results >= max_results) break; + } + +} + /* 1. Split the query into tokens - 2. For each token, look up ids using exact lookup - a. If a token has no result, try again with edit distance of 1, and then 2 - 3. Do a limited cartesian product of the word suggestions for each token to form possible corrected search phrases - (adapted from: http://stackoverflow.com/a/31169617/131050) + 2. Outer loop will generate bounded cartesian product with costs for each token + 3. Inner loop will iterate on each token with associated cost + 4. Cartesian product of the results of the token searches will be used to form search phrases + (cartesian product adapted from: http://stackoverflow.com/a/31169617/131050) 4. Intersect the lists to find docs that match each phrase 5. Sort the docs based on some ranking criteria */ @@ -94,109 +131,127 @@ std::vector Collection::search(std::string query, const int num_ StringUtils::tokenize(query, tokens, " ", true); const int max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos; - const size_t max_results = std::min(num_results, (size_t) 100); + const size_t max_results = std::min(num_results, (size_t) Collection::MAX_RESULTS); - int cost = 0; size_t total_results = 0; std::vector results; Topster<100> topster; auto begin = std::chrono::high_resolution_clock::now(); - while(cost <= max_cost) { - std::cout << "Searching with cost=" << cost << std::endl; + std::vector> token_to_costs; + std::vector all_costs; + for(int cost = 0; cost <= max_cost; cost++) { + all_costs.push_back(cost); + } - std::vector> token_leaves; - for(std::string token: tokens) { - std::transform(token.begin(), token.end(), token.begin(), ::tolower); + for(size_t token_index = 0; token_index < tokens.size(); token_index++) { + token_to_costs.push_back(all_costs); + std::transform(tokens[token_index].begin(), tokens[token_index].end(), tokens[token_index].begin(), ::tolower); + } + + std::vector> token_leaves; + const size_t combination_limit = 10; + auto product = []( long long a, std::vector& b ) { return a*b.size(); }; + long long n = 0; + long long int N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product); + + while(n < N && n < combination_limit) { + // Outerloop generates combinations of [cost to max_cost] for each token + // For e.g. for a 3-token query: [0, 0, 0], [0, 0, 1], [0, 1, 1] etc. + std::vector costs(token_to_costs.size()); + ldiv_t q { n, 0 }; + for(long long i = (token_to_costs.size() - 1); 0 <= i ; --i ) { + q = ldiv(q.quot, token_to_costs[i].size()); + costs[i] = token_to_costs[i][q.rem]; + } + + token_leaves.clear(); + size_t token_index = 0; + bool retry_with_larger_cost = false; + + while(token_index < tokens.size()) { + // For each token, look up the generated cost for this iteration and search using that cost + std::string token = tokens[token_index]; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *) token.c_str(), (int) token.length() + 1, cost, 3, leaves); + art_fuzzy_search(&t, (const unsigned char *) token.c_str(), (int) token.length() + 1, costs[token_index], 3, leaves); + if(!leaves.empty()) { - for(auto i=0; ikey_len, leaves[i]->key); - printf(" - max_cost: %d, - num_ids: %d\n", max_cost, leaves[i]->values->ids.getLength()); - /*for(auto j=0; jvalues->ids.getLength(); j++) { - printf("id: %d\n", leaves[i]->values->ids.at(j)); - }*/ - } + //log_leaves(max_cost, token, leaves); token_leaves.push_back(leaves); + } else { + // no result when `cost = costs[token_index]` => remove cost for token and re-do combinations + auto it = std::find(token_to_costs[token_index].begin(), token_to_costs[token_index].end(), costs[token_index]); + if(it != token_to_costs[token_index].end()) { + token_to_costs[token_index].erase(it); + + // no more costs left for this token, clean up + if(token_to_costs[token_index].empty()) { + token_to_costs.erase(token_to_costs.begin()+token_index); + tokens.erase(tokens.begin()+token_index); + token_index--; + } + } + + n = -1; + N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product); + + if(costs[token_index] != max_cost) { + // Unless we're already at max_cost for this token, don't look at remaining tokens since we would + // see them again in a future iteration when we retry with a larger cost + retry_with_larger_cost = true; + break; + } + } + + token_index++; + } + + if(token_leaves.size() != 0 && !retry_with_larger_cost) { + // If a) all tokens were found, or b) Some were skipped because they don't exist within max_cost, + // go ahead and search for candidates with what we have so far + search_candidates(token_leaves, topster, total_results, max_results); + topster.sort(); + + for (uint32_t i = 0; i < topster.size; i++) { + uint64_t seq_id = topster.getKeyAt(i); + std::string value; + store->get(get_seq_id_key((uint32_t) seq_id), value); + nlohmann::json document = nlohmann::json::parse(value); + results.push_back(document); + } + + if (total_results > 0) { + // Unless there are results, we continue outerloop (looking at tokens with greater cost) + break; } } - if(token_leaves.size() != tokens.size() && cost != max_cost) { - // There could have been a typo in one of the tokens, so let's try again with greater cost - // Or this could be a token that does not exist at all (rare) - //std::cout << "token_leaves.size() != tokens.size(), continuing..." << std::endl << std::endl; - cost++; - continue; - } - - const size_t combination_limit = 10; - auto product = []( long long a, std::vector& b ) { return a*b.size(); }; - long long int N = std::accumulate(token_leaves.begin(), token_leaves.end(), 1LL, product ); - - for(long long n=0; n query_suggestion = _next_suggestion(token_leaves, n); - - // initialize results with the starting element (for further intersection) - uint32_t* result_ids = query_suggestion[0]->values->ids.uncompress(); - size_t result_size = query_suggestion[0]->values->ids.getLength(); - - if(result_size == 0) continue; - - // intersect the document ids for each token to find docs that contain all the tokens (stored in `result_ids`) - for(auto i=1; i < query_suggestion.size(); i++) { - uint32_t* out = new uint32_t[result_size]; - uint32_t* curr = query_suggestion[i]->values->ids.uncompress(); - result_size = Intersection::scalar(result_ids, result_size, curr, query_suggestion[i]->values->ids.getLength(), out); - delete[] result_ids; - delete[] curr; - result_ids = out; - } - - // go through each matching document id and calculate match score - score_results(topster, query_suggestion, result_ids, result_size); - - total_results += result_size; - delete[] result_ids; - - if(total_results >= max_results) break; - } - - topster.sort(); - - for(uint32_t i=0; iget(get_seq_id_key((uint32_t) seq_id), value); - nlohmann::json document = nlohmann::json::parse(value); - results.push_back(document); - } - - if(total_results > 0) { - break; - } - - cost++; + n++; } if(results.size() == 0) { - // We could drop certain tokens and try + // FIXME: We could drop certain tokens and try searching again } long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - begin).count(); std::cout << "Time taken for result calc: " << timeMillis << "us" << std::endl; - store->print_memory_usage(); - return results; } +void Collection::log_leaves(const int max_cost, const std::string &token, const std::vector &leaves) const { + for(auto i=0; i < leaves.size(); i++) { + printf("%s - ", token.c_str()); + printf("%.*s", leaves[i]->key_len, leaves[i]->key); + printf(" - max_cost: %d, - num_ids: %d\n", max_cost, leaves[i]->values->ids.getLength()); + for(auto j=0; jvalues->ids.getLength(); j++) { + printf("id: %d\n", leaves[i]->values->ids.at(j)); + } + } +} + void Collection::score_results(Topster<100> &topster, const std::vector &query_suggestion, const uint32_t *result_ids, size_t result_size) const { for(auto i=0; i &topster, const std::vector &topster, const std::vector Collection::_next_suggestion( +inline std::vector Collection::next_suggestion( const std::vector> &token_leaves, long long int n) { std::vector query_suggestion(token_leaves.size()); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 1b1ea824..3fd48c51 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -65,7 +65,7 @@ TEST_F(CollectionTest, ExactPhraseSearch) { TEST_F(CollectionTest, SkipUnindexedTokensDuringPhraseSearch) { // Tokens that are not found in the index should be skipped - std::vector results = collection->search("from DoesNotExist", 0, 10); + std::vector results = collection->search("DoesNotExist from", 0, 10); ASSERT_EQ(2, results.size()); std::vector ids = {"2", "17"}; @@ -76,4 +76,43 @@ TEST_F(CollectionTest, SkipUnindexedTokensDuringPhraseSearch) { std::string result_id = result["id"]; ASSERT_STREQ(id.c_str(), result_id.c_str()); } + + // with non-zero cost + results = collection->search("DoesNotExist from", 2, 10); + ASSERT_EQ(2, results.size()); + + for(size_t i = 0; i < results.size(); i++) { + nlohmann::json result = results.at(i); + std::string id = ids.at(i); + std::string result_id = result["id"]; + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // with 2 indexed words + results = collection->search("from DoesNotExist insTruments", 2, 10); + ASSERT_EQ(1, results.size()); + nlohmann::json result = results.at(0); + std::string result_id = result["id"]; + ASSERT_STREQ("2", result_id.c_str()); + + results.clear(); + results = collection->search("DoesNotExist1 DoesNotExist2", 0, 10); + ASSERT_EQ(0, results.size()); + + results.clear(); + results = collection->search("DoesNotExist1 DoesNotExist2", 2, 10); + ASSERT_EQ(0, results.size()); +} + +TEST_F(CollectionTest, PartialPhraseSearch) { + std::vector results = collection->search("rocket research", 0, 10); + //ASSERT_EQ(1, results.size()); +} + +TEST_F(CollectionTest, RegressionTest1) { + std::vector results = collection->search("kind biologcal", 2, 10); + ASSERT_EQ(1, results.size()); + + std::string result_id = results.at(0)["id"]; + ASSERT_STREQ("19", result_id.c_str()); } \ No newline at end of file diff --git a/test/documents.jsonl b/test/documents.jsonl index 48bda24d..ca7c6efa 100644 --- a/test/documents.jsonl +++ b/test/documents.jsonl @@ -16,4 +16,6 @@ {"points":10,"title":"How late do the launch propellants ionize in a chemical rocket mission?"} {"points":8,"title":"How much does it cost to launch (right from start) a rocket today?"} {"points":16,"title":"Difference between Space Dynamics & Astrodynamics in engineering perspective?"} -{"points":18,"title":"What kind of biological research does ISS do?"} \ No newline at end of file +{"points":18,"title":"What kind of biological research does ISS do?"} +{"points":10,"title":"What kinds of radiation hit ISS?"} +{"points":7,"title":"What kinds of things have been tossed out of ISS?"} \ No newline at end of file