Review changes II

This commit is contained in:
ozanarmagan 2023-05-07 01:47:12 +03:00
parent 85ed9090b2
commit 992cbc9080
11 changed files with 167 additions and 143 deletions

View File

@ -50,6 +50,8 @@ namespace fields {
static const std::string num_dim = "num_dim";
static const std::string vec_dist = "vec_dist";
static const std::string reference = "reference";
static const std::string embed = "embed";
static const std::string from = "from";
static const std::string embed_from = "embed_from";
static const std::string model_name = "model_name";
@ -58,7 +60,7 @@ namespace fields {
static const std::string indexing_prefix = "indexing_prefix";
static const std::string query_prefix = "query_prefix";
static const std::string api_key = "api_key";
static const std::string model_parameters = "model_parameters";
static const std::string model_config = "model_config";
}
enum vector_distance_type_t {
@ -85,7 +87,7 @@ struct field {
size_t num_dim;
std::vector<std::string> embed_from;
nlohmann::json model_parameters;
nlohmann::json model_config;
vector_distance_type_t vec_dist;
static constexpr int VAL_UNKNOWN = 2;
@ -97,9 +99,9 @@ struct field {
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector<std::string> &embed_from = {},
const nlohmann::json& model_parameters = nlohmann::json()) :
const nlohmann::json& model_config = nlohmann::json()) :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_parameters(model_parameters) {
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_config(model_config) {
set_computed_defaults(sort, infix);
}
@ -328,10 +330,9 @@ struct field {
field_val[fields::reference] = field.reference;
}
if(!field.embed_from.empty()) {
field_val[fields::embed_from] = field.embed_from;
if(!field.model_parameters.empty()) {
field_val[fields::model_parameters] = field.model_parameters;
}
field_val[fields::embed] = nlohmann::json::object();
field_val[fields::embed][fields::from] = field.embed_from;
field_val[fields::embed][fields::model_config] = field.model_config;
}
fields_json.push_back(field_val);
@ -428,33 +429,40 @@ struct field {
size_t num_auto_detect_fields = 0;
for(nlohmann::json & field_json: fields_json) {
if(field_json.count(fields::embed_from) != 0) {
if(!field_json[fields::embed_from].is_array()) {
return Option<bool>(400, "Property `" + fields::embed_from + "` must be an array.");
if(field_json.count(fields::embed) != 0) {
if(!field_json[fields::embed].is_object()) {
return Option<bool>(400, "Property `" + fields::embed + "` must be an object.");
}
if(field_json[fields::embed_from].empty()) {
return Option<bool>(400, "Property `" + fields::embed_from + "` must have at least one element.");
if(field_json[fields::embed].count(fields::from) == 0) {
return Option<bool>(400, "Property `" + fields::embed + "` must contain a `" + fields::from + "` property.");
}
for(auto& embed_from_field : field_json[fields::embed_from]) {
if(!field_json[fields::embed][fields::from].is_array()) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must be an array.");
}
if(field_json[fields::embed][fields::from].empty()) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must have at least one element.");
}
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
if(!embed_from_field.is_string()) {
return Option<bool>(400, "Property `" + fields::embed_from + "` must contain only field names as strings.");
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must contain only field names as strings.");
}
}
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
return Option<bool>(400, "Property `" + fields::embed_from + "` is only allowed on a float array field.");
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` is only allowed on a float array field.");
}
for(auto& embed_from_field : field_json[fields::embed_from]) {
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
bool flag = false;
for(const auto& field : fields_json) {
if(field[fields::name] == embed_from_field) {
if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) {
return Option<bool>(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields.");
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
}
flag = true;
break;
@ -464,7 +472,7 @@ struct field {
for(const auto& field : the_fields) {
if(field.name == embed_from_field) {
if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) {
return Option<bool>(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields.");
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
}
flag = true;
break;
@ -472,9 +480,9 @@ struct field {
}
}
if(!flag) {
return Option<bool>(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields.");
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
}
}
}
}
auto op = json_field_to_field(enable_nested_fields,

View File

@ -34,13 +34,13 @@ public:
TextEmbedderManager(const TextEmbedderManager&) = delete;
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
TextEmbedder* get_text_embedder(const nlohmann::json& model_parameters);
TextEmbedder* get_text_embedder(const nlohmann::json& model_config);
void delete_text_embedder(const std::string& model_path);
void delete_all_text_embedders();
static const TokenizerType get_tokenizer_type(const nlohmann::json& model_parameters);
const std::string get_indexing_prefix(const nlohmann::json& model_parameters);
const std::string get_query_prefix(const nlohmann::json& model_parameters);
static const TokenizerType get_tokenizer_type(const nlohmann::json& model_config);
const std::string get_indexing_prefix(const nlohmann::json& model_config);
const std::string get_query_prefix(const nlohmann::json& model_config);
static void set_model_dir(const std::string& dir);
static const std::string& get_model_dir();

View File

@ -239,16 +239,17 @@ nlohmann::json Collection::get_summary_json() const {
field_json[fields::sort] = coll_field.sort;
field_json[fields::infix] = coll_field.infix;
field_json[fields::locale] = coll_field.locale;
field_json[fields::embed] = nlohmann::json::object();
if(!coll_field.embed_from.empty()) {
field_json[fields::embed_from] = coll_field.embed_from;
field_json[fields::embed][fields::from] = coll_field.embed_from;
}
if(coll_field.model_parameters.size() > 0) {
field_json[fields::model_parameters] = coll_field.model_parameters;
if(coll_field.model_config.size() > 0) {
field_json[fields::embed][fields::model_config] = coll_field.model_config;
// Hide OpenAI API key from the response.
if(field_json[fields::model_parameters].count(fields::api_key) != 0) {
field_json[fields::model_parameters][fields::api_key] = "<hidden>";
if(field_json[fields::embed][fields::model_config].count(fields::api_key) != 0) {
field_json[fields::embed][fields::model_config][fields::api_key] = "<hidden>";
}
}
@ -1172,7 +1173,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
auto search_field = search_schema.at(expanded_search_field);
if(search_field.num_dim > 0) {
if(has_embedding_query) {
if(!vector_query.field_name.empty()) {
std::string error = "Only one embedding field is allowed in the query.";
return Option<nlohmann::json>(400, error);
}
@ -1188,9 +1189,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(search_field.model_parameters);
auto embedder = embedder_manager.get_text_embedder(search_field.model_config);
std::string embed_query = embedder_manager.get_query_prefix(search_field.model_parameters) + raw_query;
std::string embed_query = embedder_manager.get_query_prefix(search_field.model_config) + raw_query;
auto embedding_op = embedder->Embed(embed_query);
if(!embedding_op.ok()) {
return Option<nlohmann::json>(400, embedding_op.error());
@ -1200,7 +1201,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
vector_query._reset();
vector_query.values = embedding;
vector_query.field_name = field_name;
has_embedding_query = true;
continue;
}
@ -1212,7 +1212,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
std::string real_raw_query = raw_query;
if(has_embedding_query && processed_search_fields.size() == 0) {
if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) {
raw_query = "*";
}

View File

@ -62,8 +62,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field_obj[fields::embed_from] = std::vector<std::string>();
}
if(field_obj.count(fields::model_parameters) == 0) {
field_obj[fields::model_parameters] = nlohmann::json::object();
if(field_obj.count(fields::model_config) == 0) {
field_obj[fields::model_config] = nlohmann::json::object();
}
vector_distance_type_t vec_dist_type = vector_distance_type_t::cosine;
@ -78,7 +78,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed_from],
field_obj[fields::model_parameters]);
field_obj[fields::model_config]);
// value of `sort` depends on field type
if(field_obj.count(fields::sort) == 0) {

View File

@ -672,64 +672,62 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}
}
if(field_json.count(fields::model_parameters) > 0 && field_json.count(fields::embed_from) == 0) {
return Option<bool>(400, "Property `" + fields::model_parameters + "` can only be used with `" + fields::embed_from + "`.");
if(field_json.count(fields::model_config) > 0 && field_json.count(fields::embed_from) == 0) {
return Option<bool>(400, "Property `" + fields::model_config + "` can only be used with `" + fields::embed_from + "`.");
}
if(field_json.count(fields::embed_from) != 0) {
if(field_json.count(fields::embed) != 0) {
// If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model)
field_json[fields::num_dim] = static_cast<unsigned int>(384);
if(field_json.count(fields::model_parameters) == 0) {
return Option<bool>(400, "Property `" + fields::model_parameters + "` must be specified with `" + fields::embed_from + "`.");
auto& embed_json = field_json[fields::embed];
if(embed_json.count(fields::model_config) == 0) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "` not found.");
}
auto& model_parameters = field_json[fields::model_parameters];
auto& model_config = embed_json[fields::model_config];
if(model_parameters.count(fields::model_name) == 0) {
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` must be specified with `" + fields::embed_from + "`.");
if(model_config.count(fields::model_name) == 0) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "`not found");
}
unsigned int num_dim = 0;
if(!model_parameters[fields::model_name].is_string()) {
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` must be a string.");
if(!model_config[fields::model_name].is_string()) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "` must be a string.");
}
if(model_parameters[fields::model_name].get<std::string>().empty()) {
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` cannot be empty.");
if(model_config[fields::model_name].get<std::string>().empty()) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "` cannot be empty.");
}
if(model_parameters.count(fields::indexing_prefix) != 0) {
if(!model_parameters[fields::indexing_prefix].is_string()) {
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::indexing_prefix + "` must be a string.");
if(model_config.count(fields::indexing_prefix) != 0) {
if(!model_config[fields::indexing_prefix].is_string()) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::indexing_prefix + "` must be a string.");
}
}
if(model_parameters.count(fields::query_prefix) != 0) {
if(!model_parameters[fields::query_prefix].is_string()) {
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::query_prefix + "` must be a string.");
if(model_config.count(fields::query_prefix) != 0) {
if(!model_config[fields::query_prefix].is_string()) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::query_prefix + "` must be a string.");
}
}
if(model_parameters.count("api_key") != 0) {
auto res = TextEmbedder::is_model_valid(model_parameters[fields::model_name].get<std::string>(), model_parameters[fields::api_key].get<std::string>(), num_dim);
if(res.ok()) {
field_json[fields::num_dim] = num_dim;
} else {
if(model_config.count("api_key") != 0) {
auto res = TextEmbedder::is_model_valid(model_config[fields::model_name].get<std::string>(), model_config[fields::api_key].get<std::string>(), num_dim);
if(!res.ok()) {
return Option<bool>(res.code(), res.error());
}
} else {
if(TextEmbedder::is_model_valid(model_parameters[fields::model_name].get<std::string>(), num_dim)) {
field_json[fields::num_dim] = num_dim;
} else {
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` is invalid.");
if(!TextEmbedder::is_model_valid(model_config[fields::model_name].get<std::string>(), num_dim)) {
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "` is invalid.");
}
}
field_json[fields::num_dim] = num_dim;
} else {
field_json[fields::embed_from] = std::vector<std::string>();
field_json[fields::embed] = nlohmann::json::object();
field_json[fields::embed][fields::from] = nlohmann::json::array();
}
auto DEFAULT_VEC_DIST_METRIC = magic_enum::enum_name(vector_distance_type_t::cosine);
if(field_json.count(fields::num_dim) == 0) {
@ -810,8 +808,8 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
field_json[fields::reference], field_json[fields::embed_from].get<std::vector<std::string>>(),
field_json[fields::model_parameters])
field_json[fields::reference], field_json[fields::embed][fields::from].get<std::vector<std::string>>(),
field_json[fields::embed][fields::model_config])
);
if (!field_json[fields::reference].get<std::string>().empty()) {

View File

@ -6349,7 +6349,7 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
for(const auto& field : embedding_fields) {
std::vector<std::string> text_to_embed;
for(auto& document : documents) {
std::string text = TextEmbedderManager::get_instance().get_indexing_prefix(field.model_parameters);
std::string text = TextEmbedderManager::get_instance().get_indexing_prefix(field.model_config);
for(const auto& field_name : field.embed_from) {
auto field_it = search_schema.find(field_name);
if(field_it.value().type == field_types::STRING) {
@ -6363,7 +6363,7 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
text_to_embed.push_back(text);
}
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder = embedder_manager.get_text_embedder(field.model_parameters);
auto embedder = embedder_manager.get_text_embedder(field.model_config);
auto embedding_op = embedder->batch_embed(text_to_embed);
if(!embedding_op.ok()) {

View File

@ -6,18 +6,18 @@ TextEmbedderManager& TextEmbedderManager::get_instance() {
return instance;
}
TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_parameters) {
TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
std::unique_lock<std::mutex> lock(text_embedders_mutex);
const std::string& model_name = model_parameters.at("model_name");
const std::string& model_name = model_config.at("model_name");
if(text_embedders[model_name] == nullptr) {
if(model_parameters.count("api_key") == 0) {
if(model_config.count("api_key") == 0) {
if(is_public_model(model_name)) {
// download the model if it doesn't exist
download_public_model(model_name);
}
text_embedders[model_name] = std::make_shared<TextEmbedder>(get_model_name_without_namespace(model_name));
} else {
text_embedders[model_name] = std::make_shared<TextEmbedder>(model_name, model_parameters.at("api_key").get<std::string>());
text_embedders[model_name] = std::make_shared<TextEmbedder>(model_name, model_config.at("api_key").get<std::string>());
}
}
return text_embedders[model_name].get();
@ -35,11 +35,11 @@ void TextEmbedderManager::delete_all_text_embedders() {
text_embedders.clear();
}
const TokenizerType TextEmbedderManager::get_tokenizer_type(const nlohmann::json& model_parameters) {
if(model_parameters.find("model_type") == model_parameters.end()) {
const TokenizerType TextEmbedderManager::get_tokenizer_type(const nlohmann::json& model_config) {
if(model_config.find("model_type") == model_config.end()) {
return TokenizerType::bert;
} else {
std::string tokenizer_type = model_parameters.at("model_type").get<std::string>();
std::string tokenizer_type = model_config.at("model_type").get<std::string>();
if(tokenizer_type == "distilbert") {
return TokenizerType::distilbert;
} else if(tokenizer_type == "xlm_roberta") {
@ -50,12 +50,12 @@ const TokenizerType TextEmbedderManager::get_tokenizer_type(const nlohmann::json
}
}
const std::string TextEmbedderManager::get_indexing_prefix(const nlohmann::json& model_parameters) {
const std::string TextEmbedderManager::get_indexing_prefix(const nlohmann::json& model_config) {
std::string val;
if(is_public_model(model_parameters["model_name"].get<std::string>())) {
val = public_models[model_parameters["model_name"].get<std::string>()].indexing_prefix;
if(is_public_model(model_config["model_name"].get<std::string>())) {
val = public_models[model_config["model_name"].get<std::string>()].indexing_prefix;
} else {
val = model_parameters.count("indexing_prefix") == 0 ? "" : model_parameters["indexing_prefix"].get<std::string>();
val = model_config.count("indexing_prefix") == 0 ? "" : model_config["indexing_prefix"].get<std::string>();
}
if(!val.empty()) {
val += " ";
@ -64,12 +64,12 @@ const std::string TextEmbedderManager::get_indexing_prefix(const nlohmann::json&
return val;
}
const std::string TextEmbedderManager::get_query_prefix(const nlohmann::json& model_parameters) {
const std::string TextEmbedderManager::get_query_prefix(const nlohmann::json& model_config) {
std::string val;
if(is_public_model(model_parameters["model_name"].get<std::string>())) {
val = public_models[model_parameters["model_name"].get<std::string>()].query_prefix;
if(is_public_model(model_config["model_name"].get<std::string>())) {
val = public_models[model_config["model_name"].get<std::string>()].query_prefix;
} else {
val = model_parameters.count("query_prefix") == 0 ? "" : model_parameters["query_prefix"].get<std::string>();
val = model_config.count("query_prefix") == 0 ? "" : model_config["query_prefix"].get<std::string>();
}
if(!val.empty()) {
val += " ";

View File

@ -1597,9 +1597,10 @@ TEST_F(CollectionAllFieldsTest, EmbedFromFieldJSONInvalidField) {
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["embed_from"] = {"name"};
field_json["model_parameters"] = nlohmann::json::object();
field_json["model_parameters"]["model_name"] = "ts/e5-small";
field_json["embed"] = nlohmann::json::object();
field_json["embed"]["from"] = {"name"};
field_json["embed"]["model_config"] = nlohmann::json::object();
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
std::vector<field> fields;
std::string fallback_field_type;
@ -1609,35 +1610,18 @@ TEST_F(CollectionAllFieldsTest, EmbedFromFieldJSONInvalidField) {
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `embed_from` can only refer to string or string array fields.", field_op.error());
ASSERT_EQ("Property `embed.from` can only refer to string or string array fields.", field_op.error());
}
// TEST_F(CollectionAllFieldsTest, EmbedFromFieldNoModelDir) {
// TextEmbedderManager::model_dir = std::string();
// nlohmann::json field_json;
// field_json["name"] = "embedding";
// field_json["type"] = "float[]";
// field_json["embed_from"] = {"name"};
// std::vector<field> fields;
// std::string fallback_field_type;
// auto arr = nlohmann::json::array();
// arr.push_back(field_json);
// auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
// ASSERT_FALSE(field_op.ok());
// ASSERT_EQ("Text embedding is not enabled. Please set `model-dir` at startup.", field_op.error());
// }
TEST_F(CollectionAllFieldsTest, EmbedFromNotArray) {
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["embed_from"] = "name";
field_json["model_parameters"] = nlohmann::json::object();
field_json["model_parameters"]["model_name"] = "ts/e5-small";
field_json["embed"] = nlohmann::json::object();
field_json["embed"]["from"] = "name";
field_json["embed"]["model_config"] = nlohmann::json::object();
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
std::vector<field> fields;
std::string fallback_field_type;
@ -1647,7 +1631,7 @@ TEST_F(CollectionAllFieldsTest, EmbedFromNotArray) {
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `embed_from` must be an array.", field_op.error());
ASSERT_EQ("Property `embed.from` must be an array.", field_op.error());
}
TEST_F(CollectionAllFieldsTest, ModelParametersWithoutEmbedFrom) {
@ -1655,8 +1639,8 @@ TEST_F(CollectionAllFieldsTest, ModelParametersWithoutEmbedFrom) {
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["model_parameters"] = nlohmann::json::object();
field_json["model_parameters"]["model_name"] = "ts/e5-small";
field_json["embed"]["model_config"] = nlohmann::json::object();
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
std::vector<field> fields;
std::string fallback_field_type;
@ -1665,7 +1649,7 @@ TEST_F(CollectionAllFieldsTest, ModelParametersWithoutEmbedFrom) {
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `model_parameters` can only be used with `embed_from`.", field_op.error());
ASSERT_EQ("Property `embed` must contain a `from` property.", field_op.error());
}
@ -1675,7 +1659,7 @@ TEST_F(CollectionAllFieldsTest, EmbedFromBasicValid) {
field embedding = field("embedding", field_types::FLOAT_ARRAY, false);
embedding.embed_from.push_back("name");
embedding.model_parameters["model_name"] = "ts/e5-small";
embedding.model_config["model_name"] = "ts/e5-small";
std::vector<field> fields = {field("name", field_types::STRING, false),
embedding};
auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO);
@ -1700,9 +1684,10 @@ TEST_F(CollectionAllFieldsTest, WrongDataTypeForEmbedFrom) {
nlohmann::json field_json;
field_json["name"] = "embedding";
field_json["type"] = "float[]";
field_json["embed_from"] = {"age"};
field_json["model_parameters"] = nlohmann::json::object();
field_json["model_parameters"]["model_name"] = "ts/e5-small";
field_json["embed"] = nlohmann::json::object();
field_json["embed"]["from"] = {"age"};
field_json["embed"]["model_config"] = nlohmann::json::object();
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
std::vector<field> fields;
std::string fallback_field_type;
@ -1713,5 +1698,5 @@ TEST_F(CollectionAllFieldsTest, WrongDataTypeForEmbedFrom) {
arr.push_back(field_json);
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
ASSERT_FALSE(field_op.ok());
ASSERT_EQ("Property `embed_from` can only refer to string or string array fields.", field_op.error());
ASSERT_EQ("Property `embed.from` can only refer to string or string array fields.", field_op.error());
}

View File

@ -1464,7 +1464,7 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) {
nlohmann::json update_schema = R"({
"fields": [
{"name": "embedding", "type":"float[]", "embed_from": ["names"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -1489,7 +1489,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type":"string"},
{"name": "embedding", "type":"float[]", "embed_from": ["names","category"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["names","category"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -1537,7 +1537,7 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;

View File

@ -4617,7 +4617,7 @@ TEST_F(CollectionTest, SemanticSearchTest) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4651,7 +4651,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4680,7 +4680,7 @@ TEST_F(CollectionTest, HybridSearch) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4713,7 +4713,7 @@ TEST_F(CollectionTest, HybridSearch) {
// "name": "objects",
// "fields": [
// {"name": "name", "type": "string"},
// {"name": "embedding", "type":"float[]", "embed_from": ["name"]}
// {"name": "embedding", "type":"float[]", "embed":{"from": ["name"]}
// ]
// })"_json;
@ -4741,7 +4741,7 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4814,7 +4814,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4847,7 +4847,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) {
"name": "objects",
"fields": [
{"name": "names", "type": "string[]"},
{"name": "embedding", "type":"float[]", "embed_from": ["names"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4872,7 +4872,7 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) {
"fields": [
{"name": "names", "type": "string[]"},
{"name": "category", "type": "string", "optional": true},
{"name": "embedding", "type":"float[]", "embed_from": ["names", "category"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["names", "category"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4898,7 +4898,7 @@ TEST_F(CollectionTest, WrongTypeForEmbedding) {
"name": "objects",
"fields": [
{"name": "category", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["category"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["category"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4921,7 +4921,7 @@ TEST_F(CollectionTest, WrongTypeOfElementForEmbeddingInStringArray) {
"name": "objects",
"fields": [
{"name": "category", "type": "string[]"},
{"name": "embedding", "type":"float[]", "embed_from": ["category"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["category"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4944,7 +4944,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -4991,7 +4991,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "openai/text-embedding-ada-002"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}
]
})"_json;
@ -5001,15 +5001,15 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
}
auto api_key = std::string(std::getenv("api_key"));
schema["fields"][1]["model_parameters"]["api_key"] = api_key;
schema["fields"][1]["model_config"]["api_key"] = api_key;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
auto summary = op.get()->get_summary_json();
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_parameters"]["model_name"]);
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_config"]["model_name"]);
ASSERT_EQ(1536, summary["fields"][1]["num_dim"]);
// make sure api_key is <hidden>
ASSERT_EQ("<hidden>", summary["fields"][1]["model_parameters"]["api_key"]);
ASSERT_EQ("<hidden>", summary["fields"][1]["model_config"]["api_key"]);
nlohmann::json doc;
doc["name"] = "butter";
@ -5019,4 +5019,37 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
ASSERT_EQ(1536, add_op.get()["embedding"].size());
}
TEST_F(CollectionTest, MoreThganOneEmbeddingField) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "name2", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}},
{"name": "embedding2", "type":"float[]", "embed":{"from": ["name2"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
auto coll = op.get();
nlohmann::json doc;
doc["name"] = "butter";
doc["name2"] = "butterball";
auto add_op = validator_t::validate_embed_fields(doc, op.get()->get_embedding_fields(), op.get()->get_schema(), true);
ASSERT_TRUE(add_op.ok());
spp::sparse_hash_set<std::string> dummy_include_exclude;
auto search_res_op = coll->search("butter", {"name", "embedding", "embedding2"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_FALSE(search_res_op.ok());
ASSERT_EQ("Only one embedding field is allowed in the query.", search_res_op.error());
}

View File

@ -683,7 +683,7 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
"name": "coll1",
"fields": [
{"name": "name", "type": "string"},
{"name": "vec", "type": "float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
@ -781,7 +781,7 @@ TEST_F(CollectionVectorTest, EmbeddingFieldVectorIndexTest) {
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;