Merge remote-tracking branch 'up/v0.26-facets' into v0.26-facets

This commit is contained in:
Harpreet Sangar 2023-08-23 16:08:20 +05:30
commit 7264977490
14 changed files with 263 additions and 35 deletions

1
BUILD
View File

@ -137,6 +137,7 @@ filegroup(
"test/**/*.txt",
"test/**/*.ini",
"test/**/*.jsonl",
"test/**/*.gz",
]),
)

View File

@ -167,7 +167,7 @@ private:
void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store);
void process_remove_field_for_embedding_fields(const field& the_field, std::vector<field>& garbage_fields);
void process_remove_field_for_embedding_fields(const field& del_field, std::vector<field>& garbage_embed_fields);
void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented,
const std::map<size_t, std::vector<std::string>>& pinned_hits,

View File

@ -142,8 +142,6 @@ private:
butil::EndPoint peering_endpoint;
Option<bool> handle_gzip(const std::shared_ptr<http_req>& request);
public:
static constexpr const char* log_dir_name = "log";
@ -241,6 +239,8 @@ public:
std::string get_leader_url() const;
static Option<bool> handle_gzip(const std::shared_ptr<http_req>& request);
private:
friend class ReplicationClosure;

View File

@ -72,6 +72,10 @@ public:
Option<bool> validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims);
Option<bool> validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims);
std::unordered_map<std::string, std::shared_ptr<TextEmbedder>> _get_text_embedders() {
return text_embedders;
}
private:
TextEmbedderManager() = default;

View File

@ -31,6 +31,7 @@ class RemoteEmbedder {
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) = 0;
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) = 0;
static const std::string get_model_key(const nlohmann::json& model_config);
static void init(ReplicationState* rs) {
raft_server = rs;
}
@ -51,6 +52,7 @@ class OpenAIEmbedder : public RemoteEmbedder {
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
static std::string get_model_key(const nlohmann::json& model_config);
};
@ -68,6 +70,7 @@ class GoogleEmbedder : public RemoteEmbedder {
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
static std::string get_model_key(const nlohmann::json& model_config);
};
@ -95,6 +98,7 @@ class GCPEmbedder : public RemoteEmbedder {
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
static std::string get_model_key(const nlohmann::json& model_config);
};

View File

@ -4254,7 +4254,6 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
}
}
// hide credentials in the alter payload return
for(auto& field_json : alter_payload["fields"]) {
if(field_json[fields::embed].count(fields::model_config) != 0) {
@ -4267,8 +4266,6 @@ Option<bool> Collection::alter(nlohmann::json& alter_payload) {
}
}
return Option<bool>(true);
}
@ -5346,27 +5343,43 @@ Option<bool> Collection::populate_include_exclude_fields_lk(const spp::sparse_ha
}
// Removes the dropped field from embed_from of all embedding fields.
void Collection::process_remove_field_for_embedding_fields(const field& the_field, std::vector<field>& garbage_fields) {
void Collection::process_remove_field_for_embedding_fields(const field& del_field,
std::vector<field>& garbage_embed_fields) {
for(auto& field : fields) {
if(field.embed.count(fields::from) == 0) {
continue;
}
auto embed_from = field.embed[fields::from].get<std::vector<std::string>>();
embed_from.erase(std::remove_if(embed_from.begin(), embed_from.end(), [&the_field](std::string field_name) {
return the_field.name == field_name;
}));
field.embed[fields::from] = std::move(embed_from);
embedding_fields[field.name] = field;
// mark this embedding field as "garbage" if it has no more embed_from fields
if(embed_from.empty()) {
embedding_fields.erase(field.name);
garbage_fields.push_back(field);
bool found_field = false;
nlohmann::json& embed_from_names = field.embed[fields::from];
for(auto it = embed_from_names.begin(); it != embed_from_names.end();) {
if(it.value() == del_field.name) {
it = embed_from_names.erase(it);
found_field = true;
} else {
it++;
}
}
if(found_field) {
// mark this embedding field as "garbage" if it has no more embed_from fields
if(embed_from_names.empty()) {
garbage_embed_fields.push_back(field);
} else {
// the dropped field was present in `embed_from`, so we have to update the field objects
field.embed[fields::from] = embed_from_names;
embedding_fields[field.name].embed[fields::from] = embed_from_names;
}
}
}
for(auto& garbage_field: garbage_embed_fields) {
embedding_fields.erase(garbage_field.name);
search_schema.erase(garbage_field.name);
fields.erase(std::remove_if(fields.begin(), fields.end(), [&garbage_field](const auto &f) {
return f.name == garbage_field.name;
}), fields.end());
}
}
void Collection::hide_credential(nlohmann::json& json, const std::string& credential_name) {

View File

@ -730,7 +730,7 @@ bool get_export_documents(const std::shared_ptr<http_req>& req, const std::share
}
}
res->content_type_header = "text/plain; charset=utf8";
res->content_type_header = "text/plain; charset=utf-8";
res->status_code = 200;
stream_response(req, res);
@ -903,7 +903,7 @@ bool post_import_documents(const std::shared_ptr<http_req>& req, const std::shar
}
}
res->content_type_header = "text/plain; charset=utf8";
res->content_type_header = "text/plain; charset=utf-8";
res->status_code = 200;
res->body = response_stream.str();

View File

@ -953,12 +953,16 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
try {
const std::vector<float>& float_vals = record.doc[afield.name].get<std::vector<float>>();
if(afield.vec_dist == cosine) {
std::vector<float> normalized_vals(afield.num_dim);
hnsw_index_t::normalize_vector(float_vals, normalized_vals);
vec_index->addPoint(normalized_vals.data(), (size_t)record.seq_id, true);
if(float_vals.size() != afield.num_dim) {
record.index_failure(400, "Vector size mismatch.");
} else {
vec_index->addPoint(float_vals.data(), (size_t)record.seq_id, true);
if(afield.vec_dist == cosine) {
std::vector<float> normalized_vals(afield.num_dim);
hnsw_index_t::normalize_vector(float_vals, normalized_vals);
vec_index->addPoint(normalized_vals.data(), (size_t)record.seq_id, true);
} else {
vec_index->addPoint(float_vals.data(), (size_t)record.seq_id, true);
}
}
} catch(const std::exception &e) {
record.index_failure(400, e.what());

View File

@ -43,9 +43,10 @@ Option<bool> TextEmbedderManager::validate_and_init_remote_model(const nlohmann:
}
std::unique_lock<std::mutex> lock(text_embedders_mutex);
auto text_embedder_it = text_embedders.find(model_name);
std::string model_key = is_remote_model(model_name) ? RemoteEmbedder::get_model_key(model_config) : model_name;
auto text_embedder_it = text_embedders.find(model_key);
if(text_embedder_it == text_embedders.end()) {
text_embedders.emplace(model_name, std::make_shared<TextEmbedder>(model_config, num_dims));
text_embedders.emplace(model_key, std::make_shared<TextEmbedder>(model_config, num_dims));
}
return Option<bool>(true);
@ -122,7 +123,8 @@ Option<bool> TextEmbedderManager::validate_and_init_local_model(const nlohmann::
Option<TextEmbedder*> TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
std::unique_lock<std::mutex> lock(text_embedders_mutex);
const std::string& model_name = model_config.at("model_name");
auto text_embedder_it = text_embedders.find(model_name);
std::string model_key = is_remote_model(model_name) ? RemoteEmbedder::get_model_key(model_config) : model_name;
auto text_embedder_it = text_embedders.find(model_key);
if(text_embedder_it == text_embedders.end()) {
return Option<TextEmbedder*>(404, "Text embedder was not found.");

View File

@ -53,6 +53,21 @@ long RemoteEmbedder::call_remote_api(const std::string& method, const std::strin
proxy_call_timeout_ms, true);
}
const std::string RemoteEmbedder::get_model_key(const nlohmann::json& model_config) {
const std::string model_namespace = TextEmbedderManager::get_model_namespace(model_config["model_name"].get<std::string>());
if(model_namespace == "openai") {
return OpenAIEmbedder::get_model_key(model_config);
} else if(model_namespace == "google") {
return GoogleEmbedder::get_model_key(model_config);
} else if(model_namespace == "gcp") {
return GCPEmbedder::get_model_key(model_config);
} else {
return "";
}
}
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
}
@ -255,6 +270,9 @@ nlohmann::json OpenAIEmbedder::get_error_json(const nlohmann::json& req_body, lo
return embedding_res;
}
std::string OpenAIEmbedder::get_model_key(const nlohmann::json& model_config) {
return model_config["model_name"].get<std::string>() + ":" + model_config["api_key"].get<std::string>();
}
GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) {
@ -372,6 +390,10 @@ nlohmann::json GoogleEmbedder::get_error_json(const nlohmann::json& req_body, lo
return embedding_res;
}
std::string GoogleEmbedder::get_model_key(const nlohmann::json& model_config) {
return model_config["model_name"].get<std::string>() + ":" + model_config["api_key"].get<std::string>();
}
GCPEmbedder::GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token,
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) :
@ -625,3 +647,7 @@ Option<std::string> GCPEmbedder::generate_access_token(const std::string& refres
return Option<std::string>(access_token);
}
std::string GCPEmbedder::get_model_key(const nlohmann::json& model_config) {
return model_config["model_name"].get<std::string>() + ":" + model_config["project_id"].get<std::string>() + ":" + model_config["client_secret"].get<std::string>();
}

View File

@ -1580,9 +1580,13 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type":"string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["names","category"], "model_config": {"model_name": "ts/e5-small"}}}
{"name": "title", "type": "string"},
{"name": "names", "type": "string[]"},
{"name": "category", "type":"string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["names","category"],
"model_config": {"model_name": "ts/e5-small"}}},
{"name": "embedding2", "type":"float[]", "embed":{"from": ["names"],
"model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -1594,20 +1598,28 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
LOG(INFO) << "Created collection";
auto embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(2, embedding_fields.size());
ASSERT_EQ(2, embedding_fields["embedding"].embed[fields::from].get<std::vector<std::string>>().size());
ASSERT_EQ(1, embedding_fields["embedding2"].embed[fields::from].get<std::vector<std::string>>().size());
auto coll_schema = coll->get_schema();
ASSERT_EQ(5, coll_schema.size());
auto the_fields = coll->get_fields();
ASSERT_EQ(5, the_fields.size());
auto schema_changes = R"({
"fields": [
{"name": "names", "drop": true}
]
})"_json;
auto embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(2, embedding_fields["embedding"].embed[fields::from].get<std::vector<std::string>>().size());
auto alter_op = coll->alter(schema_changes);
ASSERT_TRUE(alter_op.ok());
embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(1, embedding_fields.size());
ASSERT_EQ(1, embedding_fields["embedding"].embed[fields::from].get<std::vector<std::string>>().size());
ASSERT_EQ("category", embedding_fields["embedding"].embed[fields::from].get<std::vector<std::string>>()[0]);
@ -1623,6 +1635,16 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
embedding_fields = coll->get_embedding_fields();
ASSERT_EQ(0, embedding_fields.size());
ASSERT_EQ(0, coll->_get_index()->_get_vector_index().size());
// only title remains
coll_schema = coll->get_schema();
ASSERT_EQ(1, coll_schema.size());
ASSERT_EQ("title", coll_schema["title"].name);
the_fields = coll->get_fields();
ASSERT_EQ(1, the_fields.size());
ASSERT_EQ("title", the_fields[0].name);
}
TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {

View File

@ -1342,3 +1342,98 @@ TEST_F(CollectionVectorTest, HybridSearchReturnAllInfo) {
ASSERT_EQ(1, results["hits"][0].count("text_match_info"));
ASSERT_EQ(1, results["hits"][0].count("hybrid_search_info"));
}
TEST_F(CollectionVectorTest, DISABLED_HybridSortingTest) {
auto schema_json =
R"({
"name": "TEST",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
auto coll1 = collection_create_op.get();
auto add_op = coll1->add(R"({
"name": "john doe"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll1->add(R"({
"name": "john legend"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll1->add(R"({
"name": "john krasinski"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll1->add(R"({
"name": "john abraham"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
// first do keyword search
auto results = coll1->search("john", {"name"},
"", {}, {}, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
ASSERT_EQ(4, results["hits"].size());
// now do hybrid search with sort_by: _text_match:desc,_vector_distance:asc
std::vector<sort_by> sort_by_list = {{"_text_match", "desc"}, {"_vector_distance", "asc"}};
auto hybrid_results = coll1->search("john", {"name", "embedding"},
"", {}, sort_by_list, {2}, 10,
1, FREQUENCY, {true},
0, spp::sparse_hash_set<std::string>()).get();
// first 4 results should be same as keyword search
ASSERT_EQ(results["hits"][0]["document"]["name"].get<std::string>(), hybrid_results["hits"][0]["document"]["name"].get<std::string>());
ASSERT_EQ(results["hits"][1]["document"]["name"].get<std::string>(), hybrid_results["hits"][1]["document"]["name"].get<std::string>());
ASSERT_EQ(results["hits"][2]["document"]["name"].get<std::string>(), hybrid_results["hits"][2]["document"]["name"].get<std::string>());
ASSERT_EQ(results["hits"][3]["document"]["name"].get<std::string>(), hybrid_results["hits"][3]["document"]["name"].get<std::string>());
}
TEST_F(CollectionVectorTest, TestDifferentOpenAIApiKeys) {
if (std::getenv("api_key_1") == nullptr || std::getenv("api_key_2") == nullptr) {
LOG(INFO) << "Skipping test as api_key_1 or api_key_2 is not set";
return;
}
auto api_key1 = std::string(std::getenv("api_key_1"));
auto api_key2 = std::string(std::getenv("api_key_2"));
auto embedder_map = TextEmbedderManager::get_instance()._get_text_embedders();
ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002:" + api_key1), embedder_map.end());
ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002:" + api_key2), embedder_map.end());
ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002"), embedder_map.end());
nlohmann::json model_config1 = R"({
"model_name": "openai/text-embedding-ada-002"
})"_json;
nlohmann::json model_config2 = model_config1;
model_config1["api_key"] = api_key1;
model_config2["api_key"] = api_key2;
size_t num_dim;
TextEmbedderManager::get_instance().validate_and_init_remote_model(model_config1, num_dim);
TextEmbedderManager::get_instance().validate_and_init_remote_model(model_config2, num_dim);
embedder_map = TextEmbedderManager::get_instance()._get_text_embedders();
ASSERT_NE(embedder_map.find("openai/text-embedding-ada-002:" + api_key1), embedder_map.end());
ASSERT_NE(embedder_map.find("openai/text-embedding-ada-002:" + api_key2), embedder_map.end());
ASSERT_EQ(embedder_map.find("openai/text-embedding-ada-002"), embedder_map.end());
}

View File

@ -4,6 +4,7 @@
#include <collection_manager.h>
#include <core_api.h>
#include "core_api_utils.h"
#include "raft_server.h"
class CoreAPIUtilsTest : public ::testing::Test {
protected:
@ -621,6 +622,7 @@ TEST_F(CoreAPIUtilsTest, PresetSingleSearch) {
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll1 = op.get();
auto preset_value = R"(
{"collection":"preset_coll", "per_page": "12"}
@ -1157,4 +1159,59 @@ TEST_F(CoreAPIUtilsTest, TestProxyTimeout) {
ASSERT_EQ(408, resp->status_code);
ASSERT_EQ("Server error on remote server. Please try again later.", nlohmann::json::parse(resp->body)["message"]);
}
TEST_F(CoreAPIUtilsTest, SampleGzipIndexTest) {
Collection *coll_hnstories;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("points", field_types::INT32, false),};
coll_hnstories = collectionManager.get_collection("coll_hnstories").get();
if(coll_hnstories == nullptr) {
coll_hnstories = collectionManager.create_collection("coll_hnstories", 4, fields, "title").get();
}
auto req = std::make_shared<http_req>();
std::ifstream infile(std::string(ROOT_DIR)+"test/resources/hnstories.jsonl.gz");
std::stringstream outbuffer;
infile.seekg (0, infile.end);
int length = infile.tellg();
infile.seekg (0, infile.beg);
req->body.resize(length);
infile.read(&req->body[0], length);
auto res = ReplicationState::handle_gzip(req);
if (!res.error().empty()) {
LOG(ERROR) << res.error();
FAIL();
} else {
outbuffer << req->body;
}
std::vector<std::string> doc_lines;
std::string line;
while(std::getline(outbuffer, line)) {
doc_lines.push_back(line);
}
ASSERT_EQ(14, doc_lines.size());
ASSERT_EQ("{\"points\":1,\"title\":\"DuckDuckGo Settings\"}", doc_lines[0]);
ASSERT_EQ("{\"points\":1,\"title\":\"Making Twitter Easier to Use\"}", doc_lines[1]);
ASSERT_EQ("{\"points\":2,\"title\":\"London refers Uber app row to High Court\"}", doc_lines[2]);
ASSERT_EQ("{\"points\":1,\"title\":\"Young Global Leaders, who should be nominated? (World Economic Forum)\"}", doc_lines[3]);
ASSERT_EQ("{\"points\":1,\"title\":\"Blooki.st goes BETA in a few hours\"}", doc_lines[4]);
ASSERT_EQ("{\"points\":1,\"title\":\"Unicode Security Data: Beta Review\"}", doc_lines[5]);
ASSERT_EQ("{\"points\":2,\"title\":\"FileMap: MapReduce on the CLI\"}", doc_lines[6]);
ASSERT_EQ("{\"points\":1,\"title\":\"[Full Video] NBC News Interview with Edward Snowden\"}", doc_lines[7]);
ASSERT_EQ("{\"points\":1,\"title\":\"Hybrid App Monetization Example with Mobile Ads and In-App Purchases\"}", doc_lines[8]);
ASSERT_EQ("{\"points\":1,\"title\":\"We need oppinion from Android Developers\"}", doc_lines[9]);
ASSERT_EQ("{\"points\":1,\"title\":\"\\\\t Why Mobile Developers Should Care About Deep Linking\"}", doc_lines[10]);
ASSERT_EQ("{\"points\":2,\"title\":\"Are we getting too Sassy? Weighing up micro-optimisation vs. maintainability\"}", doc_lines[11]);
ASSERT_EQ("{\"points\":2,\"title\":\"Google's XSS game\"}", doc_lines[12]);
ASSERT_EQ("{\"points\":1,\"title\":\"Telemba Turns Your Old Roomba and Tablet Into a Telepresence Robot\"}", doc_lines[13]);
infile.close();
}

Binary file not shown.