Extract authentication logic into a separate handler.

It is now possible to define a separate key for search end-point alone as well. This will be used when one needs to call the API directly from Javascript.
This commit is contained in:
Kishore Nallan 2017-12-12 18:10:08 +05:30
parent 8904aba598
commit de80250de9
7 changed files with 56 additions and 31 deletions

View File

@ -2,6 +2,8 @@
#include "http_server.h"
bool handle_authentication(const route_path & rpath, const std::string & auth_key);
void get_collections(http_req & req, http_res & res);
void post_create_collection(http_req & req, http_res & res);

View File

@ -26,6 +26,7 @@ private:
static constexpr const char* COLLECTION_TOKEN_ORDERING_FIELD_KEY = "token_ranking_field";
std::string auth_key;
std::string search_only_auth_key;
CollectionManager();
@ -40,7 +41,7 @@ public:
CollectionManager(CollectionManager const&) = delete;
void operator=(CollectionManager const&) = delete;
Option<bool> init(Store *store, const std::string & auth_key);
Option<bool> init(Store *store, const std::string & auth_key, const std::string & search_only_auth_key);
// frees in-memory data structures when server is shutdown - helps us run a memory leak detecter properly
void dispose();
@ -51,6 +52,8 @@ public:
bool auth_key_matches(std::string auth_key_sent);
bool search_only_auth_key_matches(std::string auth_key_sent);
Option<Collection*> create_collection(std::string name, const std::vector<field> & fields,
const std::string & token_ranking_field = "");

View File

@ -13,8 +13,6 @@ extern "C" {
#include <string>
#include <stdio.h>
#include "http_data.h"
#include "collection.h"
#include "collection_manager.h"
struct request_response {
http_req* req;
@ -44,6 +42,8 @@ private:
std::string ssl_cert_key_path;
bool (*auth_handler)(const route_path & rpath, const std::string & auth_key);
static void on_accept(h2o_socket_t *listener, const char *err);
int setup_ssl(const char *cert_file, const char *key_file);
@ -71,6 +71,8 @@ public:
~HttpServer();
void set_auth_handler(bool (*handler)(const route_path & rpath, const std::string & auth_key));
void get(const std::string & path, void (*handler)(http_req & req, http_res & res), bool authenticated, bool async = false);
void post(const std::string & path, void (*handler)(http_req & req, http_res & res), bool authenticated, bool async = false);
@ -95,5 +97,4 @@ public:
static constexpr const char* AUTH_HEADER = "x-typesense-api-key";
static constexpr const char* STOP_SERVER_MESSAGE = "STOP_SERVER";
};

View File

@ -7,6 +7,15 @@
#include "collection.h"
#include "collection_manager.h"
bool handle_authentication(const route_path & rpath, const std::string & auth_key) {
CollectionManager & collectionManager = CollectionManager::get_instance();
if(rpath.handler == get_search) {
return collectionManager.auth_key_matches(auth_key) || collectionManager.search_only_auth_key_matches(auth_key);
}
return collectionManager.auth_key_matches(auth_key);
}
void get_collections(http_req & req, http_res & res) {
CollectionManager & collectionManager = CollectionManager::get_instance();
std::vector<Collection*> collections = collectionManager.get_collections();
@ -141,15 +150,15 @@ void get_search(http_req & req, http_res & res) {
}
if(!StringUtils::is_uint64_t(req.params[NUM_TYPOS])) {
return res.send_400("Parameter `" + NUM_TYPOS + "` must be an unsigned integer.");
return res.send_400("Parameter `" + std::string(NUM_TYPOS) + "` must be an unsigned integer.");
}
if(!StringUtils::is_uint64_t(req.params[PER_PAGE])) {
return res.send_400("Parameter `" + PER_PAGE + "` must be an unsigned integer.");
return res.send_400("Parameter `" + std::string(PER_PAGE) + "` must be an unsigned integer.");
}
if(!StringUtils::is_uint64_t(req.params[PAGE])) {
return res.send_400("Parameter `" + PAGE + "` must be an unsigned integer.");
return res.send_400("Parameter `" + std::string(PAGE) + "` must be an unsigned integer.");
}
std::string filter_str = req.params.count(FILTER) != 0 ? req.params[FILTER] : "";

View File

@ -36,9 +36,11 @@ void CollectionManager::add_to_collections(Collection* collection) {
collection_id_names.emplace(collection->get_collection_id(), collection->get_name());
}
Option<bool> CollectionManager::init(Store *store, const std::string & auth_key) {
Option<bool> CollectionManager::init(Store *store, const std::string & auth_key,
const std::string & search_only_auth_key) {
this->store = store;
this->auth_key = auth_key;
this->search_only_auth_key = search_only_auth_key;
std::string next_collection_id_str;
store->get(NEXT_COLLECTION_ID_KEY, next_collection_id_str);
@ -112,6 +114,10 @@ bool CollectionManager::auth_key_matches(std::string auth_key_sent) {
return (auth_key == auth_key_sent);
}
bool CollectionManager::search_only_auth_key_matches(std::string auth_key_sent) {
return (search_only_auth_key == auth_key_sent);
}
Option<Collection*> CollectionManager::create_collection(std::string name, const std::vector<field> & fields,
const std::string & token_ranking_field) {
if(store->contains(Collection::get_meta_key(name))) {

View File

@ -5,6 +5,7 @@
#include <thread>
#include <signal.h>
#include <h2o.h>
#include <iostream>
struct h2o_custom_req_handler_t {
h2o_handler_t super;
@ -261,6 +262,15 @@ int HttpServer::catch_all_handler(h2o_handler_t *_self, h2o_req_t *req) {
std::map<std::string, std::string> query_map = parse_query(query_str);
const std::string & req_body = std::string(req->entity.base, req->entity.len);
// extract auth key from header if present
std::string auth_key_from_header = "";
ssize_t auth_header_cursor = h2o_find_header_by_str(&req->headers, AUTH_HEADER,
strlen(AUTH_HEADER), -1);
if(auth_header_cursor != -1) {
h2o_iovec_t & slot = req->headers.entries[auth_header_cursor].value;
auth_key_from_header = std::string(slot.base, slot.len);
}
for(const route_path & rpath: self->http_server->routes) {
if(rpath.path_parts.size() != path_parts.size() || rpath.http_method != http_method) {
continue;
@ -280,7 +290,12 @@ int HttpServer::catch_all_handler(h2o_handler_t *_self, h2o_req_t *req) {
check_next_route:
if(found) {
// routes match - iterate and extract path params
bool authenticated = self->http_server->auth_handler(rpath, auth_key_from_header);
if(!authenticated) {
return send_401_unauthorized(req);
}
// routes match and is an authenticated request - iterate and extract path params
for(size_t i = 0; i < rpath.path_parts.size(); i++) {
const std::string & path_part = rpath.path_parts[i];
if(path_part[0] == ':') {
@ -288,25 +303,6 @@ int HttpServer::catch_all_handler(h2o_handler_t *_self, h2o_req_t *req) {
}
}
if(rpath.authenticated) {
CollectionManager & collectionManager = CollectionManager::get_instance();
ssize_t auth_header_cursor = h2o_find_header_by_str(&req->headers, AUTH_HEADER,
strlen(AUTH_HEADER), -1);
if(auth_header_cursor == -1) {
// requires authentication, but API Key is not present in the headers
return send_401_unauthorized(req);
} else {
// api key is found, let's validate
h2o_iovec_t & slot = req->headers.entries[auth_header_cursor].value;
std::string auth_key_from_header = std::string(slot.base, slot.len);
if(!collectionManager.auth_key_matches(auth_key_from_header)) {
return send_401_unauthorized(req);
}
}
}
http_req* request = new http_req{req, query_map, req_body};
http_res* response = new http_res();
response->server = self->http_server;
@ -366,6 +362,10 @@ int HttpServer::send_401_unauthorized(h2o_req_t *req) {
return 0;
}
void HttpServer::set_auth_handler(bool (*handler)(const route_path & rpath, const std::string & auth_key)) {
auth_handler = handler;
}
void HttpServer::get(const std::string & path, void (*handler)(http_req &, http_res &), bool authenticated, bool async) {
std::vector<std::string> path_parts;
StringUtils::split(path, path_parts, "/");

View File

@ -20,14 +20,16 @@ void catch_interrupt(int sig) {
int main(int argc, char **argv) {
cmdline::parser options;
options.add<std::string>("data-dir", 'd', "Directory where data will be stored.", true);
options.add<std::string>("api-auth-key", 'a', "Key for authenticating the API endpoints.", true);
options.add<std::string>("api-key", 'k', "API key that allows all operations.", true);
options.add<std::string>("search-only-api-key", 's', "API key that allows only searches.", false);
options.add<std::string>("listen-address", 'h', "Address to which Typesense server binds.", false, "0.0.0.0");
options.add<uint32_t>("listen-port", 'p', "Port on which Typesense server listens.", false, 8108);
options.add<std::string>("master", 'm', "Master host in http(s)://<master_address>:<master_port> format "
"to start the server as a read-only replica.", false, "");
options.add<std::string>("ssl-certificate", 'c', "Path to the SSL certificate file.", false, "");
options.add<std::string>("ssl-certificate-key", 'k', "Path to the SSL certificate key file.", false, "");
options.add<std::string>("ssl-certificate-key", 'e', "Path to the SSL certificate key file.", false, "");
options.parse_check(argc, argv);
@ -35,7 +37,8 @@ int main(int argc, char **argv) {
Store store(options.get<std::string>("data-dir"));
CollectionManager & collectionManager = CollectionManager::get_instance();
Option<bool> init_op = collectionManager.init(&store, options.get<std::string>("api-auth-key"));
Option<bool> init_op = collectionManager.init(&store, options.get<std::string>("api-key"),
options.get<std::string>("search-only-api-key"));
if(init_op.ok()) {
std::cout << "Finished restoring all collections from disk." << std::endl;
@ -52,6 +55,7 @@ int main(int argc, char **argv) {
);
// collection management
server->set_auth_handler(handle_authentication);
server->post("/collections", post_create_collection, true);
server->get("/collections", get_collections, true);
server->del("/collections/:collection", del_drop_collection, true);