diff --git a/BUILD b/BUILD index c3e679ec..0b9f865a 100644 --- a/BUILD +++ b/BUILD @@ -212,7 +212,7 @@ mkdir -p $INSTALLDIR/lib/_deps/google_nsync-build cp $BUILD_TMPDIR/_deps/onnx-build/libonnx.a $INSTALLDIR/lib/_deps/onnx-build cp $BUILD_TMPDIR/_deps/onnx-build/libonnx_proto.a $INSTALLDIR/lib/_deps/onnx-build cp $BUILD_TMPDIR/_deps/re2-build/libre2.a $INSTALLDIR/lib/_deps/re2-build -cp $BUILD_TMPDIR/_deps/abseil_cpp-build/. $INSTALLDIR/lib/_deps/abseil_cpp-build -r +cp -r $BUILD_TMPDIR/_deps/abseil_cpp-build/. $INSTALLDIR/lib/_deps/abseil_cpp-build cp $BUILD_TMPDIR/_deps/google_nsync-build/libnsync_cpp.a $INSTALLDIR/lib/_deps/google_nsync-build cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/deps/clog/libclog.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build/deps/clog cp $BUILD_TMPDIR/_deps/pytorch_cpuinfo-build/libcpuinfo.a $INSTALLDIR/lib/_deps/pytorch_cpuinfo-build diff --git a/include/http_data.h b/include/http_data.h index 4dd08917..9a2bedc9 100644 --- a/include/http_data.h +++ b/include/http_data.h @@ -206,6 +206,49 @@ public: virtual ~req_state_t() = default; }; +struct stream_response_state_t { +private: + + h2o_req_t* req = nullptr; + +public: + + bool is_req_early_exit = false; + + bool is_res_start = true; + h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS; + + std::string res_body; + h2o_iovec_t res_buff; + + std::string res_content_type; + int status = 0; + const char* reason = nullptr; + + h2o_generator_t* generator = nullptr; + + void set_response(uint32_t status_code, const std::string& content_type, std::string& body) { + std::string().swap(res_body); + res_body = std::move(body); + res_buff = h2o_iovec_t{.base = res_body.data(), .len = res_body.size()}; + + if(is_res_start) { + res_content_type = std::move(content_type); + status = (int)status_code; + reason = http_res::get_status_reason(status_code); + is_res_start = false; + } + } + + void set_req(h2o_req_t* _req) { + req = _req; + } + + h2o_req_t* get_req() { + return req; + } +}; + struct http_req { static constexpr const char* AUTH_HEADER = "x-typesense-api-key"; static constexpr const char* USER_HEADER = "x-typesense-user-id"; @@ -252,6 +295,9 @@ struct http_req { z_stream zs; bool zstream_initialized = false; + // stores http lib related datastructures to avoid race conditions between indexing and http write threads + stream_response_state_t res_state; + http_req(): _req(nullptr), route_hash(1), first_chunk_aggregate(true), last_chunk_aggregate(false), chunk_len(0), body_index(0), data(nullptr), ready(false), log_index(0), diff --git a/include/http_proxy.h b/include/http_proxy.h index 4b8315b3..851daa91 100644 --- a/include/http_proxy.h +++ b/include/http_proxy.h @@ -36,6 +36,7 @@ class HttpProxy { private: HttpProxy(); ~HttpProxy() = default; + http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}); // lru cache for http requests diff --git a/include/http_server.h b/include/http_server.h index ae017abf..7443b466 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -40,50 +40,6 @@ struct h2o_custom_generator_t { } }; -struct stream_response_state_t { -private: - - h2o_req_t* req = nullptr; - -public: - - bool is_req_early_exit = false; - - bool is_res_start = true; - h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS; - - std::string res_body; - h2o_iovec_t res_buff; - - h2o_iovec_t res_content_type{}; - int status = 0; - const char* reason = nullptr; - - h2o_generator_t* generator = nullptr; - - explicit stream_response_state_t(h2o_req_t* _req): req(_req) { - if(req != nullptr) { - is_res_start = (req->res.status == 0); - } - } - - void set_response(uint32_t status_code, const std::string& content_type, std::string& body) { - std::string().swap(res_body); - res_body = std::move(body); - res_buff = h2o_iovec_t{.base = res_body.data(), .len = res_body.size()}; - - if(is_res_start) { - res_content_type = h2o_strdup(&req->pool, content_type.c_str(), SIZE_MAX); - status = status_code; - reason = http_res::get_status_reason(status_code); - } - } - - h2o_req_t* get_req() { - return req; - } -}; - struct deferred_req_res_t { const std::shared_ptr req; const std::shared_ptr res; @@ -110,13 +66,9 @@ public: // used to manage lifecycle of async actions const bool destroy_after_use; - // stores http lib related datastructures to avoid race conditions between indexing and http write threads - stream_response_state_t res_state; - async_req_res_t(const std::shared_ptr& h_req, const std::shared_ptr& h_res, const bool destroy_after_use) : - req(h_req), res(h_res), destroy_after_use(destroy_after_use), - res_state((std::shared_lock(res->mres), h_req->is_diposed ? nullptr : h_req->_req)) { + req(h_req), res(h_res), destroy_after_use(destroy_after_use) { std::shared_lock lk(res->mres); @@ -124,12 +76,10 @@ public: return; } - // ***IMPORTANT*** - // We limit writing to fields of `res_state.req` to prevent race conditions with http thread - // Check `HttpServer::stream_response()` for overlapping writes. - h2o_custom_generator_t* res_generator = static_cast(res->generator.load()); + auto& res_state = req->res_state; + res_state.set_req(h_req->is_diposed ? nullptr : h_req->_req); res_state.is_req_early_exit = (res_generator->rpath->async_req && res->final && !req->last_chunk_aggregate); res_state.send_state = res->final ? H2O_SEND_STATE_FINAL : H2O_SEND_STATE_IN_PROGRESS; res_state.generator = (res_generator == nullptr) ? nullptr : &res_generator->h2o_generator; @@ -147,6 +97,10 @@ public: void res_notify() { return res->notify(); } + + stream_response_state_t& get_res_state() { + return req->res_state; + } }; struct defer_processing_t { diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index b2d2c684..eb562849 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -8,39 +8,58 @@ HttpProxy::HttpProxy() : cache(30s){ } -http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { - // check if url is in cache - uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); - key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); - key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); - for (auto& header : headers) { - key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size())); - key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size())); - } - if (cache.contains(key)) { - return cache[key]; - } - // if not, make http request +http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { HttpClient& client = HttpClient::get_instance(); http_proxy_res_t res; - if(method == "GET") { - res.status_code = client.get_response(url, res.body, res.headers, headers, 30 * 1000); + res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000); } else if(method == "POST") { - res.status_code = client.post_response(url, body, res.body, res.headers, headers, 30 * 1000); + res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000); } else if(method == "PUT") { - res.status_code = client.put_response(url, body, res.body, res.headers, 30 * 1000); + res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000); } else if(method == "DELETE") { - res.status_code = client.delete_response(url, res.body, res.headers, 30 * 1000); + res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000); } else { res.status_code = 400; nlohmann::json j; j["message"] = "Parameter `method` must be one of GET, POST, PUT, DELETE."; res.body = j.dump(); } + return res; +} - // add to cache - cache.insert(key, res); + +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { + // check if url is in cache + uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); + key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); + key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); + for(auto& header : headers){ + key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size())); + key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size())); + } + if(cache.contains(key)){ + return cache[key]; + } + + auto res = call(url, method, body, headers); + + if(res.status_code == 500){ + // retry + res = call(url, method, body, headers); + } + + if(res.status_code == 500){ + nlohmann::json j; + j["message"] = "Server error on remote server. Please try again later."; + res.body = j.dump(); + } + + + // add to cache + if(res.status_code != 500){ + cache.insert(key, res); + } return res; } \ No newline at end of file diff --git a/src/http_server.cpp b/src/http_server.cpp index 265596ca..c6eb3cc2 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -807,9 +807,11 @@ void HttpServer::stream_response(stream_response_state_t& state) { h2o_req_t* req = state.get_req(); - if(state.is_res_start) { + bool start_of_res = (req->res.status == 0); + + if(start_of_res) { h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, NULL, - state.res_content_type.base, state.res_content_type.len); + state.res_content_type.data(), state.res_content_type.size()); req->res.status = state.status; req->res.reason = state.reason; } @@ -828,7 +830,7 @@ void HttpServer::stream_response(stream_response_state_t& state) { return ; } - if (state.is_res_start) { + if (start_of_res) { /*LOG(INFO) << "h2o_start_response, content_type=" << state.res_content_type << ",response.status_code=" << state.res_status_code;*/ h2o_start_response(req, state.generator); @@ -968,7 +970,7 @@ bool HttpServer::on_stream_response_message(void *data) { // NOTE: access to `req` and `res` objects must be synchronized and wrapped by `req_res` if(req_res->is_alive()) { - stream_response(req_res->res_state); + stream_response(req_res->get_res_state()); } else { // serialized request or generator has been disposed (underlying request is probably dead) req_res->req_notify(); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 2e12f3b9..bb984e86 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -15,9 +15,9 @@ long RemoteEmbedder::call_remote_api(const std::string& method, const std::strin std::map& headers, const std::unordered_map& req_headers) { if(raft_server == nullptr || raft_server->get_leader_url().empty()) { if(method == "GET") { - return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 100000, true); + return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 45000, true); } else if(method == "POST") { - return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 100000, true); + return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 45000, true); } else { return 400; } @@ -30,7 +30,7 @@ long RemoteEmbedder::call_remote_api(const std::string& method, const std::strin req_body["url"] = url; req_body["body"] = body; req_body["headers"] = req_headers; - return HttpClient::get_instance().post_response(leader_url, req_body.dump(), res_body, headers, {}, 10000, true); + return HttpClient::get_instance().post_response(leader_url, req_body.dump(), res_body, headers, {}, 45000, true); } diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index b6b598dc..797b610b 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -679,42 +679,73 @@ TEST_F(CollectionVectorTest, VectorWithNullValue) { } TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { - nlohmann::json schema = R"({ - "name": "coll1", - "fields": [ - {"name": "name", "type": "string"}, - {"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} - ] - })"_json; - + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); - Collection* coll1 = collectionManager.create_collection(schema).get(); + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + nlohmann::json object; + object["name"] = "butter"; + auto add_op = coll->add(object.dump()); + ASSERT_TRUE(add_op.ok()); - nlohmann::json doc; + object["name"] = "butterball"; + add_op = coll->add(object.dump()); + ASSERT_TRUE(add_op.ok()); - doc["name"] = "john doe"; - ASSERT_TRUE(coll1->add(doc.dump()).ok()); + object["name"] = "butterfly"; + add_op = coll->add(object.dump()); + ASSERT_TRUE(add_op.ok()); - std::string dummy_vec_string = "[0.9"; - for (int i = 0; i < 382; i++) { - dummy_vec_string += ", 0.9"; + nlohmann::json model_config = R"({ + "model_name": "ts/e5-small" + })"_json; + + auto query_embedding = TextEmbedderManager::get_instance().get_text_embedder(model_config).get()->Embed("butter"); + + std::string vec_string = "["; + for(size_t i = 0; i < query_embedding.embedding.size(); i++) { + vec_string += std::to_string(query_embedding.embedding[i]); + if(i != query_embedding.embedding.size() - 1) { + vec_string += ","; + } } - dummy_vec_string += ", 0.9]"; - - auto results_op = coll1->search("john", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + vec_string += "]"; + auto search_res_op = coll->search("butter", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 5, "", 10, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 32767, 32767, 2, - false, true, "vec:(" + dummy_vec_string +")"); - ASSERT_EQ(true, results_op.ok()); - ASSERT_EQ(1, results_op.get()["found"].get()); - ASSERT_EQ(1, results_op.get()["hits"].size()); + false, true, "embedding:(" + vec_string + ")"); + + ASSERT_TRUE(search_res_op.ok()); + auto search_res = search_res_op.get(); + ASSERT_EQ(3, search_res["found"].get()); + ASSERT_EQ(3, search_res["hits"].size()); + // Hybrid search with rank fusion order: + // 1. butter (1/1 * 0.7) + (1/1 * 0.3) = 1 + // 2. butterfly (1/2 * 0.7) + (1/3 * 0.3) = 0.45 + // 3. butterball (1/3 * 0.7) + (1/2 * 0.3) = 0.383 + ASSERT_EQ("butter", search_res["hits"][0]["document"]["name"].get()); + ASSERT_EQ("butterfly", search_res["hits"][1]["document"]["name"].get()); + ASSERT_EQ("butterball", search_res["hits"][2]["document"]["name"].get()); + + ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); } + TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) { nlohmann::json schema = R"({ "name": "coll1",