foundationdb/tests/TestRunner/local_cluster.py
Junhyun Shim 02b2f97e99
mTLS test helpers (#7218)
* Add TLS option to (local_cluster|tmp_cluster).py

* Add TLS-enabled C API test
2022-05-23 12:47:51 +02:00

483 lines
19 KiB
Python

import json
from pathlib import Path
import random
import string
import subprocess
import os
import socket
import time
CLUSTER_UPDATE_TIMEOUT_SEC = 10
EXCLUDE_SERVERS_TIMEOUT_SEC = 120
RETRY_INTERVAL_SEC = 0.5
def _get_free_port_internal():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("0.0.0.0", 0))
return s.getsockname()[1]
_used_ports = set()
def get_free_port():
global _used_ports
port = _get_free_port_internal()
while port in _used_ports:
port = _get_free_port_internal()
_used_ports.add(port)
return port
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
valid_letters_for_secret = string.ascii_letters + string.digits
class TLSConfig:
# Passing a negative chain length generates expired leaf certificate
def __init__(
self,
server_chain_len: int = 3,
client_chain_len: int = 2,
verify_peers = "Check.Valid=1",
):
self.server_chain_len = server_chain_len
self.client_chain_len = client_chain_len
self.verify_peers = verify_peers
def random_secret_string(length):
return "".join(random.choice(valid_letters_for_secret) for _ in range(length))
class LocalCluster:
configuration_template = """
## foundationdb.conf
##
## Configuration file for FoundationDB server processes
## Full documentation is available at
## https://apple.github.io/foundationdb/configuration.html#the-configuration-file
[fdbmonitor]
[general]
restart-delay = 10
## by default, restart-backoff = restart-delay-reset-interval = restart-delay
# initial-restart-delay = 0
# restart-backoff = 60
# restart-delay-reset-interval = 60
cluster-file = {etcdir}/fdb.cluster
# delete-envvars =
# kill-on-configuration-change = true
## Default parameters for individual fdbserver processes
[fdbserver]
command = {fdbserver_bin}
public-address = {ip_address}:$ID{optional_tls}
listen-address = public
datadir = {datadir}/$ID
logdir = {logdir}
{bg_knob_line}
{tls_config}
# logsize = 10MiB
# maxlogssize = 100MiB
# machine-id =
# datacenter-id =
# class =
# memory = 8GiB
# storage-memory = 1GiB
# cache-memory = 2GiB
# metrics-cluster =
# metrics-prefix =
## An individual fdbserver process with id 4000
## Parameters set here override defaults from the [fdbserver] section
"""
def __init__(
self,
basedir: str,
fdbserver_binary: str,
fdbmonitor_binary: str,
fdbcli_binary: str,
process_number: int,
create_config=True,
port=None,
ip_address=None,
blob_granules_enabled: bool = False,
redundancy: str = "single",
tls_config: TLSConfig = None,
mkcert_binary: str = "",
):
self.basedir = Path(basedir)
self.etc = self.basedir.joinpath("etc")
self.log = self.basedir.joinpath("log")
self.data = self.basedir.joinpath("data")
self.cert = self.basedir.joinpath("cert")
self.conf_file = self.etc.joinpath("foundationdb.conf")
self.cluster_file = self.etc.joinpath("fdb.cluster")
self.fdbserver_binary = Path(fdbserver_binary)
self.fdbmonitor_binary = Path(fdbmonitor_binary)
self.fdbcli_binary = Path(fdbcli_binary)
for b in (self.fdbserver_binary, self.fdbmonitor_binary, self.fdbcli_binary):
assert b.exists(), "{} does not exist".format(b)
self.etc.mkdir(exist_ok=True)
self.log.mkdir(exist_ok=True)
self.data.mkdir(exist_ok=True)
self.process_number = process_number
self.redundancy = redundancy
self.ip_address = "127.0.0.1" if ip_address is None else ip_address
self.first_port = port
self.blob_granules_enabled = blob_granules_enabled
if blob_granules_enabled:
# add extra process for blob_worker
self.process_number += 1
if self.first_port is not None:
self.last_used_port = int(self.first_port) - 1
self.server_ports = {server_id: self.__next_port() for server_id in range(self.process_number)}
self.server_by_port = {port: server_id for server_id, port in self.server_ports.items()}
self.next_server_id = self.process_number
self.cluster_desc = random_secret_string(8)
self.cluster_secret = random_secret_string(8)
self.env_vars = {}
self.running = False
self.process = None
self.fdbmonitor_logfile = None
self.use_legacy_conf_syntax = False
self.coordinators = set()
self.active_servers = set(self.server_ports.keys())
self.tls_config = tls_config
self.mkcert_binary = Path(mkcert_binary)
self.server_cert_file = self.cert.joinpath("server_cert.pem")
self.client_cert_file = self.cert.joinpath("client_cert.pem")
self.server_key_file = self.cert.joinpath("server_key.pem")
self.client_key_file = self.cert.joinpath("client_key.pem")
self.server_ca_file = self.cert.joinpath("server_ca.pem")
self.client_ca_file = self.cert.joinpath("client_ca.pem")
if create_config:
self.create_cluster_file()
self.save_config()
if self.tls_config is not None:
self.create_tls_cert()
def __next_port(self):
if self.first_port is None:
return get_free_port()
else:
self.last_used_port += 1
return self.last_used_port
def save_config(self):
new_conf_file = self.conf_file.parent / (self.conf_file.name + ".new")
with open(new_conf_file, "x") as f:
conf_template = LocalCluster.configuration_template
bg_knob_line = ""
if self.use_legacy_conf_syntax:
conf_template = conf_template.replace("-", "_")
if self.blob_granules_enabled:
bg_knob_line = "knob_bg_url=file://" + str(self.data) + "/fdbblob/"
f.write(
conf_template.format(
etcdir=self.etc,
fdbserver_bin=self.fdbserver_binary,
datadir=self.data,
logdir=self.log,
ip_address=self.ip_address,
bg_knob_line=bg_knob_line,
tls_config=self.tls_conf_string(),
optional_tls=":tls" if self.tls_config is not None else "",
)
)
# By default, the cluster only has one process
# If a port number is given and process_number > 1, we will use subsequent numbers
# E.g., port = 4000, process_number = 5
# Then 4000,4001,4002,4003,4004 will be used as ports
# If port number is not given, we will randomly pick free ports
for server_id in self.active_servers:
f.write("[fdbserver.{server_port}]\n".format(server_port=self.server_ports[server_id]))
if self.use_legacy_conf_syntax:
f.write("machine_id = {}\n".format(server_id))
else:
f.write("machine-id = {}\n".format(server_id))
if self.blob_granules_enabled:
# make last process a blob_worker class
f.write("class = blob_worker\n")
f.flush()
os.fsync(f.fileno())
os.replace(new_conf_file, self.conf_file)
def create_cluster_file(self):
with open(self.cluster_file, "x") as f:
f.write(
"{desc}:{secret}@{ip_addr}:{server_port}{optional_tls}".format(
desc=self.cluster_desc,
secret=self.cluster_secret,
ip_addr=self.ip_address,
server_port=self.server_ports[0],
optional_tls=":tls" if self.tls_config is not None else "",
)
)
self.coordinators = {0}
def start_cluster(self):
assert not self.running, "Can't start a server that is already running"
args = [
str(self.fdbmonitor_binary),
"--conffile",
str(self.etc.joinpath("foundationdb.conf")),
"--lockfile",
str(self.etc.joinpath("fdbmonitor.lock")),
]
self.fdbmonitor_logfile = open(self.log.joinpath("fdbmonitor.log"), "w")
self.process = subprocess.Popen(
args,
stdout=self.fdbmonitor_logfile,
stderr=self.fdbmonitor_logfile,
env=self.process_env(),
)
self.running = True
def stop_cluster(self):
assert self.running, "Server is not running"
if self.process.poll() is None:
self.process.terminate()
self.running = False
def ensure_ports_released(self, timeout_sec=5):
sec = 0
while sec < timeout_sec:
in_use = False
for server_id in self.active_servers:
port = self.server_ports[server_id]
if is_port_in_use(port):
print("Port {} in use. Waiting for it to be released".format(port))
in_use = True
break
if not in_use:
return
time.sleep(0.5)
sec += 0.5
assert False, "Failed to release ports in {}s".format(timeout_sec)
def __enter__(self):
self.start_cluster()
return self
def __exit__(self, xc_type, exc_value, traceback):
self.stop_cluster()
def __fdbcli_exec(self, cmd, stdout, stderr, timeout):
args = [self.fdbcli_binary, "-C", self.cluster_file, "--exec", cmd]
if self.tls_config:
args += ["--tls-certificate-file", self.client_cert_file,
"--tls-key-file", self.client_key_file,
"--tls-ca-file", self.server_ca_file]
res = subprocess.run(args, env=self.process_env(), stderr=stderr, stdout=stdout, timeout=timeout)
assert res.returncode == 0, "fdbcli command {} failed with {}".format(cmd, res.returncode)
return res.stdout
# Execute a fdbcli command
def fdbcli_exec(self, cmd, timeout=None):
self.__fdbcli_exec(cmd, None, None, timeout)
# Execute a fdbcli command and return its output
def fdbcli_exec_and_get(self, cmd, timeout=None):
return self.__fdbcli_exec(cmd, subprocess.PIPE, None, timeout)
def create_database(self, storage="ssd", enable_tenants=True):
db_config = "configure new {} {}".format(self.redundancy, storage)
if enable_tenants:
db_config += " tenant_mode=optional_experimental"
if self.blob_granules_enabled:
db_config += " blob_granules_enabled:=1"
self.fdbcli_exec(db_config)
if self.blob_granules_enabled:
self.fdbcli_exec("blobrange start \\x00 \\xff")
# Generate and install test certificate chains and keys
def create_tls_cert(self):
assert self.tls_config is not None, "TLS not enabled"
assert self.mkcert_binary.exists() and self.mkcert_binary.is_file(), "{} does not exist".format(self.mkcert_binary)
self.cert.mkdir(exist_ok=True)
server_chain_len = abs(self.tls_config.server_chain_len)
client_chain_len = abs(self.tls_config.client_chain_len)
expire_server_cert = (self.tls_config.server_chain_len < 0)
expire_client_cert = (self.tls_config.client_chain_len < 0)
args = [
str(self.mkcert_binary),
"--server-chain-length", str(server_chain_len),
"--client-chain-length", str(client_chain_len),
"--server-cert-file", str(self.server_cert_file),
"--client-cert-file", str(self.client_cert_file),
"--server-key-file", str(self.server_key_file),
"--client-key-file", str(self.client_key_file),
"--server-ca-file", str(self.server_ca_file),
"--client-ca-file", str(self.client_ca_file),
"--print-args",
]
if expire_server_cert:
args.append("--expire-server-cert")
if expire_client_cert:
args.append("--expire-client-cert")
subprocess.run(args, check=True)
# Materialize server's TLS configuration section
def tls_conf_string(self):
if self.tls_config is None:
return ""
else:
conf_map = {
"tls-certificate-file": self.server_cert_file,
"tls-key-file": self.server_key_file,
"tls-ca-file": self.client_ca_file,
"tls-verify-peers": self.tls_config.verify_peers,
}
return "\n".join("{} = {}".format(k, v) for k, v in conf_map.items())
# Get cluster status using fdbcli
def get_status(self):
status_output = self.fdbcli_exec_and_get("status json")
return json.loads(status_output)
# Get the set of servers from the cluster status matching the given filter
def get_servers_from_status(self, filter):
status = self.get_status()
if "processes" not in status["cluster"]:
return {}
servers_found = set()
addresses = [proc_info["address"] for proc_info in status["cluster"]["processes"].values() if filter(proc_info)]
for addr in addresses:
port = int(addr.split(":", 1)[1])
assert port in self.server_by_port, "Unknown server port {}".format(port)
servers_found.add(self.server_by_port[port])
return servers_found
# Get the set of all servers from the cluster status
def get_all_servers_from_status(self):
return self.get_servers_from_status(lambda _: True)
# Get the set of all servers with coordinator role from the cluster status
def get_coordinators_from_status(self):
def is_coordinator(proc_status):
return any(entry["role"] == "coordinator" for entry in proc_status["roles"])
return self.get_servers_from_status(is_coordinator)
def process_env(self):
env = dict(os.environ)
env.update(self.env_vars)
return env
def set_env_var(self, var_name, var_val):
self.env_vars[var_name] = var_val
# Add a new server process to the cluster and return its ID
# Need to call save_config to apply the changes
def add_server(self):
server_id = self.next_server_id
assert server_id not in self.server_ports, "Server ID {} is already in use".format(server_id)
self.next_server_id += 1
port = self.__next_port()
self.server_ports[server_id] = port
self.server_by_port[port] = server_id
self.active_servers.add(server_id)
return server_id
# Remove the server with the given ID from the cluster
# Need to call save_config to apply the changes
def remove_server(self, server_id):
assert server_id in self.active_servers, "Server {} does not exist".format(server_id)
self.active_servers.remove(server_id)
# Wait until changes to the set of servers (additions & removals) are applied
def wait_for_server_update(self, timeout=CLUSTER_UPDATE_TIMEOUT_SEC):
time_limit = time.time() + timeout
servers_found = set()
while (time.time() <= time_limit):
servers_found = self.get_all_servers_from_status()
if (servers_found != self.active_servers):
break
time.sleep(RETRY_INTERVAL_SEC)
assert "Failed to apply server changes after {}sec. Expected: {}, Actual: {}".format(
timeout, self.active_servers, servers_found)
# Apply changes to the set of the coordinators, based on the current value of self.coordinators
def update_coordinators(self):
urls = ["{}:{}".format(self.ip_address, self.server_ports[id]) for id in self.coordinators]
self.fdbcli_exec("coordinators {}".format(" ".join(urls)))
# Wait until the changes to the set of the coordinators are applied
def wait_for_coordinator_update(self, timeout=CLUSTER_UPDATE_TIMEOUT_SEC):
time_limit = time.time() + timeout
coord_found = set()
while (time.time() <= time_limit):
coord_found = self.get_coordinators_from_status()
if (coord_found != self.coordinators):
break
time.sleep(RETRY_INTERVAL_SEC)
assert "Failed to apply coordinator changes after {}sec. Expected: {}, Actual: {}".format(
timeout, self.coordinators, coord_found)
# Check if the cluster file was successfully updated too
connection_string = open(self.cluster_file, "r").read()
for server_id in self.coordinators:
assert connection_string.find(str(self.server_ports[server_id])) != -1, \
"Missing coordinator {} port {} in the cluster file".format(server_id, self.server_ports[server_id])
# Exclude the servers with the given ID from the cluster, i.e. move out their data
# The method waits until the changes are applied
def exclude_servers(self, server_ids):
urls = ["{}:{}".format(self.ip_address, self.server_ports[id]) for id in server_ids]
self.fdbcli_exec("exclude FORCE {}".format(" ".join(urls)), timeout=EXCLUDE_SERVERS_TIMEOUT_SEC)
# Perform a cluster wiggle: replace all servers with new ones
def cluster_wiggle(self):
old_servers = self.active_servers.copy()
new_servers = set()
print("Starting cluster wiggle")
print("Old servers: {} on ports {}".format(old_servers, [
self.server_ports[server_id] for server_id in old_servers]))
print("Old coordinators: {}".format(self.coordinators))
# Step 1: add new servers
start_time = time.time()
for _ in range(len(old_servers)):
new_servers.add(self.add_server())
print("New servers: {} on ports {}".format(new_servers, [
self.server_ports[server_id] for server_id in new_servers]))
self.save_config()
self.wait_for_server_update()
print("New servers successfully added to the cluster. Time: {}s".format(time.time()-start_time))
# Step 2: change coordinators
start_time = time.time()
new_coordinators = set(random.sample(new_servers, len(self.coordinators)))
print("New coordinators: {}".format(new_coordinators))
self.coordinators = new_coordinators.copy()
self.update_coordinators()
self.wait_for_coordinator_update()
print("Coordinators successfully changed. Time: {}s".format(time.time()-start_time))
# Step 3: exclude old servers from the cluster, i.e. move out their data
start_time = time.time()
self.exclude_servers(old_servers)
print("Old servers successfully excluded from the cluster. Time: {}s".format(time.time()-start_time))
# Step 4: remove the old servers
start_time = time.time()
for server_id in old_servers:
self.remove_server(server_id)
self.save_config()
self.wait_for_server_update()
print("Old servers successfully removed from the cluster. Time: {}s".format(time.time()-start_time))