diff --git a/Sources/NIOCertificateReloading/TimedCertificateReloader.swift b/Sources/NIOCertificateReloading/TimedCertificateReloader.swift index 16b1a97..e0a10a3 100644 --- a/Sources/NIOCertificateReloading/TimedCertificateReloader.swift +++ b/Sources/NIOCertificateReloading/TimedCertificateReloader.swift @@ -254,50 +254,14 @@ public struct TimedCertificateReloader: CertificateReloader { } /// Initialize a new ``TimedCertificateReloader``. + /// - Important: ``TimedCertificateReloader/sslContextConfigurationOverride`` will not contain any + /// certificate or private key overrides until either ``TimedCertificateReloader/run()`` or + /// ``TimedCertificateReloader/reload()`` are called. /// - Parameters: /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. /// - certificateSource: A ``TimedCertificateReloader/CertificateSource``. /// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. - public init( - refreshInterval: TimeAmount, - certificateSource: CertificateSource, - privateKeySource: PrivateKeySource, - logger: Logger? = nil - ) { - self.init( - refreshInterval: Duration(refreshInterval), - certificateSource: certificateSource, - privateKeySource: privateKeySource, - logger: logger - ) - } - - /// Attempt to initialize a new ``TimedCertificateReloader``, but throw if the given certificate and private keys cannot be - /// loaded. - /// - Parameters: - /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. - /// - validatingCertificateSource: A ``TimedCertificateReloader/CertificateSource``. - /// - validatingPrivateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. - /// - Throws: If the certificate or private key cannot be loaded. - public init( - refreshInterval: TimeAmount, - validatingCertificateSource: CertificateSource, - validatingPrivateKeySource: PrivateKeySource, - logger: Logger? = nil - ) throws { - try self.init( - refreshInterval: Duration(refreshInterval), - validatingCertificateSource: validatingCertificateSource, - validatingPrivateKeySource: validatingPrivateKeySource, - logger: logger - ) - } - - /// Initialize a new ``TimedCertificateReloader``. - /// - Parameters: - /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. - /// - certificateSource: A ``TimedCertificateReloader/CertificateSource``. - /// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. + /// - logger: An optional logger. public init( refreshInterval: Duration, certificateSource: CertificateSource, @@ -309,36 +273,34 @@ public struct TimedCertificateReloader: CertificateReloader { self.privateKeySource = privateKeySource self.state = NIOLockedValueBox(nil) self.logger = logger - - // Immediately try to load the configured cert and key to avoid having to wait for the first - // reload loop to run. - // We ignore errors because this initializer tolerates not finding the certificate and/or - // private key on first load. - try? self.reloadPair() } - /// Attempt to initialize a new ``TimedCertificateReloader``, but throw if the given certificate and private keys cannot be - /// loaded. + /// Initialize a new ``TimedCertificateReloader``, and attempt to reload the certificate and private key pair from the given + /// sources. If the reload fails (because e.g. the paths aren't valid), this method will throw. + /// - Important: If this method does not throw, it is guaranteed that + /// ``TimedCertificateReloader/sslContextConfigurationOverride`` will contain the configured certificate and + /// private key pair, even before the first reload is triggered or ``TimedCertificateReloader/run()`` is called. /// - Parameters: /// - refreshInterval: The interval at which attempts to update the certificate and private key should be made. - /// - validatingCertificateSource: A ``TimedCertificateReloader/CertificateSource``. - /// - validatingPrivateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. - /// - Throws: If the certificate or private key cannot be loaded. - public init( + /// - certificateSource: A ``TimedCertificateReloader/CertificateSource``. + /// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``. + /// - logger: An optional logger. + /// - Returns: The newly created ``TimedCertificateReloader``. + /// - Throws: If either the certificate or private key sources cannot be loaded, an error will be thrown. + static public func makeReloaderValidatingSources( refreshInterval: Duration, - validatingCertificateSource: CertificateSource, - validatingPrivateKeySource: PrivateKeySource, + certificateSource: CertificateSource, + privateKeySource: PrivateKeySource, logger: Logger? = nil - ) throws { - self.refreshInterval = refreshInterval - self.certificateSource = validatingCertificateSource - self.privateKeySource = validatingPrivateKeySource - self.state = NIOLockedValueBox(nil) - self.logger = logger - - // Immediately try to load the configured cert and key to avoid having to wait for the first - // reload loop to run. - try self.reloadPair() + ) throws -> Self { + let reloader = Self.init( + refreshInterval: refreshInterval, + certificateSource: certificateSource, + privateKeySource: privateKeySource, + logger: logger + ) + try reloader.reload() + return reloader } /// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and @@ -347,7 +309,7 @@ public struct TimedCertificateReloader: CertificateReloader { public func run() async throws { for try await _ in AsyncTimerSequence.repeating(every: self.refreshInterval).cancelOnGracefulShutdown() { do { - try self.reloadPair() + try self.reload() } catch { self.logger?.debug( "Failed to reload certificate and private key.", @@ -361,7 +323,8 @@ public struct TimedCertificateReloader: CertificateReloader { } } - private func reloadPair() throws { + /// Manually attempt a certificate and private key pair update. + public func reload() throws { if let certificateBytes = try self.loadCertificate(), let keyBytes = try self.loadPrivateKey(), let certificate = try self.parseCertificate(from: certificateBytes), diff --git a/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift b/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift index 957f3ea..ed4e302 100644 --- a/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift +++ b/Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift @@ -27,7 +27,8 @@ final class TimedCertificateReloaderTests: XCTestCase { privateKey: .init( location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), format: .der - ) + ), + validateSources: false ) { reloader in let override = reloader.sslContextConfigurationOverride XCTAssertNil(override.certificateChain) @@ -35,6 +36,25 @@ final class TimedCertificateReloaderTests: XCTestCase { } } + func testCertificatePathDoesNotExist_ValidatingSource() async throws { + do { + try await runTimedCertificateReloaderTest( + certificate: .init(location: .file(path: "doesnotexist"), format: .der), + privateKey: .init( + location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }), + format: .der + ) + ) { _ in + XCTFail("Test should have failed before reaching this point.") + } + } catch { + XCTAssertEqual( + error as? TimedCertificateReloader.Error, + TimedCertificateReloader.Error.certificatePathNotFound("doesnotexist") + ) + } + } + func testKeyPathDoesNotExist() async throws { try await runTimedCertificateReloaderTest( certificate: .init( @@ -44,7 +64,8 @@ final class TimedCertificateReloaderTests: XCTestCase { privateKey: .init( location: .file(path: "doesnotexist"), format: .der - ) + ), + validateSources: false ) { reloader in let override = reloader.sslContextConfigurationOverride XCTAssertNil(override.certificateChain) @@ -52,6 +73,28 @@ final class TimedCertificateReloaderTests: XCTestCase { } } + func testKeyPathDoesNotExist_ValidatingSource() async throws { + do { + try await runTimedCertificateReloaderTest( + certificate: .init( + location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }), + format: .der + ), + privateKey: .init( + location: .file(path: "doesnotexist"), + format: .der + ) + ) { _ in + XCTFail("Test should have failed before reaching this point.") + } + } catch { + XCTAssertEqual( + error as? TimedCertificateReloader.Error, + TimedCertificateReloader.Error.privateKeyPathNotFound("doesnotexist") + ) + } + } + func testCertificateIsInUnexpectedFormat() async throws { try await runTimedCertificateReloaderTest( certificate: .init( @@ -323,6 +366,7 @@ final class TimedCertificateReloaderTests: XCTestCase { private func runTimedCertificateReloaderTest( certificate: TimedCertificateReloader.CertificateSource, privateKey: TimedCertificateReloader.PrivateKeySource, + validateSources: Bool = true, _ body: @escaping @Sendable (TimedCertificateReloader) async throws -> Void ) async throws { let reloader = TimedCertificateReloader( @@ -333,6 +377,11 @@ final class TimedCertificateReloaderTests: XCTestCase { ), privateKeySource: .init(location: privateKey.location, format: privateKey.format) ) + + if validateSources { + try reloader.reload() + } + try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await reloader.run()