diff --git a/Package.swift b/Package.swift index bc8920d..e70b7ad 100644 --- a/Package.swift +++ b/Package.swift @@ -257,6 +257,27 @@ var targets: [PackageDescription.Target] = [ ], swiftSettings: strictConcurrencySettings ), + .target( + name: "NIOCertificateHotReloading", + dependencies: [ + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "X509", package: "swift-certificates"), + .product(name: "SwiftASN1", package: "swift-asn1") + ], + swiftSettings: strictConcurrencySettings + ), + .testTarget( + name: "NIOCertificateHotReloadingTests", + dependencies: [ + "NIOCertificateHotReloading", + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "X509", package: "swift-certificates"), + .product(name: "SwiftASN1", package: "swift-asn1") + ], + swiftSettings: strictConcurrencySettings + ), ] let package = Package( @@ -270,6 +291,7 @@ let package = Package( .library(name: "NIOHTTPTypesHTTP2", targets: ["NIOHTTPTypesHTTP2"]), .library(name: "NIOResumableUpload", targets: ["NIOResumableUpload"]), .library(name: "NIOHTTPResponsiveness", targets: ["NIOHTTPResponsiveness"]), + .library(name: "NIOCertificateHotReloading", targets: ["NIOCertificateHotReloading"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), @@ -278,7 +300,9 @@ let package = Package( .package(url: "https://github.com/apple/swift-http-structured-headers.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.2.0"), - + .package(url: "https://github.com/apple/swift-certificates.git", branch: "1.10.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.29.3"), + .package(url: "https://github.com/apple/swift-asn1.git", from: "1.3.1"), ], targets: targets ) diff --git a/Sources/NIOCertificateHotReloading/NIOCertificateHotReloading.swift b/Sources/NIOCertificateHotReloading/NIOCertificateHotReloading.swift new file mode 100644 index 0000000..47d5cba --- /dev/null +++ b/Sources/NIOCertificateHotReloading/NIOCertificateHotReloading.swift @@ -0,0 +1,278 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import X509 +import SwiftASN1 + +import NIOSSL +import struct NIOCore.TimeAmount +import NIOConcurrencyHelpers +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +/// A protocol that defines a certificate reloader. +/// +/// A certificate reloader is a service that can provide you with updated versions of a certificate and private key pair, in +/// the form of a `NIOSSLContextConfigurationOverride`, which will be used when performing a TLS handshake in NIO. +/// Each implementation can choose how to observe for changes, but they all require an ``sslContextConfigurationOverride`` +/// to be exposed. +public protocol CertificateReloader: Sendable { + /// A `NIOSSLContextConfigurationOverride` that will be used as part of the NIO application's TLS configuration. + /// Its certificate and private key will be kept up-to-date via whatever mechanism the specific ``CertificateReloader`` + /// implementation provides. + var sslContextConfigurationOverride: NIOSSLContextConfigurationOverride { get } +} + +/// A ``TimedCertificateReloader`` is an implementation of a ``CertificateReloader``, where the certificate and private +/// key pair is updated at a fixed interval from the file path or memory location configured. +/// +/// You initialize a ``TimedCertificateReloader`` by providing a refresh interval, and locations for the certificate and the private +/// key. You must then call ``run()`` on this reloader for it to start observing changes. +/// Once the reloader is running, call ``sslContextConfigurationOverride`` to get a +/// `NIOSSLContextConfigurationOverride` which can be set on NIO's `TLSConfiguration`: this will keep the certificate +/// and private key pair up to date. +/// You may instead call `TLSConfiguration.withAutomaticCertificateReloading(using:)` to get a +/// `TLSConfiguration` with a configured reloader. +/// +/// If any errors occur during a reload attempt (such as: being unable to find the file(s) containing the certificate or the key; the format +/// not being recognized or not matching the configured one; not being able to verify a certificate's signature against the given +/// private key; etc), then that attempt will be aborted but the service will keep on trying at the configured interval. +/// The last-valid certificate-key pair (if any) will be returned as the ``sslContextConfigurationOverride``. +@available(macOS 11.0, iOS 14, tvOS 14, watchOS 7, *) +public struct TimedCertificateReloader: CertificateReloader { + /// The encoding for the certificate or the key. + public struct Encoding: Sendable, Equatable { + fileprivate enum _Backing { + case der + case pem + } + fileprivate let _backing: _Backing + + private init(_ backing: _Backing) { + self._backing = backing + } + + /// The encoding of this certificate/key is DER bytes. + public static let der = Encoding(.der) + + /// The encoding of this certificate/key is PEM. + public static let pem = Encoding(.pem) + } + + /// A location specification for a certificate or key. + public struct Location: Sendable { + fileprivate enum _Backing { + case file(path: String) + case memory(provider: @Sendable () -> [UInt8]?) + } + + fileprivate let _backing: _Backing + + private init(_ backing: _Backing) { + self._backing = backing + } + + /// This certificate/key can be found at the given filepath. + /// - Parameter path: The filepath where the certificate/key can be found. + /// - Returns: A `Location`. + public static func file(path: String) -> Self { Self(_Backing.file(path: path)) } + + /// This certificate/key is available in memory, and will be provided by the given closure. + /// - Parameter provider: A closure providing the bytes for the given certificate or key. + /// - Returns: A `Location`. + public static func memory(provider: @Sendable @escaping () -> [UInt8]?) -> Self { + Self(_Backing.memory(provider: provider)) + } + } + + /// A description of a certificate, in terms of its ``Location`` and ``Encoding``. + public struct CertificateDescription: Sendable { + public var location: Location + public var format: Encoding + + public init(location: Location, format: Encoding) { + self.location = location + self.format = format + } + } + + /// A description of a private key, in terms of its ``Location`` and ``Encoding``. + public struct PrivateKeyDescription: Sendable { + public var location: Location + public var format: Encoding + + public init(location: Location, format: Encoding) { + self.location = location + self.format = format + } + } + + private struct CertificateKeyPair { + var certificate: NIOSSLCertificateSource + var privateKey: NIOSSLPrivateKeySource + } + + private let refreshInterval: TimeAmount + private let certificateDescription: CertificateDescription + private let privateKeyDescription: PrivateKeyDescription + private let state: NIOLockedValueBox + + + /// A `NIOSSLContextConfigurationOverride` that will be used as part of the NIO application's TLS configuration. + /// Its certificate and private key will be kept up-to-date via the reload mechanism the ``TimedCertificateReloader`` + /// implementation provides. + public var sslContextConfigurationOverride: NIOSSLContextConfigurationOverride { + get { + var override = NIOSSLContextConfigurationOverride() + guard let certificateKeyPair = self.state.withLockedValue({ $0 }) else { + return override + } + override.certificateChain = [certificateKeyPair.certificate] + override.privateKey = certificateKeyPair.privateKey + return override + } + } + + /// Initialize a new ``TimedCertificateReloader``. + /// - Parameters: + /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. + /// - certificateDescription: A ``CertificateDescription``. + /// - privateKeyDescription: A ``PrivateKeyDescription``. + public init( + refreshingEvery refreshInterval: TimeAmount, + certificateDescription: CertificateDescription, + privateKeyDescription: PrivateKeyDescription + ) { + self.refreshInterval = refreshInterval + self.certificateDescription = certificateDescription + self.privateKeyDescription = privateKeyDescription + + // TODO: try parsing key and cert here too + self.state = NIOLockedValueBox(nil) + } + + /// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and + /// private key pair. + /// - Important: You *must* call this method to get certificate and key updates. + public func run() async throws { + while !Task.isCancelled { + try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds)) + let certificateBytes: [UInt8]? + switch self.certificateDescription.location._backing { + case .file(path: let path): + let bytes = FileManager.default.contents(atPath: path) + certificateBytes = bytes.map { Array($0) } + + case .memory(let bytesProvider): + certificateBytes = bytesProvider() + } + + let keyBytes: [UInt8]? + switch self.privateKeyDescription.location._backing { + case .file(path: let path): + let bytes = FileManager.default.contents(atPath: path) + keyBytes = bytes.map { Array($0) } + + case .memory(let bytesProvider): + keyBytes = bytesProvider() + } + + if let certificateBytes, let keyBytes { + let certificate: Certificate + switch self.certificateDescription.format._backing { + case .der: + let parsedCertificate = try? Certificate(derEncoded: Array(certificateBytes)) + + guard let parsedCertificate else { + // could not parse certificate, ignore this update + continue + } + + certificate = parsedCertificate + + case .pem: + let parsedCertificate = String(bytes: certificateBytes, encoding: .utf8) + .flatMap { try? Certificate(pemEncoded: $0) } + + guard let parsedCertificate else { + // could not parse certificate, ignore this update + continue + } + + certificate = parsedCertificate + } + + let key: Certificate.PrivateKey + switch self.privateKeyDescription.format._backing { + case .der: + let parsedKey = try? Certificate.PrivateKey(derBytes: Array(keyBytes)) + + guard let parsedKey else { + // could not parse key, ignore the update + continue + } + + key = parsedKey + + case .pem: + let parsedKey = String(bytes: keyBytes, encoding: .utf8) + .flatMap { try? Certificate.PrivateKey(pemEncoded: $0) } + + guard let parsedKey else { + // could not parse key, ignore the update + continue + } + + key = parsedKey + } + + if key.publicKey.isValidSignature(certificate.signature, for: certificate) { + let nioSSLCertificate = try? NIOSSLCertificate( + bytes: certificate.serializeAsPEM().derBytes, + format: .der + ) + let nioSSLPrivateKey = try? NIOSSLPrivateKey( + bytes: key.serializeAsPEM().derBytes, + format: .der + ) + guard let nioSSLCertificate, let nioSSLPrivateKey else { + continue + } + + self.state.withLockedValue { + $0 = CertificateKeyPair( + certificate: .certificate(nioSSLCertificate), + privateKey: .privateKey(nioSSLPrivateKey) + ) + } + } + } + } + } +} + +extension TLSConfiguration { + /// Configure a ``CertificateReloader`` to observe updates for the certificate and key pair used. + /// - Parameter reloader: A ``CertificateReloader`` to watch for certificate and key pair updates. + /// - Returns: A `TLSConfiguration` that reloads the certificate and key used in its SSL handshake. + mutating public func withAutomaticCertificateReloading(using reloader: any CertificateReloader) -> Self { + self.sslContextCallback = { _, promise in + promise.succeed(reloader.sslContextConfigurationOverride) + } + return self + } +} diff --git a/Tests/NIOCertificateHotReloadingTests/NIOCertificateHotReloadingTests.swift b/Tests/NIOCertificateHotReloadingTests/NIOCertificateHotReloadingTests.swift new file mode 100644 index 0000000..66135f2 --- /dev/null +++ b/Tests/NIOCertificateHotReloadingTests/NIOCertificateHotReloadingTests.swift @@ -0,0 +1,313 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCertificateHotReloading +import Foundation +import XCTest +import X509 +import NIOSSL +import NIOConcurrencyHelpers +@preconcurrency import Crypto + +final class TimedCertificateReloaderTests: XCTestCase { + func testCertificatePathDoesNotExist() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init(location: .file(path: "doesnotexist"), format: .der), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testKeyPathDoesNotExist() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .file(path: "doesnotexist"), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testCertificateIsInUnexpectedFormat() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }), + format: .pem + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testKeyIsInUnexpectedFormat() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .pem + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testCertificateAndKeyDoNotMatch() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(P384.Signing.PrivateKey().derRepresentation) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testReloadSuccessfully() async throws { + let certificateBox = NIOLockedValueBox([UInt8]()) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + var override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + + certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes }) + + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testCertificateNotFoundAtReload() async throws { + let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox( + try! Self.sampleCert.serializeAsPEM().derBytes + ) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + var override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + + certificateBox.withLockedValue({ $0 = nil }) + + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testKeyNotFoundAtReload() async throws { + let keyBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox( + Array(Self.samplePrivateKey.derRepresentation) + ) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try! Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { keyBox.withLockedValue({ $0 }) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + var override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + + keyBox.withLockedValue({ $0 = nil }) + + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testCertificateAndKeyDoNotMatchOnReload() async throws { + let keyBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox( + Array(Self.samplePrivateKey.derRepresentation) + ) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try! Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { keyBox.withLockedValue({ $0 }) }), + format: .der + ) + ) { reloader in + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + var override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + + keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) }) + + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + static let startDate = Date() + static let samplePrivateKey = P384.Signing.PrivateKey() + static let sampleCertName = try! DistinguishedName { + CountryName("US") + OrganizationName("Apple") + CommonName("Swift Certificate Test") + } + static let sampleCert: Certificate = { + return try! Certificate( + version: .v3, + serialNumber: .init(), + publicKey: .init(samplePrivateKey.publicKey), + notValidBefore: startDate.advanced(by: -60 * 60 * 24 * 360), + notValidAfter: startDate.advanced(by: 60 * 60 * 24 * 360), + issuer: sampleCertName, + subject: sampleCertName, + signatureAlgorithm: .ecdsaWithSHA384, + extensions: Certificate.Extensions { + Critical( + BasicConstraints.isCertificateAuthority(maxPathLength: nil) + ) + }, + issuerPrivateKey: .init(samplePrivateKey) + ) + }() + + private func runTimedCertificateReloaderTest( + certificate: TimedCertificateReloader.CertificateDescription, + privateKey: TimedCertificateReloader.PrivateKeyDescription, + _ body: @escaping @Sendable (TimedCertificateReloader) async throws -> Void + ) async throws { + let reloader = TimedCertificateReloader( + refreshingEvery: .milliseconds(50), + certificateDescription: .init( + location: certificate.location, + format: certificate.format + ), + privateKeyDescription: .init(location: privateKey.location, format: privateKey.format) + ) + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await reloader.run() + } + group.addTask { + try await body(reloader) + } + try await group.next() + group.cancelAll() + } + } +}