PR changes

This commit is contained in:
Gus Cairo 2025-05-08 14:46:38 +01:00
parent f3cce7f10c
commit 33fe7400b5
2 changed files with 80 additions and 68 deletions

View File

@ -254,50 +254,14 @@ public struct TimedCertificateReloader: CertificateReloader {
}
/// Initialize a new ``TimedCertificateReloader``.
/// - Important: ``TimedCertificateReloader/sslContextConfigurationOverride`` will not contain any
/// certificate or private key overrides until either ``TimedCertificateReloader/run()`` or
/// ``TimedCertificateReloader/reload()`` are called.
/// - Parameters:
/// - refreshInterval: The interval at which attempts to update the certificate and private key should be made.
/// - certificateSource: A ``TimedCertificateReloader/CertificateSource``.
/// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``.
public init(
refreshInterval: TimeAmount,
certificateSource: CertificateSource,
privateKeySource: PrivateKeySource,
logger: Logger? = nil
) {
self.init(
refreshInterval: Duration(refreshInterval),
certificateSource: certificateSource,
privateKeySource: privateKeySource,
logger: logger
)
}
/// Attempt to initialize a new ``TimedCertificateReloader``, but throw if the given certificate and private keys cannot be
/// loaded.
/// - Parameters:
/// - refreshInterval: The interval at which attempts to update the certificate and private key should be made.
/// - validatingCertificateSource: A ``TimedCertificateReloader/CertificateSource``.
/// - validatingPrivateKeySource: A ``TimedCertificateReloader/PrivateKeySource``.
/// - Throws: If the certificate or private key cannot be loaded.
public init(
refreshInterval: TimeAmount,
validatingCertificateSource: CertificateSource,
validatingPrivateKeySource: PrivateKeySource,
logger: Logger? = nil
) throws {
try self.init(
refreshInterval: Duration(refreshInterval),
validatingCertificateSource: validatingCertificateSource,
validatingPrivateKeySource: validatingPrivateKeySource,
logger: logger
)
}
/// Initialize a new ``TimedCertificateReloader``.
/// - Parameters:
/// - refreshInterval: The interval at which attempts to update the certificate and private key should be made.
/// - certificateSource: A ``TimedCertificateReloader/CertificateSource``.
/// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``.
/// - logger: An optional logger.
public init(
refreshInterval: Duration,
certificateSource: CertificateSource,
@ -309,36 +273,34 @@ public struct TimedCertificateReloader: CertificateReloader {
self.privateKeySource = privateKeySource
self.state = NIOLockedValueBox(nil)
self.logger = logger
// Immediately try to load the configured cert and key to avoid having to wait for the first
// reload loop to run.
// We ignore errors because this initializer tolerates not finding the certificate and/or
// private key on first load.
try? self.reloadPair()
}
/// Attempt to initialize a new ``TimedCertificateReloader``, but throw if the given certificate and private keys cannot be
/// loaded.
/// Initialize a new ``TimedCertificateReloader``, and attempt to reload the certificate and private key pair from the given
/// sources. If the reload fails (because e.g. the paths aren't valid), this method will throw.
/// - Important: If this method does not throw, it is guaranteed that
/// ``TimedCertificateReloader/sslContextConfigurationOverride`` will contain the configured certificate and
/// private key pair, even before the first reload is triggered or ``TimedCertificateReloader/run()`` is called.
/// - Parameters:
/// - refreshInterval: The interval at which attempts to update the certificate and private key should be made.
/// - validatingCertificateSource: A ``TimedCertificateReloader/CertificateSource``.
/// - validatingPrivateKeySource: A ``TimedCertificateReloader/PrivateKeySource``.
/// - Throws: If the certificate or private key cannot be loaded.
public init(
/// - certificateSource: A ``TimedCertificateReloader/CertificateSource``.
/// - privateKeySource: A ``TimedCertificateReloader/PrivateKeySource``.
/// - logger: An optional logger.
/// - Returns: The newly created ``TimedCertificateReloader``.
/// - Throws: If either the certificate or private key sources cannot be loaded, an error will be thrown.
static public func makeReloaderValidatingSources(
refreshInterval: Duration,
validatingCertificateSource: CertificateSource,
validatingPrivateKeySource: PrivateKeySource,
certificateSource: CertificateSource,
privateKeySource: PrivateKeySource,
logger: Logger? = nil
) throws {
self.refreshInterval = refreshInterval
self.certificateSource = validatingCertificateSource
self.privateKeySource = validatingPrivateKeySource
self.state = NIOLockedValueBox(nil)
self.logger = logger
// Immediately try to load the configured cert and key to avoid having to wait for the first
// reload loop to run.
try self.reloadPair()
) throws -> Self {
let reloader = Self.init(
refreshInterval: refreshInterval,
certificateSource: certificateSource,
privateKeySource: privateKeySource,
logger: logger
)
try reloader.reload()
return reloader
}
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
@ -347,7 +309,7 @@ public struct TimedCertificateReloader: CertificateReloader {
public func run() async throws {
for try await _ in AsyncTimerSequence.repeating(every: self.refreshInterval).cancelOnGracefulShutdown() {
do {
try self.reloadPair()
try self.reload()
} catch {
self.logger?.debug(
"Failed to reload certificate and private key.",
@ -361,7 +323,8 @@ public struct TimedCertificateReloader: CertificateReloader {
}
}
private func reloadPair() throws {
/// Manually attempt a certificate and private key pair update.
public func reload() throws {
if let certificateBytes = try self.loadCertificate(),
let keyBytes = try self.loadPrivateKey(),
let certificate = try self.parseCertificate(from: certificateBytes),

View File

@ -27,7 +27,8 @@ final class TimedCertificateReloaderTests: XCTestCase {
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .der
)
),
validateSources: false
) { reloader in
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
@ -35,6 +36,25 @@ final class TimedCertificateReloaderTests: XCTestCase {
}
}
func testCertificatePathDoesNotExist_ValidatingSource() async throws {
do {
try await runTimedCertificateReloaderTest(
certificate: .init(location: .file(path: "doesnotexist"), format: .der),
privateKey: .init(
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
format: .der
)
) { _ in
XCTFail("Test should have failed before reaching this point.")
}
} catch {
XCTAssertEqual(
error as? TimedCertificateReloader.Error,
TimedCertificateReloader.Error.certificatePathNotFound("doesnotexist")
)
}
}
func testKeyPathDoesNotExist() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(
@ -44,7 +64,8 @@ final class TimedCertificateReloaderTests: XCTestCase {
privateKey: .init(
location: .file(path: "doesnotexist"),
format: .der
)
),
validateSources: false
) { reloader in
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
@ -52,6 +73,28 @@ final class TimedCertificateReloaderTests: XCTestCase {
}
}
func testKeyPathDoesNotExist_ValidatingSource() async throws {
do {
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
format: .der
),
privateKey: .init(
location: .file(path: "doesnotexist"),
format: .der
)
) { _ in
XCTFail("Test should have failed before reaching this point.")
}
} catch {
XCTAssertEqual(
error as? TimedCertificateReloader.Error,
TimedCertificateReloader.Error.privateKeyPathNotFound("doesnotexist")
)
}
}
func testCertificateIsInUnexpectedFormat() async throws {
try await runTimedCertificateReloaderTest(
certificate: .init(
@ -323,6 +366,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
private func runTimedCertificateReloaderTest(
certificate: TimedCertificateReloader.CertificateSource,
privateKey: TimedCertificateReloader.PrivateKeySource,
validateSources: Bool = true,
_ body: @escaping @Sendable (TimedCertificateReloader) async throws -> Void
) async throws {
let reloader = TimedCertificateReloader(
@ -333,6 +377,11 @@ final class TimedCertificateReloaderTests: XCTestCase {
),
privateKeySource: .init(location: privateKey.location, format: privateKey.format)
)
if validateSources {
try reloader.reload()
}
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await reloader.run()