typesense/include/text_embedder.h
Krunal Gandhi b256736986
Some checks failed
tests / test (push) Has been cancelled
allow updating remote model api_key (#1944)
* allow updating remote model api_key

* add nullptr checks and unique_lock

* refactor field updation flow

* Revert "refactor field updation flow"

This reverts commit a6a9847be80ce3d8cee7ea944f5e9453091a949f.

* refactor api_key updation flow

* add locks in remote embedder, update api_key after validation

* fix parsing of embed field

* fix wrong params and parse

* update minor changes, update entry in text_embedders

* add comment

* update embeded_fields with new api_key

* add lock for updating embedded fields

---------

Co-authored-by: Kishore Nallan <kishorenc@gmail.com>
2024-09-25 08:58:32 +05:30

58 lines
2.2 KiB
C++

#pragma once
#include <sentencepiece_processor.h>
#include <core/session/onnxruntime_cxx_api.h>
#include <tokenizer/bert_tokenizer.hpp>
#include <vector>
#include "option.h"
#include "text_embedder_tokenizer.h"
#include "text_embedder_remote.h"
class TextEmbedder {
public:
// Constructor for local or public models
TextEmbedder(const std::string& model_path);
// Constructor for remote models
TextEmbedder(const nlohmann::json& model_config, size_t num_dims, const bool has_custom_dims = false);
~TextEmbedder();
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200,
const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2);
const std::string& get_vocab_file_name() const;
const size_t get_num_dim() const;
bool is_remote() {
return remote_embedder_ != nullptr;
}
Option<bool> validate();
std::shared_ptr<Ort::Session> get_session() {
return session_;
}
std::shared_ptr<Ort::Env> get_env() {
return env_;
}
const TokenizerType get_tokenizer_type() {
return tokenizer_->get_tokenizer_type();
}
bool update_remote_embedder_apikey(const std::string& api_key) {
return remote_embedder_->update_api_key(api_key);
}
private:
std::shared_ptr<Ort::Session> session_;
std::shared_ptr<Ort::Env> env_;
encoded_input_t Encode(const std::string& text);
batch_encoded_input_t batch_encode(const std::vector<std::string>& inputs);
std::unique_ptr<TextEmbeddingTokenizer> tokenizer_;
std::unique_ptr<RemoteEmbedder> remote_embedder_;
std::string vocab_file_name;
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input, const std::vector<int64_t>& attention_mask);
std::string output_tensor_name;
size_t num_dim;
std::mutex mutex_;
};