diff --git a/include/collection_manager.h b/include/collection_manager.h index ddde53c8..84f27ac2 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -25,6 +25,8 @@ private: static constexpr const char* COLLECTION_SORT_FIELDS_KEY = "sort_fields"; static constexpr const char* COLLECTION_TOKEN_ORDERING_FIELD_KEY = "token_ranking_field"; + std::string auth_key; + CollectionManager(); ~CollectionManager() = default; @@ -38,7 +40,9 @@ public: CollectionManager(CollectionManager const&) = delete; void operator=(CollectionManager const&) = delete; - void init(Store *store); + void init(Store *store, const std::string & auth_key); + + bool auth_key_matches(std::string auth_key_sent); Collection* create_collection(std::string name, const std::vector & search_fields, const std::vector & facet_fields, diff --git a/include/http_server.h b/include/http_server.h index a1297e10..e5ae5855 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -33,13 +33,18 @@ struct http_res { body = "{\"message\": \"" + message + "\"}"; } + void send_403() { + status_code = 403; + body = "{\"message\": \"Forbidden\"}"; + } + void send_404() { status_code = 404; body = "{\"message\": \"Not Found\"}"; } void send_409(const std::string & message) { - status_code = 400; + status_code = 409; body = "{\"message\": \"" + message + "\"}"; } @@ -63,6 +68,7 @@ struct route_path { std::string http_method; std::vector path_parts; void (*handler)(http_req & req, http_res &); + bool authenticated; inline bool operator< (const route_path& rhs) const { return true; @@ -76,9 +82,9 @@ private: static h2o_accept_ctx_t accept_ctx; static std::vector routes; - std::string listen_address; + const std::string listen_address; - uint32_t listen_port; + const uint32_t listen_port; h2o_hostconf_t *hostconf; @@ -95,18 +101,22 @@ private: static int catch_all_handler(h2o_handler_t *self, h2o_req_t *req); + static int send_403_forbidden(h2o_req_t *req); + public: HttpServer(std::string listen_address, uint32_t listen_port); ~HttpServer(); - void get(const std::string & path, void (*handler)(http_req & req, http_res &)); + void get(const std::string & path, void (*handler)(http_req & req, http_res &), bool authenticated); - void post(const std::string & path, void (*handler)(http_req &, http_res &)); + void post(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated); - void put(const std::string & path, void (*handler)(http_req &, http_res &)); + void put(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated); - void del(const std::string & path, void (*handler)(http_req &, http_res &)); + void del(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated); int run(); + + static constexpr const char* AUTH_HEADER = "x-api-key"; }; \ No newline at end of file diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 4e2e1ddc..eebd2210 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -8,8 +8,9 @@ CollectionManager::CollectionManager() { } -void CollectionManager::init(Store *store) { +void CollectionManager::init(Store *store, const std::string & auth_key) { this->store = store; + this->auth_key = auth_key; std::string next_collection_id_str; store->get(NEXT_COLLECTION_ID_KEY, next_collection_id_str); @@ -85,6 +86,10 @@ void CollectionManager::init(Store *store) { std::cout << "Finished restoring all collections from disk." << std::endl; } +bool CollectionManager::auth_key_matches(std::string auth_key_sent) { + return (auth_key == auth_key_sent); +} + Collection* CollectionManager::create_collection(std::string name, const std::vector & search_fields, const std::vector & facet_fields, const std::vector & sort_fields, diff --git a/src/http_server.cpp b/src/http_server.cpp index 23cac504..2e386a4a 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -2,6 +2,7 @@ #include "string_utils.h" #include #include +#include h2o_globalconf_t HttpServer::config; h2o_context_t HttpServer::ctx; @@ -54,7 +55,6 @@ int HttpServer::create_listener(void) { } int HttpServer::run() { - signal(SIGPIPE, SIG_IGN); h2o_context_init(&ctx, h2o_evloop_create(), &config); @@ -85,6 +85,7 @@ const char* HttpServer::get_status_reason(uint32_t status_code) { case 200: return "OK"; case 201: return "Created"; case 400: return "Bad Request"; + case 403: return "Forbidden"; case 404: return "Not Found"; case 409: return "Conflict"; case 500: return "Internal Server Error"; @@ -161,6 +162,25 @@ int HttpServer::catch_all_handler(h2o_handler_t *self, h2o_req_t *req) { } } + if(rpath.authenticated) { + CollectionManager & collectionManager = CollectionManager::get_instance(); + ssize_t auth_header_cursor = h2o_find_header_by_str(&req->headers, AUTH_HEADER, + strlen(AUTH_HEADER), -1); + + if(auth_header_cursor == -1) { + // requires authentication, but API Key is not present in the headers + return send_403_forbidden(req); + } else { + // api key is found, let's validate + h2o_iovec_t & slot = req->headers.entries[auth_header_cursor].value; + std::string auth_key_from_header = std::string(slot.base, slot.len); + + if(!collectionManager.auth_key_matches(auth_key_from_header)) { + return send_403_forbidden(req); + } + } + } + http_req request = {query_map, req_body}; http_res response; (rpath.handler)(request, response); @@ -186,31 +206,43 @@ int HttpServer::catch_all_handler(h2o_handler_t *self, h2o_req_t *req) { return 0; } -void HttpServer::get(const std::string & path, void (*handler)(http_req &, http_res &)) { +int HttpServer::send_403_forbidden(h2o_req_t *req) { + h2o_generator_t generator = {NULL, NULL}; + std::string res_body = std::string("{\"message\": \"Forbidden - ") + AUTH_HEADER + " header is invalid or not present.\"}"; + h2o_iovec_t body = h2o_strdup(&req->pool, res_body.c_str(), SIZE_MAX); + req->res.status = 403; + req->res.reason = get_status_reason(req->res.status); + h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, H2O_STRLIT("application/json; charset=utf-8")); + h2o_start_response(req, &generator); + h2o_send(req, &body, 1, 1); + return 0; +} + +void HttpServer::get(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) { std::vector path_parts; StringUtils::split(path, path_parts, "/"); - route_path rpath = {"GET", path_parts, handler}; + route_path rpath = {"GET", path_parts, handler, authenticated}; routes.push_back(rpath); } -void HttpServer::post(const std::string & path, void (*handler)(http_req &, http_res &)) { +void HttpServer::post(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) { std::vector path_parts; StringUtils::split(path, path_parts, "/"); - route_path rpath = {"POST", path_parts, handler}; + route_path rpath = {"POST", path_parts, handler, authenticated}; routes.push_back(rpath); } -void HttpServer::put(const std::string & path, void (*handler)(http_req &, http_res &)) { +void HttpServer::put(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) { std::vector path_parts; StringUtils::split(path, path_parts, "/"); - route_path rpath = {"PUT", path_parts, handler}; + route_path rpath = {"PUT", path_parts, handler, authenticated}; routes.push_back(rpath); } -void HttpServer::del(const std::string & path, void (*handler)(http_req &, http_res &)) { +void HttpServer::del(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) { std::vector path_parts; StringUtils::split(path, path_parts, "/"); - route_path rpath = {"DELETE", path_parts, handler}; + route_path rpath = {"DELETE", path_parts, handler, authenticated}; routes.push_back(rpath); } diff --git a/src/main/typesense_server.cpp b/src/main/typesense_server.cpp index c6dba806..4aea7a79 100644 --- a/src/main/typesense_server.cpp +++ b/src/main/typesense_server.cpp @@ -5,20 +5,24 @@ int main(int argc, char **argv) { cmdline::parser options; options.add("data-dir", 'd', "Directory where data will be stored.", true); + options.add("api-auth-key", 'k', "Key for authenticating the API endpoints.", true); options.add("listen-address", 'a', "Address to which Typesense server binds.", false, "0.0.0.0"); options.add("listen-port", 'p', "Port on which Typesense server listens.", false, 8080); options.parse_check(argc, argv); Store store(options.get("data-dir")); CollectionManager & collectionManager = CollectionManager::get_instance(); - collectionManager.init(&store); + collectionManager.init(&store, options.get("api-auth-key")); - HttpServer server(options.get("listen-address"), options.get("listen-port")); + HttpServer server( + options.get("listen-address"), + options.get("listen-port") + ); - server.post("/collection", post_create_collection); - server.post("/collection/:collection", post_add_document); - server.get("/collection/:collection/search", get_search); - server.del("/collection/:collection/:id", del_remove_document); + server.post("/collection", post_create_collection, true); + server.post("/collection/:collection", post_add_document, true); + server.get("/collection/:collection/search", get_search, false); + server.del("/collection/:collection/:id", del_remove_document, true); server.run(); return 0; diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index fffc9982..f65eb0a6 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -22,7 +22,7 @@ protected: system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); store = new Store(state_dir_path); - collectionManager.init(store); + collectionManager.init(store, "auth_key"); search_fields = {field("title", field_types::STRING), field("starring", field_types::STRING)}; facet_fields = {field("starring", field_types::STRING)}; @@ -107,7 +107,7 @@ TEST_F(CollectionManagerTest, RestoreRecordsOnRestart) { // create a new collection manager to ensure that it restores the records from the disk backed store CollectionManager & collectionManager2 = CollectionManager::get_instance(); - collectionManager2.init(store); + collectionManager2.init(store, "auth_key"); collection1 = collectionManager2.get_collection("collection1"); ASSERT_NE(nullptr, collection1); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 6a7a9c8d..33560b09 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -22,7 +22,7 @@ protected: system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); store = new Store(state_dir_path); - collectionManager.init(store); + collectionManager.init(store, "auth_key"); std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl"); std::vector search_fields = {field("title", field_types::STRING)};