diff --git a/Package.swift b/Package.swift index bc8920d..71b11f1 100644 --- a/Package.swift +++ b/Package.swift @@ -257,6 +257,30 @@ var targets: [PackageDescription.Target] = [ ], swiftSettings: strictConcurrencySettings ), + .target( + name: "NIOCertificateReloading", + dependencies: [ + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "X509", package: "swift-certificates"), + .product(name: "SwiftASN1", package: "swift-asn1"), + .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), + .product(name: "AsyncAlgorithms", package: "swift-async-algorithms"), + .product(name: "Logging", package: "swift-log"), + ], + swiftSettings: strictConcurrencySettings + ), + .testTarget( + name: "NIOCertificateReloadingTests", + dependencies: [ + "NIOCertificateReloading", + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "X509", package: "swift-certificates"), + .product(name: "SwiftASN1", package: "swift-asn1"), + ], + swiftSettings: strictConcurrencySettings + ), ] let package = Package( @@ -270,6 +294,7 @@ let package = Package( .library(name: "NIOHTTPTypesHTTP2", targets: ["NIOHTTPTypesHTTP2"]), .library(name: "NIOResumableUpload", targets: ["NIOResumableUpload"]), .library(name: "NIOHTTPResponsiveness", targets: ["NIOHTTPResponsiveness"]), + .library(name: "NIOCertificateReloading", targets: ["NIOCertificateReloading"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), @@ -278,6 +303,12 @@ let package = Package( .package(url: "https://github.com/apple/swift-http-structured-headers.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.2.0"), + .package(url: "https://github.com/apple/swift-certificates.git", from: "1.10.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.29.3"), + .package(url: "https://github.com/apple/swift-asn1.git", from: "1.3.1"), + .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.8.0"), + .package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.6.3"), ], targets: targets diff --git a/Sources/NIOCertificateReloading/CertificateReloader.swift b/Sources/NIOCertificateReloading/CertificateReloader.swift new file mode 100644 index 0000000..0cd61ea --- /dev/null +++ b/Sources/NIOCertificateReloading/CertificateReloader.swift @@ -0,0 +1,125 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOSSL + +/// A protocol that defines a certificate reloader. +/// +/// A certificate reloader is a service that can provide you with updated versions of a certificate and private key pair, in +/// the form of a `NIOSSLContextConfigurationOverride`, which will be used when performing a TLS handshake in NIO. +/// Each implementation can choose how to observe for changes, but they all require an ``sslContextConfigurationOverride`` +/// to be exposed. +public protocol CertificateReloader: Sendable { + /// A `NIOSSLContextConfigurationOverride` that will be used as part of the NIO application's TLS configuration. + /// Its certificate and private key will be kept up-to-date via whatever mechanism the specific ``CertificateReloader`` + /// implementation provides. + var sslContextConfigurationOverride: NIOSSLContextConfigurationOverride { get } +} + +extension TLSConfiguration { + /// Errors thrown when creating a ``NIOSSL/TLSConfiguration`` with a ``CertificateReloader``. + public struct CertificateReloaderError: Error, Hashable, CustomStringConvertible { + private enum _Backing: CustomStringConvertible { + case missingCertificateChain + case missingPrivateKey + + var description: String { + switch self { + case .missingCertificateChain: + return "Missing certificate chain" + case .missingPrivateKey: + return "Missing private key" + } + } + } + + private let _backing: _Backing + + private init(backing: _Backing) { + self._backing = backing + } + + public var description: String { + self._backing.description + } + + /// The given ``CertificateReloader`` could not provide a certificate chain with which to create this config. + public static var missingCertificateChain: Self { .init(backing: .missingCertificateChain) } + + /// The given ``CertificateReloader`` could not provide a private key with which to create this config. + public static var missingPrivateKey: Self { .init(backing: .missingPrivateKey) } + } + + /// Create a ``NIOSSL/TLSConfiguration`` for use with server-side contexts, with certificate reloading enabled. + /// - Parameter certificateReloader: A ``CertificateReloader`` to watch for certificate and key pair updates. + /// - Returns: A ``NIOSSL/TLSConfiguration`` for use with server-side contexts, that reloads the certificate and key + /// used in its SSL handshake. + /// - Throws: This method will throw if an override isn't present. This may happen if a certificate or private key could not be + /// loaded from the given paths. + public static func makeServerConfiguration( + certificateReloader: some CertificateReloader + ) throws -> Self { + let override = certificateReloader.sslContextConfigurationOverride + + guard let certificateChain = override.certificateChain else { + throw CertificateReloaderError.missingCertificateChain + } + + guard let privateKey = override.privateKey else { + throw CertificateReloaderError.missingPrivateKey + } + + var configuration = Self.makeServerConfiguration( + certificateChain: certificateChain, + privateKey: privateKey + ) + configuration.setCertificateReloader(certificateReloader) + return configuration + } + + /// Create a ``NIOSSL/TLSConfiguration`` for use with client-side contexts, with certificate reloading enabled. + /// - Parameter certificateReloader: A ``CertificateReloader`` to watch for certificate and key pair updates. + /// - Returns: A ``NIOSSL/TLSConfiguration`` for use with client-side contexts, that reloads the certificate and key + /// used in its SSL handshake. + /// - Throws: This method will throw if an override isn't present. This may happen if a certificate or private key could not be + /// loaded from the given paths. + public static func makeClientConfiguration( + certificateReloader: some CertificateReloader + ) throws -> Self { + let override = certificateReloader.sslContextConfigurationOverride + + guard let certificateChain = override.certificateChain else { + throw CertificateReloaderError.missingCertificateChain + } + + guard let privateKey = override.privateKey else { + throw CertificateReloaderError.missingPrivateKey + } + + var configuration = Self.makeClientConfiguration() + configuration.certificateChain = certificateChain + configuration.privateKey = privateKey + configuration.setCertificateReloader(certificateReloader) + return configuration + } + + /// Configure a ``CertificateReloader`` to observe updates for the certificate and key pair used. + /// - Parameter reloader: A ``CertificateReloader`` to watch for certificate and key pair updates. + public mutating func setCertificateReloader(_ reloader: some CertificateReloader) { + self.sslContextCallback = { _, promise in + promise.succeed(reloader.sslContextConfigurationOverride) + } + } +} diff --git a/Sources/NIOCertificateReloading/TimedCertificateReloader.swift b/Sources/NIOCertificateReloading/TimedCertificateReloader.swift new file mode 100644 index 0000000..c9b0420 --- /dev/null +++ b/Sources/NIOCertificateReloading/TimedCertificateReloader.swift @@ -0,0 +1,440 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncAlgorithms +import Logging +import NIOConcurrencyHelpers +import NIOSSL +import ServiceLifecycle +import SwiftASN1 +import X509 + +import struct NIOCore.TimeAmount + +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +/// A ``TimedCertificateReloader`` is an implementation of a ``CertificateReloader``, where the certificate and private +/// key pair is updated at a fixed interval from the file path or memory location configured. +/// +/// You initialize a ``TimedCertificateReloader`` by providing a refresh interval, and locations for the certificate and the private +/// key via ``init(refreshInterval:certificateSource:privateKeySource:logger:)``. +/// Simply creating a timed reloader won't validate that the sources provide valid certificate and private key pairs. If you want this to be +/// validated at creation time, you may instead use +/// ``makeReloaderValidatingSources(refreshInterval:certificateSource:privateKeySource:logger:)``. +/// +/// You may then set the timed reloader on your ``NIOSSL/TLSConfiguration`` using +/// ``NIOSSL/TLSConfiguration/setCertificateReloader(_:)``: +/// +/// ```swift +/// var configuration = TLSConfiguration.makeServerConfiguration( +/// certificateChain: chain, +/// privateKey: key +/// ) +/// let reloader = TimedCertificateReloader( +/// refreshInterval: .seconds(500), +/// certificateSource: TimedCertificateReloader.CertificateSource(...), +/// privateKeySource: TimedCertificateReloader.PrivateKeySource(...) +/// ) +/// configuration.setCertificateReloader(reloader) +/// ``` +/// +/// Finally, you must call ``run()`` on the reloader for it to start observing changes. +/// If you want to trigger a manual reload at any point, you may call ``reload()``. +/// +/// If you're creating a server configuration, you can instead opt to use +/// ``NIOSSL/TLSConfiguration/makeServerConfiguration(certificateReloader:)``, which will set the initial +/// certificate chain and private key, as well as set the reloader: +/// +/// ```swift +/// let configuration = TLSConfiguration.makeServerConfiguration( +/// certificateReloader: reloader +/// ) +/// ``` +/// +/// If you're creating a client configuration, you can instead opt to use +/// ``NIOSSL/TLSConfiguration/makeClientConfiguration(certificateReloader:)`` which will set the reloader: +/// ```swift +/// let configuration = TLSConfiguration.makeClientConfiguration( +/// certificateReloader: reloader +/// ) +/// ``` +/// +/// In both cases, make sure you've either called ``run()`` or created the ``TimedCertificateReloader`` using +/// ``makeReloaderValidatingSources(refreshInterval:certificateSource:privateKeySource:logger:)`` +/// _before_ creating the ``NIOSSL/TLSConfiguration``, as otherwise the validation will fail. +/// +/// Once the reloader is running, you can manually access its ``sslContextConfigurationOverride`` property to get a +/// `NIOSSLContextConfigurationOverride`, although this will typically not be necessary, as it's the NIO channel that will +/// handle the override when initiating TLS handshakes. +/// +/// ```swift +/// try await withThrowingTaskGroup(of: Void.self) { group in +/// group.addTask { +/// reloader.run() +/// } +/// // ... +/// let override = reloader.sslContextConfigurationOverride +/// // ... +/// } +/// ``` +/// +/// ``TimedCertificateReloader`` conforms to `ServiceLifecycle`'s `Service` protocol, meaning you can simply create +/// the reloader and add it to your `ServiceGroup` without having to manually run it. +/// +/// If any errors occur during a reload attempt (such as: being unable to find the file(s) containing the certificate or the key; the format +/// not being recognized or not matching the configured one; not being able to verify a certificate's signature against the given +/// private key; etc), then that attempt will be aborted but the service will keep on trying at the configured interval. +/// The last-valid certificate-key pair (if any) will be returned as the ``sslContextConfigurationOverride``. +#if compiler(>=6.0) +@available(macOS 13, iOS 16, watchOS 9, tvOS 16, macCatalyst 16, visionOS 1, *) +#else +@available(macOS 13, iOS 16, watchOS 9, tvOS 16, macCatalyst 16, *) +#endif +public struct TimedCertificateReloader: CertificateReloader { + /// The encoding for the certificate or the key. + public struct Encoding: Sendable, Hashable { + fileprivate enum _Backing { + case der + case pem + } + fileprivate let _backing: _Backing + + private init(_ backing: _Backing) { + self._backing = backing + } + + /// The encoding of this certificate/key is DER bytes. + public static var der: Self { .init(.der) } + + /// The encoding of this certificate/key is PEM. + public static var pem: Self { .init(.pem) } + } + + /// A location specification for a certificate or key. + public struct Location: Sendable, CustomStringConvertible { + fileprivate enum _Backing: CustomStringConvertible { + case file(path: String) + case memory(provider: @Sendable () throws -> [UInt8]) + + var description: String { + switch self { + case .file(let path): + return "Filepath: \(path)" + case .memory: + return "<in-memory location>" + } + } + } + + fileprivate let _backing: _Backing + + private init(_ backing: _Backing) { + self._backing = backing + } + + public var description: String { + self._backing.description + } + + /// This certificate/key can be found at the given filepath. + /// - Parameter path: The filepath where the certificate/key can be found. + /// - Returns: A `Location`. + public static func file(path: String) -> Self { Self(_Backing.file(path: path)) } + + /// This certificate/key is available in memory, and will be provided by the given closure. + /// - Parameter provider: A closure providing the bytes for the given certificate or key. This closure should return + /// `nil` if a certificate/key isn't currently available for whatever reason. + /// - Returns: A `Location`. + public static func memory(provider: @Sendable @escaping () throws -> [UInt8]) -> Self { + Self(_Backing.memory(provider: provider)) + } + } + + /// A description of a certificate, in terms of its ``TimedCertificateReloader/Location`` and + /// ``TimedCertificateReloader/Encoding``. + public struct CertificateSource: Sendable { + + /// The certificate's ``TimedCertificateReloader/Location``. + public var location: Location + + /// The certificate's ``TimedCertificateReloader/Encoding``. + public var format: Encoding + + /// Initialize a new ``TimedCertificateReloader/CertificateSource``. + /// - Parameters: + /// - location: A ``TimedCertificateReloader/Location``. + /// - format: A ``TimedCertificateReloader/Encoding``. + public init(location: Location, format: Encoding) { + self.location = location + self.format = format + } + } + + /// A description of a private key, in terms of its ``TimedCertificateReloader/Location`` and + /// ``TimedCertificateReloader/Encoding``. + public struct PrivateKeySource: Sendable { + + /// The key's ``TimedCertificateReloader/Location``. + public var location: Location + + /// The key's ``TimedCertificateReloader/Encoding``. + public var format: Encoding + + /// Initialize a new ``TimedCertificateReloader/PrivateKeySource``. + /// - Parameters: + /// - location: A ``TimedCertificateReloader/Location``. + /// - format: A ``TimedCertificateReloader/Encoding``. + public init(location: Location, format: Encoding) { + self.location = location + self.format = format + } + } + + /// Errors specific to the ``TimedCertificateReloader``. + public struct Error: Swift.Error, Hashable, CustomStringConvertible { + private enum _Backing: Hashable, CustomStringConvertible { + case certificatePathNotFound(String) + case privateKeyPathNotFound(String) + + var description: String { + switch self { + case .certificatePathNotFound(let path): + return "Certificate path not found: \(path)" + case .privateKeyPathNotFound(let path): + return "Private key path not found: \(path)" + } + } + } + + private let _backing: _Backing + + private init(_ backing: _Backing) { + self._backing = backing + } + + /// The file path given for the certificate cannot be found. + /// - Parameter path: The file path given for the certificate. + /// - Returns: A ``TimedCertificateReloader/Error``. + public static func certificatePathNotFound(_ path: String) -> Self { + Self(.certificatePathNotFound(path)) + } + + /// The file path given for the private key cannot be found. + /// - Parameter path: The file path given for the private key. + /// - Returns: A ``TimedCertificateReloader/Error``. + public static func privateKeyPathNotFound(_ path: String) -> Self { + Self(.privateKeyPathNotFound(path)) + } + + public var description: String { + self._backing.description + } + } + + private struct CertificateKeyPair { + var certificate: NIOSSLCertificateSource + var privateKey: NIOSSLPrivateKeySource + } + + private let refreshInterval: Duration + private let certificateSource: CertificateSource + private let privateKeySource: PrivateKeySource + private let state: NIOLockedValueBox<CertificateKeyPair?> + private let logger: Logger? + + /// A `NIOSSLContextConfigurationOverride` that will be used as part of the NIO application's TLS configuration. + /// Its certificate and private key will be kept up-to-date via the reload mechanism the ``TimedCertificateReloader`` + /// implementation provides. + /// - Note: If no reload attempt has yet been tried (either by creating the reloader with + /// ``makeReloaderValidatingSources(refreshInterval:certificateSource:privateKeySource:logger:)``, + /// manually calling ``reload()``, or by calling ``run()``), `NIOSSLContextConfigurationOverride/noChanges` + /// will be returned. + public var sslContextConfigurationOverride: NIOSSLContextConfigurationOverride { + get { + guard let certificateKeyPair = self.state.withLockedValue({ $0 }) else { + return .noChanges + } + var override = NIOSSLContextConfigurationOverride() + override.certificateChain = [certificateKeyPair.certificate] + override.privateKey = certificateKeyPair.privateKey + return override + } + } + + /// Initialize a new ``TimedCertificateReloader``. + /// - Important: ``TimedCertificateReloader/sslContextConfigurationOverride`` will return + /// `NIOSSLContextConfigurationOverride/noChanges` until ``TimedCertificateReloader/run()`` or + /// ``TimedCertificateReloader/reload()`` are called. + /// - Parameters: + /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. + /// - certificateSource: A ``TimedCertificateReloader/CertificateSource``. + /// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. + /// - logger: An optional logger. + public init( + refreshInterval: Duration, + certificateSource: CertificateSource, + privateKeySource: PrivateKeySource, + logger: Logger? = nil + ) { + self.refreshInterval = refreshInterval + self.certificateSource = certificateSource + self.privateKeySource = privateKeySource + self.state = NIOLockedValueBox(nil) + self.logger = logger + } + + /// Initialize a new ``TimedCertificateReloader``, and attempt to reload the certificate and private key pair from the given + /// sources. If the reload fails (because e.g. the paths aren't valid), this method will throw. + /// - Important: If this method does not throw, it is guaranteed that + /// ``TimedCertificateReloader/sslContextConfigurationOverride`` will contain the configured certificate and + /// private key pair, even before the first reload is triggered or ``TimedCertificateReloader/run()`` is called. + /// - Parameters: + /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. + /// - certificateSource: A ``TimedCertificateReloader/CertificateSource``. + /// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. + /// - logger: An optional logger. + /// - Returns: The newly created ``TimedCertificateReloader``. + /// - Throws: If either the certificate or private key sources cannot be loaded, an error will be thrown. + public static func makeReloaderValidatingSources( + refreshInterval: Duration, + certificateSource: CertificateSource, + privateKeySource: PrivateKeySource, + logger: Logger? = nil + ) throws -> Self { + let reloader = Self.init( + refreshInterval: refreshInterval, + certificateSource: certificateSource, + privateKeySource: privateKeySource, + logger: logger + ) + try reloader.reload() + return reloader + } + + /// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and + /// private key pair. + /// - Important: You *must* call this method to get certificate and key updates. + public func run() async throws { + for try await _ in AsyncTimerSequence.repeating(every: self.refreshInterval).cancelOnGracefulShutdown() { + do { + try self.reload() + } catch { + self.logger?.debug( + "Failed to reload certificate and private key.", + metadata: [ + "error": "\(error)", + "certificatePath": "\(self.certificateSource.location)", + "privateKeyPath": "\(self.privateKeySource.location)", + ] + ) + } + } + } + + /// Manually attempt a certificate and private key pair update. + public func reload() throws { + let certificateBytes = try self.loadCertificate() + let keyBytes = try self.loadPrivateKey() + if let certificate = try self.parseCertificate(from: certificateBytes), + let key = try self.parsePrivateKey(from: keyBytes), + key.publicKey.isValidSignature(certificate.signature, for: certificate) + { + try self.attemptToUpdatePair(certificate: certificate, key: key) + } + } + + private func loadCertificate() throws -> [UInt8] { + let certificateBytes: [UInt8] + switch self.certificateSource.location._backing { + case .file(let path): + guard let bytes = FileManager.default.contents(atPath: path) else { + throw Error.certificatePathNotFound(path) + } + certificateBytes = Array(bytes) + + case .memory(let bytesProvider): + certificateBytes = try bytesProvider() + } + return certificateBytes + } + + private func loadPrivateKey() throws -> [UInt8] { + let keyBytes: [UInt8] + switch self.privateKeySource.location._backing { + case .file(let path): + guard let bytes = FileManager.default.contents(atPath: path) else { + throw Error.privateKeyPathNotFound(path) + } + keyBytes = Array(bytes) + + case .memory(let bytesProvider): + keyBytes = try bytesProvider() + } + return keyBytes + } + + private func parseCertificate(from certificateBytes: [UInt8]) throws -> Certificate? { + let certificate: Certificate? + switch self.certificateSource.format._backing { + case .der: + certificate = try Certificate(derEncoded: certificateBytes) + + case .pem: + certificate = try String(bytes: certificateBytes, encoding: .utf8) + .flatMap { try Certificate(pemEncoded: $0) } + } + return certificate + } + + private func parsePrivateKey(from keyBytes: [UInt8]) throws -> Certificate.PrivateKey? { + let key: Certificate.PrivateKey? + switch self.privateKeySource.format._backing { + case .der: + key = try Certificate.PrivateKey(derBytes: keyBytes) + + case .pem: + key = try String(bytes: keyBytes, encoding: .utf8) + .flatMap { try Certificate.PrivateKey(pemEncoded: $0) } + } + return key + } + + private func attemptToUpdatePair(certificate: Certificate, key: Certificate.PrivateKey) throws { + let nioSSLCertificate = try NIOSSLCertificate( + bytes: certificate.serializeAsPEM().derBytes, + format: .der + ) + let nioSSLPrivateKey = try NIOSSLPrivateKey( + bytes: key.serializeAsPEM().derBytes, + format: .der + ) + self.state.withLockedValue { + $0 = CertificateKeyPair( + certificate: .certificate(nioSSLCertificate), + privateKey: .privateKey(nioSSLPrivateKey) + ) + } + } +} + +#if compiler(>=6.0) +@available(macOS 13, iOS 16, watchOS 9, tvOS 16, macCatalyst 16, visionOS 1, *) +#else +@available(macOS 13, iOS 16, watchOS 9, tvOS 16, macCatalyst 16, *) +#endif +extension TimedCertificateReloader: Service {} diff --git a/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift b/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift new file mode 100644 index 0000000..c115fee --- /dev/null +++ b/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift @@ -0,0 +1,512 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@preconcurrency import Crypto +import NIOCertificateReloading +import NIOConcurrencyHelpers +import NIOSSL +import SwiftASN1 +import X509 +import XCTest + +#if canImport(FoundationEssentials) +import FoundationEssentials +#else +import Foundation +#endif + +final class TimedCertificateReloaderTests: XCTestCase { + func testCertificatePathDoesNotExist() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init(location: .file(path: "doesnotexist"), format: .der), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ), + validateSources: false + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testCertificatePathDoesNotExist_ValidatingSource() async throws { + do { + try await runTimedCertificateReloaderTest( + certificate: .init(location: .file(path: "doesnotexist"), format: .der), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { _ in + XCTFail("Test should have failed before reaching this point.") + } + } catch { + XCTAssertEqual( + error as? TimedCertificateReloader.Error, + TimedCertificateReloader.Error.certificatePathNotFound("doesnotexist") + ) + } + } + + func testKeyPathDoesNotExist() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .file(path: "doesnotexist"), + format: .der + ), + validateSources: false + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testKeyPathDoesNotExist_ValidatingSource() async throws { + do { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .file(path: "doesnotexist"), + format: .der + ) + ) { _ in + XCTFail("Test should have failed before reaching this point.") + } + } catch { + XCTAssertEqual( + error as? TimedCertificateReloader.Error, + TimedCertificateReloader.Error.privateKeyPathNotFound("doesnotexist") + ) + } + } + + func testCertificateIsInUnexpectedFormat_FromMemory() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .pem + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + private func createTempFile(contents: Data) throws -> URL { + let directory = FileManager.default.temporaryDirectory + let filename = UUID().uuidString + let fileURL = directory.appendingPathComponent(filename) + FileManager.default.createFile(atPath: fileURL.path(), contents: contents) + return fileURL + } + + func testCertificateIsInUnexpectedFormat_FromFile() async throws { + let certBytes = try Self.sampleCert.serializeAsPEM().derBytes + let file = try self.createTempFile(contents: Data(certBytes)) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .file(path: file.path()), + format: .pem + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testKeyIsInUnexpectedFormat_FromMemory() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .pem + ) + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testKeyIsInUnexpectedFormat_FromFile() async throws { + let keyBytes = Self.samplePrivateKey.derRepresentation + let file = try self.createTempFile(contents: keyBytes) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .file(path: file.path()), + format: .pem + ) + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + func testCertificateAndKeyDoNotMatch() async throws { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(P384.Signing.PrivateKey().derRepresentation) }), + format: .der + ) + ) { reloader in + let override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + } + } + + enum TestError: Error { + case emptyCertificate + case emptyPrivateKey + } + + func testReloadSuccessfully_FromMemory() async throws { + let certificateBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox([]) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { + let cert = certificateBox.withLockedValue({ $0 }) + if cert.isEmpty { + throw TestError.emptyCertificate + } + return cert + }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ), + // We need to disable validation because the provider will initially be empty. + validateSources: false + ) { reloader in + // On first attempt, we should have no certificate or private key overrides available, + // since the certificate box is empty. + var override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + + // Update the box to contain a valid certificate. + certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes }) + + // Give the reload loop some time to run and update the cert-key pair. + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + + // Now the overrides should be present. + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testReloadSuccessfully_FromFile() async throws { + // Start with empty files. + let certificateFile = try self.createTempFile(contents: Data()) + let privateKeyFile = try self.createTempFile(contents: Data()) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .file(path: certificateFile.path()), + format: .der + ), + privateKey: .init( + location: .file(path: privateKeyFile.path()), + format: .der + ), + // We need to disable validation because the files will not initially have any contents. + validateSources: false + ) { reloader in + // On first attempt, we should have no certificate or private key overrides available, + // since the certificate box is empty. + var override = reloader.sslContextConfigurationOverride + XCTAssertNil(override.certificateChain) + XCTAssertNil(override.privateKey) + + // Update the files to contain data + try Data(try Self.sampleCert.serializeAsPEM().derBytes).write(to: certificateFile) + try Self.samplePrivateKey.derRepresentation.write(to: privateKeyFile) + + // Give the reload loop some time to run and update the cert-key pair. + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + + // Now the overrides should be present. + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testCertificateNotFoundAtReload() async throws { + let certificateBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox( + try! Self.sampleCert.serializeAsPEM().derBytes + ) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { + let cert = certificateBox.withLockedValue({ $0 }) + if cert.isEmpty { + throw TestError.emptyCertificate + } + return cert + }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { reloader in + // On first attempt, the overrides should be correctly present. + var override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + + // Update the box to contain empty bytes: this will cause the provider to throw. + certificateBox.withLockedValue({ $0 = [] }) + + // Give the reload loop some time to run and update the cert-key pair. + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + + // We should still be offering the previously valid cert-key pair. + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testKeyNotFoundAtReload() async throws { + let keyBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox( + Array(Self.samplePrivateKey.derRepresentation) + ) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { + let key = keyBox.withLockedValue({ $0 }) + if key.isEmpty { + throw TestError.emptyPrivateKey + } + return key + }), + format: .der + ) + ) { reloader in + // On first attempt, the overrides should be correctly present. + var override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + + // Update the box to contain empty bytes: this will cause the provider to throw. + keyBox.withLockedValue({ $0 = [] }) + + // Give the reload loop some time to run and update the cert-key pair. + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + + // We should still be offering the previously valid cert-key pair. + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testCertificateAndKeyDoNotMatchOnReload() async throws { + let keyBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox( + Array(Self.samplePrivateKey.derRepresentation) + ) + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .memory(provider: { keyBox.withLockedValue({ $0 }) }), + format: .der + ) + ) { reloader in + // On first attempt, the overrides should be correctly present. + var override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + + // Update the box to contain a key that does not match the given certificate. + keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) }) + + // Give the reload loop some time to run and update the cert-key pair. + try await Task.sleep(for: .milliseconds(100), tolerance: .zero) + + // We should still be offering the previously valid cert-key pair. + override = reloader.sslContextConfigurationOverride + XCTAssertEqual( + override.certificateChain, + [.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))] + ) + XCTAssertEqual( + override.privateKey, + .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) + ) + } + } + + func testCertificateReloaderErrorDescription() { + XCTAssertEqual( + "\(TLSConfiguration.CertificateReloaderError.missingCertificateChain)", + "Missing certificate chain" + ) + XCTAssertEqual( + "\(TLSConfiguration.CertificateReloaderError.missingPrivateKey)", + "Missing private key" + ) + } + + func testTimedCertificateReloaderErrorDescription() { + XCTAssertEqual( + "\(TimedCertificateReloader.Error.certificatePathNotFound("some/path"))", + "Certificate path not found: some/path" + ) + XCTAssertEqual( + "\(TimedCertificateReloader.Error.privateKeyPathNotFound("some/path"))", + "Private key path not found: some/path" + ) + } + + static let startDate = Date() + static let samplePrivateKey = P384.Signing.PrivateKey() + static let sampleCertName = try! DistinguishedName { + CountryName("US") + OrganizationName("Apple") + CommonName("Swift Certificate Test") + } + static let sampleCert: Certificate = { + try! Certificate( + version: .v3, + serialNumber: .init(), + publicKey: .init(samplePrivateKey.publicKey), + notValidBefore: startDate.advanced(by: -60 * 60 * 24 * 360), + notValidAfter: startDate.advanced(by: 60 * 60 * 24 * 360), + issuer: sampleCertName, + subject: sampleCertName, + signatureAlgorithm: .ecdsaWithSHA384, + extensions: Certificate.Extensions { + Critical( + BasicConstraints.isCertificateAuthority(maxPathLength: nil) + ) + }, + issuerPrivateKey: .init(samplePrivateKey) + ) + }() + + private func runTimedCertificateReloaderTest( + certificate: TimedCertificateReloader.CertificateSource, + privateKey: TimedCertificateReloader.PrivateKeySource, + validateSources: Bool = true, + _ body: @escaping @Sendable (TimedCertificateReloader) async throws -> Void + ) async throws { + let reloader = TimedCertificateReloader( + refreshInterval: .milliseconds(50), + certificateSource: .init( + location: certificate.location, + format: certificate.format + ), + privateKeySource: .init(location: privateKey.location, format: privateKey.format) + ) + + if validateSources { + try reloader.reload() + } + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await reloader.run() + } + try await body(reloader) + group.cancelAll() + } + } +}