Implement authentication against an API auth key.

The key should be passed via X-API-KEY HTTP header.
This commit is contained in:
Kishore Nallan 2017-07-04 22:18:47 +05:30
parent 06ff49df4a
commit c471cd50c3
7 changed files with 82 additions and 27 deletions

View File

@ -25,6 +25,8 @@ private:
static constexpr const char* COLLECTION_SORT_FIELDS_KEY = "sort_fields";
static constexpr const char* COLLECTION_TOKEN_ORDERING_FIELD_KEY = "token_ranking_field";
std::string auth_key;
CollectionManager();
~CollectionManager() = default;
@ -38,7 +40,9 @@ public:
CollectionManager(CollectionManager const&) = delete;
void operator=(CollectionManager const&) = delete;
void init(Store *store);
void init(Store *store, const std::string & auth_key);
bool auth_key_matches(std::string auth_key_sent);
Collection* create_collection(std::string name, const std::vector<field> & search_fields,
const std::vector<field> & facet_fields,

View File

@ -33,13 +33,18 @@ struct http_res {
body = "{\"message\": \"" + message + "\"}";
}
void send_403() {
status_code = 403;
body = "{\"message\": \"Forbidden\"}";
}
void send_404() {
status_code = 404;
body = "{\"message\": \"Not Found\"}";
}
void send_409(const std::string & message) {
status_code = 400;
status_code = 409;
body = "{\"message\": \"" + message + "\"}";
}
@ -63,6 +68,7 @@ struct route_path {
std::string http_method;
std::vector<std::string> path_parts;
void (*handler)(http_req & req, http_res &);
bool authenticated;
inline bool operator< (const route_path& rhs) const {
return true;
@ -76,9 +82,9 @@ private:
static h2o_accept_ctx_t accept_ctx;
static std::vector<route_path> routes;
std::string listen_address;
const std::string listen_address;
uint32_t listen_port;
const uint32_t listen_port;
h2o_hostconf_t *hostconf;
@ -95,18 +101,22 @@ private:
static int catch_all_handler(h2o_handler_t *self, h2o_req_t *req);
static int send_403_forbidden(h2o_req_t *req);
public:
HttpServer(std::string listen_address, uint32_t listen_port);
~HttpServer();
void get(const std::string & path, void (*handler)(http_req & req, http_res &));
void get(const std::string & path, void (*handler)(http_req & req, http_res &), bool authenticated);
void post(const std::string & path, void (*handler)(http_req &, http_res &));
void post(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated);
void put(const std::string & path, void (*handler)(http_req &, http_res &));
void put(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated);
void del(const std::string & path, void (*handler)(http_req &, http_res &));
void del(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated);
int run();
static constexpr const char* AUTH_HEADER = "x-api-key";
};

View File

@ -8,8 +8,9 @@ CollectionManager::CollectionManager() {
}
void CollectionManager::init(Store *store) {
void CollectionManager::init(Store *store, const std::string & auth_key) {
this->store = store;
this->auth_key = auth_key;
std::string next_collection_id_str;
store->get(NEXT_COLLECTION_ID_KEY, next_collection_id_str);
@ -85,6 +86,10 @@ void CollectionManager::init(Store *store) {
std::cout << "Finished restoring all collections from disk." << std::endl;
}
bool CollectionManager::auth_key_matches(std::string auth_key_sent) {
return (auth_key == auth_key_sent);
}
Collection* CollectionManager::create_collection(std::string name, const std::vector<field> & search_fields,
const std::vector<field> & facet_fields,
const std::vector<field> & sort_fields,

View File

@ -2,6 +2,7 @@
#include "string_utils.h"
#include <regex>
#include <signal.h>
#include <h2o.h>
h2o_globalconf_t HttpServer::config;
h2o_context_t HttpServer::ctx;
@ -54,7 +55,6 @@ int HttpServer::create_listener(void) {
}
int HttpServer::run() {
signal(SIGPIPE, SIG_IGN);
h2o_context_init(&ctx, h2o_evloop_create(), &config);
@ -85,6 +85,7 @@ const char* HttpServer::get_status_reason(uint32_t status_code) {
case 200: return "OK";
case 201: return "Created";
case 400: return "Bad Request";
case 403: return "Forbidden";
case 404: return "Not Found";
case 409: return "Conflict";
case 500: return "Internal Server Error";
@ -161,6 +162,25 @@ int HttpServer::catch_all_handler(h2o_handler_t *self, h2o_req_t *req) {
}
}
if(rpath.authenticated) {
CollectionManager & collectionManager = CollectionManager::get_instance();
ssize_t auth_header_cursor = h2o_find_header_by_str(&req->headers, AUTH_HEADER,
strlen(AUTH_HEADER), -1);
if(auth_header_cursor == -1) {
// requires authentication, but API Key is not present in the headers
return send_403_forbidden(req);
} else {
// api key is found, let's validate
h2o_iovec_t & slot = req->headers.entries[auth_header_cursor].value;
std::string auth_key_from_header = std::string(slot.base, slot.len);
if(!collectionManager.auth_key_matches(auth_key_from_header)) {
return send_403_forbidden(req);
}
}
}
http_req request = {query_map, req_body};
http_res response;
(rpath.handler)(request, response);
@ -186,31 +206,43 @@ int HttpServer::catch_all_handler(h2o_handler_t *self, h2o_req_t *req) {
return 0;
}
void HttpServer::get(const std::string & path, void (*handler)(http_req &, http_res &)) {
int HttpServer::send_403_forbidden(h2o_req_t *req) {
h2o_generator_t generator = {NULL, NULL};
std::string res_body = std::string("{\"message\": \"Forbidden - ") + AUTH_HEADER + " header is invalid or not present.\"}";
h2o_iovec_t body = h2o_strdup(&req->pool, res_body.c_str(), SIZE_MAX);
req->res.status = 403;
req->res.reason = get_status_reason(req->res.status);
h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, H2O_STRLIT("application/json; charset=utf-8"));
h2o_start_response(req, &generator);
h2o_send(req, &body, 1, 1);
return 0;
}
void HttpServer::get(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) {
std::vector<std::string> path_parts;
StringUtils::split(path, path_parts, "/");
route_path rpath = {"GET", path_parts, handler};
route_path rpath = {"GET", path_parts, handler, authenticated};
routes.push_back(rpath);
}
void HttpServer::post(const std::string & path, void (*handler)(http_req &, http_res &)) {
void HttpServer::post(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) {
std::vector<std::string> path_parts;
StringUtils::split(path, path_parts, "/");
route_path rpath = {"POST", path_parts, handler};
route_path rpath = {"POST", path_parts, handler, authenticated};
routes.push_back(rpath);
}
void HttpServer::put(const std::string & path, void (*handler)(http_req &, http_res &)) {
void HttpServer::put(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) {
std::vector<std::string> path_parts;
StringUtils::split(path, path_parts, "/");
route_path rpath = {"PUT", path_parts, handler};
route_path rpath = {"PUT", path_parts, handler, authenticated};
routes.push_back(rpath);
}
void HttpServer::del(const std::string & path, void (*handler)(http_req &, http_res &)) {
void HttpServer::del(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated) {
std::vector<std::string> path_parts;
StringUtils::split(path, path_parts, "/");
route_path rpath = {"DELETE", path_parts, handler};
route_path rpath = {"DELETE", path_parts, handler, authenticated};
routes.push_back(rpath);
}

View File

@ -5,20 +5,24 @@
int main(int argc, char **argv) {
cmdline::parser options;
options.add<std::string>("data-dir", 'd', "Directory where data will be stored.", true);
options.add<std::string>("api-auth-key", 'k', "Key for authenticating the API endpoints.", true);
options.add<std::string>("listen-address", 'a', "Address to which Typesense server binds.", false, "0.0.0.0");
options.add<uint32_t>("listen-port", 'p', "Port on which Typesense server listens.", false, 8080);
options.parse_check(argc, argv);
Store store(options.get<std::string>("data-dir"));
CollectionManager & collectionManager = CollectionManager::get_instance();
collectionManager.init(&store);
collectionManager.init(&store, options.get<std::string>("api-auth-key"));
HttpServer server(options.get<std::string>("listen-address"), options.get<uint32_t>("listen-port"));
HttpServer server(
options.get<std::string>("listen-address"),
options.get<uint32_t>("listen-port")
);
server.post("/collection", post_create_collection);
server.post("/collection/:collection", post_add_document);
server.get("/collection/:collection/search", get_search);
server.del("/collection/:collection/:id", del_remove_document);
server.post("/collection", post_create_collection, true);
server.post("/collection/:collection", post_add_document, true);
server.get("/collection/:collection/search", get_search, false);
server.del("/collection/:collection/:id", del_remove_document, true);
server.run();
return 0;

View File

@ -22,7 +22,7 @@ protected:
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
store = new Store(state_dir_path);
collectionManager.init(store);
collectionManager.init(store, "auth_key");
search_fields = {field("title", field_types::STRING), field("starring", field_types::STRING)};
facet_fields = {field("starring", field_types::STRING)};
@ -107,7 +107,7 @@ TEST_F(CollectionManagerTest, RestoreRecordsOnRestart) {
// create a new collection manager to ensure that it restores the records from the disk backed store
CollectionManager & collectionManager2 = CollectionManager::get_instance();
collectionManager2.init(store);
collectionManager2.init(store, "auth_key");
collection1 = collectionManager2.get_collection("collection1");
ASSERT_NE(nullptr, collection1);

View File

@ -22,7 +22,7 @@ protected:
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
store = new Store(state_dir_path);
collectionManager.init(store);
collectionManager.init(store, "auth_key");
std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl");
std::vector<field> search_fields = {field("title", field_types::STRING)};