Add a TimedCertificateReloader

This commit is contained in:
Gus Cairo 2025-04-30 14:24:59 +01:00
parent 0fc472ba34
commit 554f22e189
3 changed files with 616 additions and 1 deletions

View File

@ -257,6 +257,27 @@ var targets: [PackageDescription.Target] = [
],
swiftSettings: strictConcurrencySettings
),
.target(
name: "NIOCertificateHotReloading",
dependencies: [
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "X509", package: "swift-certificates"),
.product(name: "SwiftASN1", package: "swift-asn1")
],
swiftSettings: strictConcurrencySettings
),
.testTarget(
name: "NIOCertificateHotReloadingTests",
dependencies: [
"NIOCertificateHotReloading",
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "X509", package: "swift-certificates"),
.product(name: "SwiftASN1", package: "swift-asn1")
],
swiftSettings: strictConcurrencySettings
),
]
let package = Package(
@ -270,6 +291,7 @@ let package = Package(
.library(name: "NIOHTTPTypesHTTP2", targets: ["NIOHTTPTypesHTTP2"]),
.library(name: "NIOResumableUpload", targets: ["NIOResumableUpload"]),
.library(name: "NIOHTTPResponsiveness", targets: ["NIOHTTPResponsiveness"]),
.library(name: "NIOCertificateHotReloading", targets: ["NIOCertificateHotReloading"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"),
@ -278,7 +300,9 @@ let package = Package(
.package(url: "https://github.com/apple/swift-http-structured-headers.git", from: "1.2.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"),
.package(url: "https://github.com/apple/swift-algorithms.git", from: "1.2.0"),
.package(url: "https://github.com/apple/swift-certificates.git", branch: "1.10.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.29.3"),
.package(url: "https://github.com/apple/swift-asn1.git", from: "1.3.1"),
],
targets: targets
)

View File

@ -0,0 +1,278 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import X509
import SwiftASN1
import NIOSSL
import struct NIOCore.TimeAmount
import NIOConcurrencyHelpers
#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif
/// A protocol that defines a certificate reloader.
///
/// A certificate reloader is a service that can provide you with updated versions of a certificate and private key pair, in
/// the form of a `NIOSSLContextConfigurationOverride`, which will be used when performing a TLS handshake in NIO.
/// Each implementation can choose how to observe for changes, but they all require an ``sslContextConfigurationOverride``
/// to be exposed.
public protocol CertificateReloader: Sendable {
/// A `NIOSSLContextConfigurationOverride` that will be used as part of the NIO application's TLS configuration.
/// Its certificate and private key will be kept up-to-date via whatever mechanism the specific ``CertificateReloader``
/// implementation provides.
var sslContextConfigurationOverride: NIOSSLContextConfigurationOverride { get }
}
/// A ``TimedCertificateReloader`` is an implementation of a ``CertificateReloader``, where the certificate and private
/// key pair is updated at a fixed interval from the file path or memory location configured.
///
/// You initialize a ``TimedCertificateReloader`` by providing a refresh interval, and locations for the certificate and the private
/// key. You must then call ``run()`` on this reloader for it to start observing changes.
/// Once the reloader is running, call ``sslContextConfigurationOverride`` to get a
/// `NIOSSLContextConfigurationOverride` which can be set on NIO's `TLSConfiguration`: this will keep the certificate
/// and private key pair up to date.
/// You may instead call `TLSConfiguration.withAutomaticCertificateReloading(using:)` to get a
/// `TLSConfiguration` with a configured reloader.
///
/// If any errors occur during a reload attempt (such as: being unable to find the file(s) containing the certificate or the key; the format
/// not being recognized or not matching the configured one; not being able to verify a certificate's signature against the given
/// private key; etc), then that attempt will be aborted but the service will keep on trying at the configured interval.
/// The last-valid certificate-key pair (if any) will be returned as the ``sslContextConfigurationOverride``.
@available(macOS 11.0, iOS 14, tvOS 14, watchOS 7, *)
public struct TimedCertificateReloader: CertificateReloader {
/// The encoding for the certificate or the key.
public struct Encoding: Sendable, Equatable {
fileprivate enum _Backing {
case der
case pem
}
fileprivate let _backing: _Backing
private init(_ backing: _Backing) {
self._backing = backing
}
/// The encoding of this certificate/key is DER bytes.
public static let der = Encoding(.der)
/// The encoding of this certificate/key is PEM.
public static let pem = Encoding(.pem)
}
/// A location specification for a certificate or key.
public struct Location: Sendable {
fileprivate enum _Backing {
case file(path: String)
case memory(provider: @Sendable () -> [UInt8]?)
}
fileprivate let _backing: _Backing
private init(_ backing: _Backing) {
self._backing = backing
}
/// This certificate/key can be found at the given filepath.
/// - Parameter path: The filepath where the certificate/key can be found.
/// - Returns: A `Location`.
public static func file(path: String) -> Self { Self(_Backing.file(path: path)) }
/// This certificate/key is available in memory, and will be provided by the given closure.
/// - Parameter provider: A closure providing the bytes for the given certificate or key.
/// - Returns: A `Location`.
public static func memory(provider: @Sendable @escaping () -> [UInt8]?) -> Self {
Self(_Backing.memory(provider: provider))
}
}
/// A description of a certificate, in terms of its ``Location`` and ``Encoding``.
public struct CertificateDescription: Sendable {
public var location: Location
public var format: Encoding
public init(location: Location, format: Encoding) {
self.location = location
self.format = format
}
}
/// A description of a private key, in terms of its ``Location`` and ``Encoding``.
public struct PrivateKeyDescription: Sendable {
public var location: Location
public var format: Encoding
public init(location: Location, format: Encoding) {
self.location = location
self.format = format
}
}
private struct CertificateKeyPair {
var certificate: NIOSSLCertificateSource
var privateKey: NIOSSLPrivateKeySource
}
private let refreshInterval: TimeAmount
private let certificateDescription: CertificateDescription
private let privateKeyDescription: PrivateKeyDescription
private let state: NIOLockedValueBox<CertificateKeyPair?>
/// A `NIOSSLContextConfigurationOverride` that will be used as part of the NIO application's TLS configuration.
/// Its certificate and private key will be kept up-to-date via the reload mechanism the ``TimedCertificateReloader``
/// implementation provides.
public var sslContextConfigurationOverride: NIOSSLContextConfigurationOverride {
get {
var override = NIOSSLContextConfigurationOverride()
guard let certificateKeyPair = self.state.withLockedValue({ $0 }) else {
return override
}
override.certificateChain = [certificateKeyPair.certificate]
override.privateKey = certificateKeyPair.privateKey
return override
}
}
/// Initialize a new ``TimedCertificateReloader``.
/// - Parameters:
/// - refreshInterval: The interval at which attempts to update the certificate and private key should be made.
/// - certificateDescription: A ``CertificateDescription``.
/// - privateKeyDescription: A ``PrivateKeyDescription``.
public init(
refreshingEvery refreshInterval: TimeAmount,
certificateDescription: CertificateDescription,
privateKeyDescription: PrivateKeyDescription
) {
self.refreshInterval = refreshInterval
self.certificateDescription = certificateDescription
self.privateKeyDescription = privateKeyDescription
// TODO: try parsing key and cert here too
self.state = NIOLockedValueBox(nil)
}
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
/// private key pair.
/// - Important: You *must* call this method to get certificate and key updates.
public func run() async throws {
while !Task.isCancelled {
try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds))
let certificateBytes: [UInt8]?
switch self.certificateDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
certificateBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
certificateBytes = bytesProvider()
}
let keyBytes: [UInt8]?
switch self.privateKeyDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
keyBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
keyBytes = bytesProvider()
}
if let certificateBytes, let keyBytes {
let certificate: Certificate
switch self.certificateDescription.format._backing {
case .der:
let parsedCertificate = try? Certificate(derEncoded: Array(certificateBytes))
guard let parsedCertificate else {
// could not parse certificate, ignore this update
continue
}
certificate = parsedCertificate
case .pem:
let parsedCertificate = String(bytes: certificateBytes, encoding: .utf8)
.flatMap { try? Certificate(pemEncoded: $0) }
guard let parsedCertificate else {
// could not parse certificate, ignore this update
continue
}
certificate = parsedCertificate
}
let key: Certificate.PrivateKey
switch self.privateKeyDescription.format._backing {
case .der:
let parsedKey = try? Certificate.PrivateKey(derBytes: Array(keyBytes))
guard let parsedKey else {
// could not parse key, ignore the update
continue
}
key = parsedKey
case .pem:
let parsedKey = String(bytes: keyBytes, encoding: .utf8)
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
guard let parsedKey else {
// could not parse key, ignore the update
continue
}
key = parsedKey
}
if key.publicKey.isValidSignature(certificate.signature, for: certificate) {
let nioSSLCertificate = try? NIOSSLCertificate(
bytes: certificate.serializeAsPEM().derBytes,
format: .der
)
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
bytes: key.serializeAsPEM().derBytes,
format: .der
)
guard let nioSSLCertificate, let nioSSLPrivateKey else {
continue
}
self.state.withLockedValue {
$0 = CertificateKeyPair(
certificate: .certificate(nioSSLCertificate),
privateKey: .privateKey(nioSSLPrivateKey)
)
}
}
}
}
}
}
extension TLSConfiguration {
/// Configure a ``CertificateReloader`` to observe updates for the certificate and key pair used.
/// - Parameter reloader: A ``CertificateReloader`` to watch for certificate and key pair updates.
/// - Returns: A `TLSConfiguration` that reloads the certificate and key used in its SSL handshake.
mutating public func withAutomaticCertificateReloading(using reloader: any CertificateReloader) -> Self {
self.sslContextCallback = { _, promise in
promise.succeed(reloader.sslContextConfigurationOverride)
}
return self
}
}

View File

@ -0,0 +1,313 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCertificateHotReloading
import Foundation
import XCTest
import X509
import NIOSSL
import NIOConcurrencyHelpers
@preconcurrency import Crypto
final class TimedCertificateReloaderTests: XCTestCase {
func testCertificatePathDoesNotExist() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(location: .file(path: "doesnotexist"), format: .der),
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
}
}
func testKeyPathDoesNotExist() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
format: .der
),
privateKey: .init(
location: .file(path: "doesnotexist"),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
}
}
func testCertificateIsInUnexpectedFormat() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
format: .pem
),
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
}
}
func testKeyIsInUnexpectedFormat() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
format: .der
),
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .pem
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
}
}
func testCertificateAndKeyDoNotMatch() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
format: .der
),
privateKey: .init(
location: .memory(provider: { Array(P384.Signing.PrivateKey().derRepresentation) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
}
}
func testReloadSuccessfully() async throws {
let certificateBox = NIOLockedValueBox([UInt8]())
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
format: .der
),
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
var override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes })
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
}
}
func testCertificateNotFoundAtReload() async throws {
let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(
try! Self.sampleCert.serializeAsPEM().derBytes
)
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
format: .der
),
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
var override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
certificateBox.withLockedValue({ $0 = nil })
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
}
}
func testKeyNotFoundAtReload() async throws {
let keyBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(
Array(Self.samplePrivateKey.derRepresentation)
)
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try! Self.sampleCert.serializeAsPEM().derBytes }),
format: .der
),
privateKey: .init(
location: .memory(provider: { keyBox.withLockedValue({ $0 }) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
var override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
keyBox.withLockedValue({ $0 = nil })
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
}
}
func testCertificateAndKeyDoNotMatchOnReload() async throws {
let keyBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(
Array(Self.samplePrivateKey.derRepresentation)
)
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try! Self.sampleCert.serializeAsPEM().derBytes }),
format: .der
),
privateKey: .init(
location: .memory(provider: { keyBox.withLockedValue({ $0 }) }),
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
var override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) })
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
)
XCTAssertEqual(
override.privateKey,
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
}
}
static let startDate = Date()
static let samplePrivateKey = P384.Signing.PrivateKey()
static let sampleCertName = try! DistinguishedName {
CountryName("US")
OrganizationName("Apple")
CommonName("Swift Certificate Test")
}
static let sampleCert: Certificate = {
return try! Certificate(
version: .v3,
serialNumber: .init(),
publicKey: .init(samplePrivateKey.publicKey),
notValidBefore: startDate.advanced(by: -60 * 60 * 24 * 360),
notValidAfter: startDate.advanced(by: 60 * 60 * 24 * 360),
issuer: sampleCertName,
subject: sampleCertName,
signatureAlgorithm: .ecdsaWithSHA384,
extensions: Certificate.Extensions {
Critical(
BasicConstraints.isCertificateAuthority(maxPathLength: nil)
)
},
issuerPrivateKey: .init(samplePrivateKey)
)
}()
private func runTimedCertificateReloaderTest(
certificate: TimedCertificateReloader.CertificateDescription,
privateKey: TimedCertificateReloader.PrivateKeyDescription,
_ body: @escaping @Sendable (TimedCertificateReloader) async throws -> Void
) async throws {
let reloader = TimedCertificateReloader(
refreshingEvery: .milliseconds(50),
certificateDescription: .init(
location: certificate.location,
format: certificate.format
),
privateKeyDescription: .init(location: privateKey.location, format: privateKey.format)
)
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await reloader.run()
}
group.addTask {
try await body(reloader)
}
try await group.next()
group.cancelAll()
}
}
}