mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 21:22:25 +08:00
Review changes II
This commit is contained in:
parent
85ed9090b2
commit
992cbc9080
@ -50,6 +50,8 @@ namespace fields {
|
||||
static const std::string num_dim = "num_dim";
|
||||
static const std::string vec_dist = "vec_dist";
|
||||
static const std::string reference = "reference";
|
||||
static const std::string embed = "embed";
|
||||
static const std::string from = "from";
|
||||
static const std::string embed_from = "embed_from";
|
||||
static const std::string model_name = "model_name";
|
||||
|
||||
@ -58,7 +60,7 @@ namespace fields {
|
||||
static const std::string indexing_prefix = "indexing_prefix";
|
||||
static const std::string query_prefix = "query_prefix";
|
||||
static const std::string api_key = "api_key";
|
||||
static const std::string model_parameters = "model_parameters";
|
||||
static const std::string model_config = "model_config";
|
||||
}
|
||||
|
||||
enum vector_distance_type_t {
|
||||
@ -85,7 +87,7 @@ struct field {
|
||||
|
||||
size_t num_dim;
|
||||
std::vector<std::string> embed_from;
|
||||
nlohmann::json model_parameters;
|
||||
nlohmann::json model_config;
|
||||
vector_distance_type_t vec_dist;
|
||||
|
||||
static constexpr int VAL_UNKNOWN = 2;
|
||||
@ -97,9 +99,9 @@ struct field {
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector<std::string> &embed_from = {},
|
||||
const nlohmann::json& model_parameters = nlohmann::json()) :
|
||||
const nlohmann::json& model_config = nlohmann::json()) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_parameters(model_parameters) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_config(model_config) {
|
||||
|
||||
set_computed_defaults(sort, infix);
|
||||
}
|
||||
@ -328,10 +330,9 @@ struct field {
|
||||
field_val[fields::reference] = field.reference;
|
||||
}
|
||||
if(!field.embed_from.empty()) {
|
||||
field_val[fields::embed_from] = field.embed_from;
|
||||
if(!field.model_parameters.empty()) {
|
||||
field_val[fields::model_parameters] = field.model_parameters;
|
||||
}
|
||||
field_val[fields::embed] = nlohmann::json::object();
|
||||
field_val[fields::embed][fields::from] = field.embed_from;
|
||||
field_val[fields::embed][fields::model_config] = field.model_config;
|
||||
}
|
||||
fields_json.push_back(field_val);
|
||||
|
||||
@ -428,33 +429,40 @@ struct field {
|
||||
size_t num_auto_detect_fields = 0;
|
||||
|
||||
for(nlohmann::json & field_json: fields_json) {
|
||||
if(field_json.count(fields::embed_from) != 0) {
|
||||
|
||||
if(!field_json[fields::embed_from].is_array()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` must be an array.");
|
||||
if(field_json.count(fields::embed) != 0) {
|
||||
|
||||
if(!field_json[fields::embed].is_object()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "` must be an object.");
|
||||
}
|
||||
|
||||
if(field_json[fields::embed_from].empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` must have at least one element.");
|
||||
if(field_json[fields::embed].count(fields::from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "` must contain a `" + fields::from + "` property.");
|
||||
}
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed_from]) {
|
||||
if(!field_json[fields::embed][fields::from].is_array()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must be an array.");
|
||||
}
|
||||
|
||||
if(field_json[fields::embed][fields::from].empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must have at least one element.");
|
||||
}
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
|
||||
if(!embed_from_field.is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` must contain only field names as strings.");
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` must contain only field names as strings.");
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json[fields::type] != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` is only allowed on a float array field.");
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` is only allowed on a float array field.");
|
||||
}
|
||||
|
||||
|
||||
for(auto& embed_from_field : field_json[fields::embed_from]) {
|
||||
for(auto& embed_from_field : field_json[fields::embed][fields::from]) {
|
||||
bool flag = false;
|
||||
for(const auto& field : fields_json) {
|
||||
if(field[fields::name] == embed_from_field) {
|
||||
if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields.");
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
@ -464,7 +472,7 @@ struct field {
|
||||
for(const auto& field : the_fields) {
|
||||
if(field.name == embed_from_field) {
|
||||
if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields.");
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
flag = true;
|
||||
break;
|
||||
@ -472,9 +480,9 @@ struct field {
|
||||
}
|
||||
}
|
||||
if(!flag) {
|
||||
return Option<bool>(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields.");
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto op = json_field_to_field(enable_nested_fields,
|
||||
|
@ -34,13 +34,13 @@ public:
|
||||
TextEmbedderManager(const TextEmbedderManager&) = delete;
|
||||
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
|
||||
|
||||
TextEmbedder* get_text_embedder(const nlohmann::json& model_parameters);
|
||||
TextEmbedder* get_text_embedder(const nlohmann::json& model_config);
|
||||
void delete_text_embedder(const std::string& model_path);
|
||||
void delete_all_text_embedders();
|
||||
|
||||
static const TokenizerType get_tokenizer_type(const nlohmann::json& model_parameters);
|
||||
const std::string get_indexing_prefix(const nlohmann::json& model_parameters);
|
||||
const std::string get_query_prefix(const nlohmann::json& model_parameters);
|
||||
static const TokenizerType get_tokenizer_type(const nlohmann::json& model_config);
|
||||
const std::string get_indexing_prefix(const nlohmann::json& model_config);
|
||||
const std::string get_query_prefix(const nlohmann::json& model_config);
|
||||
static void set_model_dir(const std::string& dir);
|
||||
static const std::string& get_model_dir();
|
||||
|
||||
|
@ -239,16 +239,17 @@ nlohmann::json Collection::get_summary_json() const {
|
||||
field_json[fields::sort] = coll_field.sort;
|
||||
field_json[fields::infix] = coll_field.infix;
|
||||
field_json[fields::locale] = coll_field.locale;
|
||||
field_json[fields::embed] = nlohmann::json::object();
|
||||
|
||||
if(!coll_field.embed_from.empty()) {
|
||||
field_json[fields::embed_from] = coll_field.embed_from;
|
||||
field_json[fields::embed][fields::from] = coll_field.embed_from;
|
||||
}
|
||||
|
||||
if(coll_field.model_parameters.size() > 0) {
|
||||
field_json[fields::model_parameters] = coll_field.model_parameters;
|
||||
if(coll_field.model_config.size() > 0) {
|
||||
field_json[fields::embed][fields::model_config] = coll_field.model_config;
|
||||
// Hide OpenAI API key from the response.
|
||||
if(field_json[fields::model_parameters].count(fields::api_key) != 0) {
|
||||
field_json[fields::model_parameters][fields::api_key] = "<hidden>";
|
||||
if(field_json[fields::embed][fields::model_config].count(fields::api_key) != 0) {
|
||||
field_json[fields::embed][fields::model_config][fields::api_key] = "<hidden>";
|
||||
}
|
||||
}
|
||||
|
||||
@ -1172,7 +1173,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
auto search_field = search_schema.at(expanded_search_field);
|
||||
|
||||
if(search_field.num_dim > 0) {
|
||||
if(has_embedding_query) {
|
||||
if(!vector_query.field_name.empty()) {
|
||||
std::string error = "Only one embedding field is allowed in the query.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
@ -1188,9 +1189,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(search_field.model_parameters);
|
||||
auto embedder = embedder_manager.get_text_embedder(search_field.model_config);
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(search_field.model_parameters) + raw_query;
|
||||
std::string embed_query = embedder_manager.get_query_prefix(search_field.model_config) + raw_query;
|
||||
auto embedding_op = embedder->Embed(embed_query);
|
||||
if(!embedding_op.ok()) {
|
||||
return Option<nlohmann::json>(400, embedding_op.error());
|
||||
@ -1200,7 +1201,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
vector_query._reset();
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
has_embedding_query = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1212,7 +1212,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
std::string real_raw_query = raw_query;
|
||||
if(has_embedding_query && processed_search_fields.size() == 0) {
|
||||
if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) {
|
||||
raw_query = "*";
|
||||
}
|
||||
|
||||
|
@ -62,8 +62,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
|
||||
field_obj[fields::embed_from] = std::vector<std::string>();
|
||||
}
|
||||
|
||||
if(field_obj.count(fields::model_parameters) == 0) {
|
||||
field_obj[fields::model_parameters] = nlohmann::json::object();
|
||||
if(field_obj.count(fields::model_config) == 0) {
|
||||
field_obj[fields::model_config] = nlohmann::json::object();
|
||||
}
|
||||
vector_distance_type_t vec_dist_type = vector_distance_type_t::cosine;
|
||||
|
||||
@ -78,7 +78,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
|
||||
field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale],
|
||||
-1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array],
|
||||
field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed_from],
|
||||
field_obj[fields::model_parameters]);
|
||||
field_obj[fields::model_config]);
|
||||
|
||||
// value of `sort` depends on field type
|
||||
if(field_obj.count(fields::sort) == 0) {
|
||||
|
@ -672,64 +672,62 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
}
|
||||
}
|
||||
|
||||
if(field_json.count(fields::model_parameters) > 0 && field_json.count(fields::embed_from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "` can only be used with `" + fields::embed_from + "`.");
|
||||
if(field_json.count(fields::model_config) > 0 && field_json.count(fields::embed_from) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::model_config + "` can only be used with `" + fields::embed_from + "`.");
|
||||
}
|
||||
|
||||
if(field_json.count(fields::embed_from) != 0) {
|
||||
if(field_json.count(fields::embed) != 0) {
|
||||
// If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model)
|
||||
field_json[fields::num_dim] = static_cast<unsigned int>(384);
|
||||
|
||||
if(field_json.count(fields::model_parameters) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "` must be specified with `" + fields::embed_from + "`.");
|
||||
auto& embed_json = field_json[fields::embed];
|
||||
|
||||
if(embed_json.count(fields::model_config) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "` not found.");
|
||||
}
|
||||
|
||||
auto& model_parameters = field_json[fields::model_parameters];
|
||||
auto& model_config = embed_json[fields::model_config];
|
||||
|
||||
if(model_parameters.count(fields::model_name) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` must be specified with `" + fields::embed_from + "`.");
|
||||
if(model_config.count(fields::model_name) == 0) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "`not found");
|
||||
}
|
||||
|
||||
unsigned int num_dim = 0;
|
||||
if(!model_parameters[fields::model_name].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` must be a string.");
|
||||
if(!model_config[fields::model_name].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "` must be a string.");
|
||||
}
|
||||
if(model_parameters[fields::model_name].get<std::string>().empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` cannot be empty.");
|
||||
if(model_config[fields::model_name].get<std::string>().empty()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "` cannot be empty.");
|
||||
}
|
||||
|
||||
if(model_parameters.count(fields::indexing_prefix) != 0) {
|
||||
if(!model_parameters[fields::indexing_prefix].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::indexing_prefix + "` must be a string.");
|
||||
if(model_config.count(fields::indexing_prefix) != 0) {
|
||||
if(!model_config[fields::indexing_prefix].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::indexing_prefix + "` must be a string.");
|
||||
}
|
||||
}
|
||||
|
||||
if(model_parameters.count(fields::query_prefix) != 0) {
|
||||
if(!model_parameters[fields::query_prefix].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::query_prefix + "` must be a string.");
|
||||
if(model_config.count(fields::query_prefix) != 0) {
|
||||
if(!model_config[fields::query_prefix].is_string()) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::query_prefix + "` must be a string.");
|
||||
}
|
||||
}
|
||||
|
||||
if(model_parameters.count("api_key") != 0) {
|
||||
auto res = TextEmbedder::is_model_valid(model_parameters[fields::model_name].get<std::string>(), model_parameters[fields::api_key].get<std::string>(), num_dim);
|
||||
if(res.ok()) {
|
||||
field_json[fields::num_dim] = num_dim;
|
||||
} else {
|
||||
if(model_config.count("api_key") != 0) {
|
||||
auto res = TextEmbedder::is_model_valid(model_config[fields::model_name].get<std::string>(), model_config[fields::api_key].get<std::string>(), num_dim);
|
||||
if(!res.ok()) {
|
||||
return Option<bool>(res.code(), res.error());
|
||||
}
|
||||
} else {
|
||||
if(TextEmbedder::is_model_valid(model_parameters[fields::model_name].get<std::string>(), num_dim)) {
|
||||
field_json[fields::num_dim] = num_dim;
|
||||
} else {
|
||||
return Option<bool>(400, "Property `" + fields::model_parameters + "." + fields::model_name + "` is invalid.");
|
||||
if(!TextEmbedder::is_model_valid(model_config[fields::model_name].get<std::string>(), num_dim)) {
|
||||
return Option<bool>(400, "Property `" + fields::embed + "." + fields::model_config + "." + fields::model_name + "` is invalid.");
|
||||
}
|
||||
}
|
||||
field_json[fields::num_dim] = num_dim;
|
||||
} else {
|
||||
field_json[fields::embed_from] = std::vector<std::string>();
|
||||
field_json[fields::embed] = nlohmann::json::object();
|
||||
field_json[fields::embed][fields::from] = nlohmann::json::array();
|
||||
}
|
||||
|
||||
|
||||
|
||||
auto DEFAULT_VEC_DIST_METRIC = magic_enum::enum_name(vector_distance_type_t::cosine);
|
||||
|
||||
if(field_json.count(fields::num_dim) == 0) {
|
||||
@ -810,8 +808,8 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
|
||||
field_json[fields::reference], field_json[fields::embed_from].get<std::vector<std::string>>(),
|
||||
field_json[fields::model_parameters])
|
||||
field_json[fields::reference], field_json[fields::embed][fields::from].get<std::vector<std::string>>(),
|
||||
field_json[fields::embed][fields::model_config])
|
||||
);
|
||||
|
||||
if (!field_json[fields::reference].get<std::string>().empty()) {
|
||||
|
@ -6349,7 +6349,7 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
for(const auto& field : embedding_fields) {
|
||||
std::vector<std::string> text_to_embed;
|
||||
for(auto& document : documents) {
|
||||
std::string text = TextEmbedderManager::get_instance().get_indexing_prefix(field.model_parameters);
|
||||
std::string text = TextEmbedderManager::get_instance().get_indexing_prefix(field.model_config);
|
||||
for(const auto& field_name : field.embed_from) {
|
||||
auto field_it = search_schema.find(field_name);
|
||||
if(field_it.value().type == field_types::STRING) {
|
||||
@ -6363,7 +6363,7 @@ Option<bool> Index::batch_embed_fields(std::vector<nlohmann::json*>& documents,
|
||||
text_to_embed.push_back(text);
|
||||
}
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder = embedder_manager.get_text_embedder(field.model_parameters);
|
||||
auto embedder = embedder_manager.get_text_embedder(field.model_config);
|
||||
auto embedding_op = embedder->batch_embed(text_to_embed);
|
||||
|
||||
if(!embedding_op.ok()) {
|
||||
|
@ -6,18 +6,18 @@ TextEmbedderManager& TextEmbedderManager::get_instance() {
|
||||
return instance;
|
||||
}
|
||||
|
||||
TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_parameters) {
|
||||
TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) {
|
||||
std::unique_lock<std::mutex> lock(text_embedders_mutex);
|
||||
const std::string& model_name = model_parameters.at("model_name");
|
||||
const std::string& model_name = model_config.at("model_name");
|
||||
if(text_embedders[model_name] == nullptr) {
|
||||
if(model_parameters.count("api_key") == 0) {
|
||||
if(model_config.count("api_key") == 0) {
|
||||
if(is_public_model(model_name)) {
|
||||
// download the model if it doesn't exist
|
||||
download_public_model(model_name);
|
||||
}
|
||||
text_embedders[model_name] = std::make_shared<TextEmbedder>(get_model_name_without_namespace(model_name));
|
||||
} else {
|
||||
text_embedders[model_name] = std::make_shared<TextEmbedder>(model_name, model_parameters.at("api_key").get<std::string>());
|
||||
text_embedders[model_name] = std::make_shared<TextEmbedder>(model_name, model_config.at("api_key").get<std::string>());
|
||||
}
|
||||
}
|
||||
return text_embedders[model_name].get();
|
||||
@ -35,11 +35,11 @@ void TextEmbedderManager::delete_all_text_embedders() {
|
||||
text_embedders.clear();
|
||||
}
|
||||
|
||||
const TokenizerType TextEmbedderManager::get_tokenizer_type(const nlohmann::json& model_parameters) {
|
||||
if(model_parameters.find("model_type") == model_parameters.end()) {
|
||||
const TokenizerType TextEmbedderManager::get_tokenizer_type(const nlohmann::json& model_config) {
|
||||
if(model_config.find("model_type") == model_config.end()) {
|
||||
return TokenizerType::bert;
|
||||
} else {
|
||||
std::string tokenizer_type = model_parameters.at("model_type").get<std::string>();
|
||||
std::string tokenizer_type = model_config.at("model_type").get<std::string>();
|
||||
if(tokenizer_type == "distilbert") {
|
||||
return TokenizerType::distilbert;
|
||||
} else if(tokenizer_type == "xlm_roberta") {
|
||||
@ -50,12 +50,12 @@ const TokenizerType TextEmbedderManager::get_tokenizer_type(const nlohmann::json
|
||||
}
|
||||
}
|
||||
|
||||
const std::string TextEmbedderManager::get_indexing_prefix(const nlohmann::json& model_parameters) {
|
||||
const std::string TextEmbedderManager::get_indexing_prefix(const nlohmann::json& model_config) {
|
||||
std::string val;
|
||||
if(is_public_model(model_parameters["model_name"].get<std::string>())) {
|
||||
val = public_models[model_parameters["model_name"].get<std::string>()].indexing_prefix;
|
||||
if(is_public_model(model_config["model_name"].get<std::string>())) {
|
||||
val = public_models[model_config["model_name"].get<std::string>()].indexing_prefix;
|
||||
} else {
|
||||
val = model_parameters.count("indexing_prefix") == 0 ? "" : model_parameters["indexing_prefix"].get<std::string>();
|
||||
val = model_config.count("indexing_prefix") == 0 ? "" : model_config["indexing_prefix"].get<std::string>();
|
||||
}
|
||||
if(!val.empty()) {
|
||||
val += " ";
|
||||
@ -64,12 +64,12 @@ const std::string TextEmbedderManager::get_indexing_prefix(const nlohmann::json&
|
||||
return val;
|
||||
}
|
||||
|
||||
const std::string TextEmbedderManager::get_query_prefix(const nlohmann::json& model_parameters) {
|
||||
const std::string TextEmbedderManager::get_query_prefix(const nlohmann::json& model_config) {
|
||||
std::string val;
|
||||
if(is_public_model(model_parameters["model_name"].get<std::string>())) {
|
||||
val = public_models[model_parameters["model_name"].get<std::string>()].query_prefix;
|
||||
if(is_public_model(model_config["model_name"].get<std::string>())) {
|
||||
val = public_models[model_config["model_name"].get<std::string>()].query_prefix;
|
||||
} else {
|
||||
val = model_parameters.count("query_prefix") == 0 ? "" : model_parameters["query_prefix"].get<std::string>();
|
||||
val = model_config.count("query_prefix") == 0 ? "" : model_config["query_prefix"].get<std::string>();
|
||||
}
|
||||
if(!val.empty()) {
|
||||
val += " ";
|
||||
|
@ -1597,9 +1597,10 @@ TEST_F(CollectionAllFieldsTest, EmbedFromFieldJSONInvalidField) {
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["embed_from"] = {"name"};
|
||||
field_json["model_parameters"] = nlohmann::json::object();
|
||||
field_json["model_parameters"]["model_name"] = "ts/e5-small";
|
||||
field_json["embed"] = nlohmann::json::object();
|
||||
field_json["embed"]["from"] = {"name"};
|
||||
field_json["embed"]["model_config"] = nlohmann::json::object();
|
||||
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
@ -1609,35 +1610,18 @@ TEST_F(CollectionAllFieldsTest, EmbedFromFieldJSONInvalidField) {
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `embed_from` can only refer to string or string array fields.", field_op.error());
|
||||
ASSERT_EQ("Property `embed.from` can only refer to string or string array fields.", field_op.error());
|
||||
}
|
||||
|
||||
// TEST_F(CollectionAllFieldsTest, EmbedFromFieldNoModelDir) {
|
||||
// TextEmbedderManager::model_dir = std::string();
|
||||
// nlohmann::json field_json;
|
||||
// field_json["name"] = "embedding";
|
||||
// field_json["type"] = "float[]";
|
||||
// field_json["embed_from"] = {"name"};
|
||||
|
||||
// std::vector<field> fields;
|
||||
// std::string fallback_field_type;
|
||||
// auto arr = nlohmann::json::array();
|
||||
// arr.push_back(field_json);
|
||||
|
||||
// auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
|
||||
// ASSERT_FALSE(field_op.ok());
|
||||
// ASSERT_EQ("Text embedding is not enabled. Please set `model-dir` at startup.", field_op.error());
|
||||
// }
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, EmbedFromNotArray) {
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["embed_from"] = "name";
|
||||
field_json["model_parameters"] = nlohmann::json::object();
|
||||
field_json["model_parameters"]["model_name"] = "ts/e5-small";
|
||||
field_json["embed"] = nlohmann::json::object();
|
||||
field_json["embed"]["from"] = "name";
|
||||
field_json["embed"]["model_config"] = nlohmann::json::object();
|
||||
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
@ -1647,7 +1631,7 @@ TEST_F(CollectionAllFieldsTest, EmbedFromNotArray) {
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `embed_from` must be an array.", field_op.error());
|
||||
ASSERT_EQ("Property `embed.from` must be an array.", field_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionAllFieldsTest, ModelParametersWithoutEmbedFrom) {
|
||||
@ -1655,8 +1639,8 @@ TEST_F(CollectionAllFieldsTest, ModelParametersWithoutEmbedFrom) {
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["model_parameters"] = nlohmann::json::object();
|
||||
field_json["model_parameters"]["model_name"] = "ts/e5-small";
|
||||
field_json["embed"]["model_config"] = nlohmann::json::object();
|
||||
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
@ -1665,7 +1649,7 @@ TEST_F(CollectionAllFieldsTest, ModelParametersWithoutEmbedFrom) {
|
||||
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `model_parameters` can only be used with `embed_from`.", field_op.error());
|
||||
ASSERT_EQ("Property `embed` must contain a `from` property.", field_op.error());
|
||||
}
|
||||
|
||||
|
||||
@ -1675,7 +1659,7 @@ TEST_F(CollectionAllFieldsTest, EmbedFromBasicValid) {
|
||||
|
||||
field embedding = field("embedding", field_types::FLOAT_ARRAY, false);
|
||||
embedding.embed_from.push_back("name");
|
||||
embedding.model_parameters["model_name"] = "ts/e5-small";
|
||||
embedding.model_config["model_name"] = "ts/e5-small";
|
||||
std::vector<field> fields = {field("name", field_types::STRING, false),
|
||||
embedding};
|
||||
auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO);
|
||||
@ -1700,9 +1684,10 @@ TEST_F(CollectionAllFieldsTest, WrongDataTypeForEmbedFrom) {
|
||||
nlohmann::json field_json;
|
||||
field_json["name"] = "embedding";
|
||||
field_json["type"] = "float[]";
|
||||
field_json["embed_from"] = {"age"};
|
||||
field_json["model_parameters"] = nlohmann::json::object();
|
||||
field_json["model_parameters"]["model_name"] = "ts/e5-small";
|
||||
field_json["embed"] = nlohmann::json::object();
|
||||
field_json["embed"]["from"] = {"age"};
|
||||
field_json["embed"]["model_config"] = nlohmann::json::object();
|
||||
field_json["embed"]["model_config"]["model_name"] = "ts/e5-small";
|
||||
|
||||
std::vector<field> fields;
|
||||
std::string fallback_field_type;
|
||||
@ -1713,5 +1698,5 @@ TEST_F(CollectionAllFieldsTest, WrongDataTypeForEmbedFrom) {
|
||||
arr.push_back(field_json);
|
||||
auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields);
|
||||
ASSERT_FALSE(field_op.ok());
|
||||
ASSERT_EQ("Property `embed_from` can only refer to string or string array fields.", field_op.error());
|
||||
ASSERT_EQ("Property `embed.from` can only refer to string or string array fields.", field_op.error());
|
||||
}
|
@ -1464,7 +1464,7 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) {
|
||||
|
||||
nlohmann::json update_schema = R"({
|
||||
"fields": [
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["names"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -1489,7 +1489,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {
|
||||
"fields": [
|
||||
{"name": "names", "type": "string[]"},
|
||||
{"name": "category", "type":"string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["names","category"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["names","category"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -1537,7 +1537,7 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
|
@ -4617,7 +4617,7 @@ TEST_F(CollectionTest, SemanticSearchTest) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4651,7 +4651,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4680,7 +4680,7 @@ TEST_F(CollectionTest, HybridSearch) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4713,7 +4713,7 @@ TEST_F(CollectionTest, HybridSearch) {
|
||||
// "name": "objects",
|
||||
// "fields": [
|
||||
// {"name": "name", "type": "string"},
|
||||
// {"name": "embedding", "type":"float[]", "embed_from": ["name"]}
|
||||
// {"name": "embedding", "type":"float[]", "embed":{"from": ["name"]}
|
||||
// ]
|
||||
// })"_json;
|
||||
|
||||
@ -4741,7 +4741,7 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4814,7 +4814,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4847,7 +4847,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "names", "type": "string[]"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["names"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4872,7 +4872,7 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) {
|
||||
"fields": [
|
||||
{"name": "names", "type": "string[]"},
|
||||
{"name": "category", "type": "string", "optional": true},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["names", "category"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["names", "category"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4898,7 +4898,7 @@ TEST_F(CollectionTest, WrongTypeForEmbedding) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "category", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["category"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["category"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4921,7 +4921,7 @@ TEST_F(CollectionTest, WrongTypeOfElementForEmbeddingInStringArray) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "category", "type": "string[]"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["category"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["category"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4944,7 +4944,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -4991,7 +4991,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "openai/text-embedding-ada-002"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -5001,15 +5001,15 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
}
|
||||
|
||||
auto api_key = std::string(std::getenv("api_key"));
|
||||
schema["fields"][1]["model_parameters"]["api_key"] = api_key;
|
||||
schema["fields"][1]["model_config"]["api_key"] = api_key;
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
auto summary = op.get()->get_summary_json();
|
||||
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_parameters"]["model_name"]);
|
||||
ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_config"]["model_name"]);
|
||||
ASSERT_EQ(1536, summary["fields"][1]["num_dim"]);
|
||||
// make sure api_key is <hidden>
|
||||
ASSERT_EQ("<hidden>", summary["fields"][1]["model_parameters"]["api_key"]);
|
||||
ASSERT_EQ("<hidden>", summary["fields"][1]["model_config"]["api_key"]);
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "butter";
|
||||
@ -5019,4 +5019,37 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
ASSERT_EQ(1536, add_op.get()["embedding"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, MoreThganOneEmbeddingField) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "name2", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}},
|
||||
{"name": "embedding2", "type":"float[]", "embed":{"from": ["name2"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto op = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
auto coll = op.get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "butter";
|
||||
doc["name2"] = "butterball";
|
||||
|
||||
auto add_op = validator_t::validate_embed_fields(doc, op.get()->get_embedding_fields(), op.get()->get_schema(), true);
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
spp::sparse_hash_set<std::string> dummy_include_exclude;
|
||||
|
||||
auto search_res_op = coll->search("butter", {"name", "embedding", "embedding2"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
|
||||
ASSERT_EQ("Only one embedding field is allowed in the query.", search_res_op.error());
|
||||
}
|
||||
|
||||
|
@ -683,7 +683,7 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "vec", "type": "float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
@ -781,7 +781,7 @@ TEST_F(CollectionVectorTest, EmbeddingFieldVectorIndexTest) {
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed_from": ["name"], "model_parameters": {"model_name": "ts/e5-small"}}
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user