typesense/include/ratelimit_manager.h
Ozan Armağan 07b456c229
Rate limit API updates & improvements (#846)
* Rate limit procedure updates & improvements

* Review changes

* Update for adding support to delete temp throttles

* Review Changes II

* Review Changes III
2023-01-31 20:57:02 +05:30

304 lines
10 KiB
C++

#pragma once
#include <string>
#include <vector>
#include <tuple>
#include <unordered_map>
#include <mutex>
#include <shared_mutex>
#include <json.hpp>
#include <magic_enum.hpp>
#include "lru/lru.hpp"
#include "option.h"
#include "store.h"
// Action enum for rate limit rules
enum class RateLimitAction {
allow,
block,
throttle
};
// Entity type enum for rate limit rules
enum class RateLimitedEntityType {
ip,
api_key
};
// overload operator! to get inverse of RateLimitedEntityType
inline RateLimitedEntityType operator!(const RateLimitedEntityType& entity_type) {
switch (entity_type) {
case RateLimitedEntityType::ip:
return RateLimitedEntityType::api_key;
case RateLimitedEntityType::api_key:
return RateLimitedEntityType::ip;
}
}
// Max requests struct for rate limit rules
struct rate_limit_max_requests_t {
int64_t minute_threshold = -1;
int64_t hour_threshold = -1;
};
// Entry struct for rate limit rule pointer hash map as key
struct rate_limit_entity_t {
RateLimitedEntityType entity_type;
std::string entity_id;
// Equality operator for rate_limit_entity_t
bool operator==(const rate_limit_entity_t& other) const {
return std::tie(entity_type, entity_id) == std::tie(other.entity_type, other.entity_id);
}
};
// Struct for rate limit rules
struct rate_limit_rule_t {
uint32_t id;
RateLimitAction action;
std::vector<rate_limit_entity_t> entities;
rate_limit_max_requests_t max_requests;
int64_t auto_ban_1m_threshold = -1;
int64_t auto_ban_1m_duration_hours = -1;
bool apply_limit_per_entity = false;
uint32_t priority = 0;
const nlohmann::json to_json() const;
};
// Request counter struct for ip addresses to keep track of requests for current and previous sampling period
struct request_counter_t {
int64_t current_requests_count_minute = 0;
int64_t current_requests_count_hour = 0;
int64_t previous_requests_count_minute = 0;
int64_t previous_requests_count_hour = 0;
int64_t threshold_exceed_count_minute = 0;
time_t last_reset_time_minute = 0;
time_t last_reset_time_hour = 0;
void reset() {
current_requests_count_minute = 0;
current_requests_count_hour = 0;
threshold_exceed_count_minute = 0;
previous_requests_count_minute = 0;
previous_requests_count_hour = 0;
last_reset_time_minute = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
last_reset_time_hour = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
}
// not-equal operator overload
bool operator!=(const request_counter_t& other) const{
return std::tie(current_requests_count_minute, current_requests_count_hour, previous_requests_count_minute, previous_requests_count_hour, last_reset_time_minute, last_reset_time_hour) !=
std::tie(other.current_requests_count_minute, other.current_requests_count_hour, other.previous_requests_count_minute, other.previous_requests_count_hour, other.last_reset_time_minute, other.last_reset_time_hour);
}
};
// Struct to store ban information for ip addresses
struct rate_limit_status_t {
uint32_t status_id;
int64_t throttling_from;
int64_t throttling_to;
rate_limit_entity_t entity;
// optional second entity for AND bans
Option<rate_limit_entity_t> and_entity = Option<rate_limit_entity_t>(404, "Not Found");
rate_limit_status_t(const uint32_t status_id, const int64_t throttling_from, const int64_t throttling_to, const rate_limit_entity_t &entity, const rate_limit_entity_t* and_entity = nullptr) : status_id(status_id), throttling_from(throttling_from), throttling_to(throttling_to), entity(entity) {
if (and_entity != nullptr) {
this->and_entity = Option<rate_limit_entity_t>(*and_entity);
}
}
// default constructor
rate_limit_status_t() = default;
const nlohmann::json to_json() const;
void parse_json(const nlohmann::json& json);
};
// Struct to store how many requests made by exceeded rate limit entities
struct rate_limit_exceed_t {
uint32_t rule_id;
std::string entities;
uint64_t request_count = 0;
const nlohmann::json to_json() const {
nlohmann::json json;
std::string api_key = entities.substr(0, entities.find("_"));
std::string ip = entities.substr(entities.find("_") + 1);
json["id"] = rule_id;
json["api_key"] = api_key;
json["ip"] = ip;
json["request_count"] = request_count;
return json;
}
};
// Hash function for rate_limit_entity_t
namespace std {
template <>
struct hash<rate_limit_entity_t> {
std::size_t operator()(const rate_limit_entity_t& k) const{
return ((std::hash<int>()(static_cast<int>(k.entity_type)) ^ (std::hash<std::string>()(k.entity_id) << 1)) >> 1);
}
};
}
class RateLimitManager
{
public:
RateLimitManager(const RateLimitManager&) = delete;
RateLimitManager& operator=(const RateLimitManager&) = delete;
RateLimitManager(RateLimitManager&&) = delete;
RateLimitManager& operator=(RateLimitManager&&) = delete;
static RateLimitManager* getInstance();
// Get vector of banned entities
const std::vector<rate_limit_status_t> get_banned_entities(const RateLimitedEntityType entity_type);
// Check if request is rate limited for given entities
bool is_rate_limited(const rate_limit_entity_t& api_key_entity, const rate_limit_entity_t& ip_entity);
// Add rule by JSON
Option<nlohmann::json> add_rule(const nlohmann::json &rule_json);
// Edit rule by JSON
Option<nlohmann::json> edit_rule(const uint64_t id, const nlohmann::json &rule_json);
// Find rule by ID
Option<nlohmann::json> find_rule_by_id(const uint64_t id);
// Delete rule by ID
bool delete_rule_by_id(const uint64_t id);
// Delete ban by ID
bool delete_ban_by_id(const uint64_t id);
// Delete throttle by ID
bool delete_throttle_by_id(const uint32_t id);
// Get All rules as vector
const std::vector<rate_limit_rule_t> get_all_rules();
// Get all rules as json
const nlohmann::json get_all_rules_json();
// Get exceeded entities and request counts as JSON
const nlohmann::json get_exceeded_entities_json();
// Get autobanned entities as JSON
const nlohmann::json get_throttled_entities_json();
// Clear all rules
void clear_all();
// Internal function to set base time
void _set_base_timestamp(const time_t& base_time);
// Set store
Option<bool> init(Store* store);
private:
RateLimitManager() {
rate_limit_request_counts.capacity(10000);
}
// Store for rate limit rules
Store *store;
// Using a $ prefix so that these meta keys stay above record entries in a lexicographically ordered KV store
// Prefix for rate limit rules
static constexpr const char* RULES_NEXT_ID = "$RLN";
static constexpr const char* RULES_PREFIX = "$RLRP";
// Prefix for bans
static constexpr const char* BANS_NEXT_ID = "$RLBN";
static constexpr const char* BANS_PREFIX = "$RLBP";
// Static instance of wildcard entities
inline static const rate_limit_entity_t WILDCARD_IP = rate_limit_entity_t{RateLimitedEntityType::ip, ".*"};
inline static const rate_limit_entity_t WILDCARD_API_KEY = rate_limit_entity_t{RateLimitedEntityType::api_key, ".*"};
// ID of latest added rule
inline static uint32_t last_rule_id = 0;
// ID of latest added ban
inline static uint32_t last_ban_id = 0;
// ID of latest added throttle
inline static uint32_t last_throttle_id = 0;
// Store for rate_limit_rule_t
std::unordered_map<uint64_t,rate_limit_rule_t> rule_store;
// LRU Cache to store rate limit and request counts for entities
LRU::Cache<std::string, request_counter_t> rate_limit_request_counts;
// Unordered map to point rules from rule store for entities
std::unordered_map<rate_limit_entity_t, std::vector<rate_limit_rule_t*>> rate_limit_entities;
// Unordered map to store banned entities
std::unordered_map<std::string, rate_limit_status_t> throttled_entities;
// Hash map to store exceeds
std::unordered_map<std::string, rate_limit_exceed_t> rate_limit_exceeds;
// Mutex to protect access to ip_rate_limits and api_key_rate_limits
std::shared_mutex rate_limit_mutex;
// Helper function to ban an entity temporarily
void temp_ban_entity(const rate_limit_entity_t& entity, const uint64_t number_of_hours);
// Helper function to ban an entity temporarily without locking mutex
void temp_ban_entity_wrapped(const rate_limit_entity_t& entity, const uint64_t number_of_hours, const rate_limit_entity_t* and_entity = nullptr);
// Helper function to check if JSON rule is valid
Option<bool> is_valid_rule(const nlohmann::json &rule_json);
// Parse JSON rule to rate_limit_rule_t
static Option<rate_limit_rule_t> parse_rule(const nlohmann::json &rule_json);
// Helper function to insert rule in store
void insert_rule(const rate_limit_rule_t &rule);
// Helper function to get rule key for DB store from ID
std::string get_rule_key(const uint32_t id);
// Helper function to get ban key for DB store from ID
std::string get_ban_key(const uint32_t id);
// Base timestamp
time_t base_timestamp = 0;
// Helper function to get current timestamp
time_t get_current_time();
// Helper function to get throttle key for entity if exists
Option<std::string> get_throttle_key(const rate_limit_entity_t& ip_entity, const rate_limit_entity_t& api_key_entity);
// Helper function to get request counter key according to rule type
static const std::string get_request_counter_key(const rate_limit_rule_t& rule, const rate_limit_entity_t& ip_entity, const rate_limit_entity_t& api_key_entity);
// Fill bucket rule for the given entity
void fill_bucket(const rate_limit_entity_t& target_entity, const rate_limit_entity_t& other_entity, std::vector<rate_limit_rule_t*> &rules_bucket);
// Singleton instance
inline static RateLimitManager *instance;
};