Refactor some stuff

This commit is contained in:
Gus Cairo 2025-04-30 14:53:29 +01:00
parent 6b6a5b5e10
commit 7a755d74e0
2 changed files with 109 additions and 100 deletions

View File

@ -158,9 +158,11 @@ public struct TimedCertificateReloader: CertificateReloader {
self.refreshInterval = refreshInterval self.refreshInterval = refreshInterval
self.certificateDescription = certificateDescription self.certificateDescription = certificateDescription
self.privateKeyDescription = privateKeyDescription self.privateKeyDescription = privateKeyDescription
// TODO: try parsing key and cert here too
self.state = NIOLockedValueBox(nil) self.state = NIOLockedValueBox(nil)
// Immediately try to load the configured cert and key to avoid having to wait for the first
// reload loop to run.
self.reloadPair()
} }
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and /// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
@ -169,96 +171,91 @@ public struct TimedCertificateReloader: CertificateReloader {
public func run() async throws { public func run() async throws {
while !Task.isCancelled { while !Task.isCancelled {
try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds)) try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds))
let certificateBytes: [UInt8]? self.reloadPair()
switch self.certificateDescription.location._backing { }
case .file(path: let path): }
let bytes = FileManager.default.contents(atPath: path)
certificateBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
certificateBytes = bytesProvider()
}
let keyBytes: [UInt8]?
switch self.privateKeyDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
keyBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
keyBytes = bytesProvider()
}
if let certificateBytes, let keyBytes {
let certificate: Certificate
switch self.certificateDescription.format._backing {
case .der:
let parsedCertificate = try? Certificate(derEncoded: Array(certificateBytes))
guard let parsedCertificate else {
// could not parse certificate, ignore this update
continue
}
certificate = parsedCertificate
case .pem:
let parsedCertificate = String(bytes: certificateBytes, encoding: .utf8)
.flatMap { try? Certificate(pemEncoded: $0) }
guard let parsedCertificate else { private func reloadPair() {
// could not parse certificate, ignore this update if let certificateBytes = self.loadCertificate(),
continue let keyBytes = self.loadPrivateKey(),
} let certificate = self.parseCertificate(from: certificateBytes),
let key = self.parsePrivateKey(from: keyBytes),
certificate = parsedCertificate key.publicKey.isValidSignature(certificate.signature, for: certificate) {
} self.attemptToUpdatePair(certificate: certificate, key: key)
}
let key: Certificate.PrivateKey }
switch self.privateKeyDescription.format._backing {
case .der:
let parsedKey = try? Certificate.PrivateKey(derBytes: Array(keyBytes))
guard let parsedKey else {
// could not parse key, ignore the update
continue
}
key = parsedKey
case .pem:
let parsedKey = String(bytes: keyBytes, encoding: .utf8)
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
guard let parsedKey else { private func loadCertificate() -> [UInt8]? {
// could not parse key, ignore the update let certificateBytes: [UInt8]?
continue switch self.certificateDescription.location._backing {
} case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
key = parsedKey certificateBytes = bytes.map { Array($0) }
}
if key.publicKey.isValidSignature(certificate.signature, for: certificate) {
let nioSSLCertificate = try? NIOSSLCertificate(
bytes: certificate.serializeAsPEM().derBytes,
format: .der
)
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
bytes: key.serializeAsPEM().derBytes,
format: .der
)
guard let nioSSLCertificate, let nioSSLPrivateKey else {
continue
}
self.state.withLockedValue { case .memory(let bytesProvider):
$0 = CertificateKeyPair( certificateBytes = bytesProvider()
certificate: .certificate(nioSSLCertificate), }
privateKey: .privateKey(nioSSLPrivateKey) return certificateBytes
) }
}
} private func loadPrivateKey() -> [UInt8]? {
} let keyBytes: [UInt8]?
switch self.privateKeyDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
keyBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
keyBytes = bytesProvider()
}
return keyBytes
}
private func parseCertificate(from certificateBytes: [UInt8]) -> Certificate? {
let certificate: Certificate?
switch self.certificateDescription.format._backing {
case .der:
certificate = try? Certificate(derEncoded: certificateBytes)
case .pem:
certificate = String(bytes: certificateBytes, encoding: .utf8)
.flatMap { try? Certificate(pemEncoded: $0) }
}
return certificate
}
private func parsePrivateKey(from keyBytes: [UInt8]) -> Certificate.PrivateKey? {
let key: Certificate.PrivateKey?
switch self.privateKeyDescription.format._backing {
case .der:
key = try? Certificate.PrivateKey(derBytes: keyBytes)
case .pem:
key = String(bytes: keyBytes, encoding: .utf8)
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
}
return key
}
private func attemptToUpdatePair(certificate: Certificate, key: Certificate.PrivateKey) {
let nioSSLCertificate = try? NIOSSLCertificate(
bytes: certificate.serializeAsPEM().derBytes,
format: .der
)
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
bytes: key.serializeAsPEM().derBytes,
format: .der
)
guard let nioSSLCertificate, let nioSSLPrivateKey else {
return
}
self.state.withLockedValue {
$0 = CertificateKeyPair(
certificate: .certificate(nioSSLCertificate),
privateKey: .privateKey(nioSSLPrivateKey)
)
} }
} }
} }

View File

@ -2,7 +2,7 @@
// //
// This source file is part of the SwiftNIO open source project // This source file is part of the SwiftNIO open source project
// //
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors // Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0 // Licensed under Apache License v2.0
// //
// See LICENSE.txt for license information // See LICENSE.txt for license information
@ -27,7 +27,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain) XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey) XCTAssertNil(override.privateKey)
@ -45,7 +44,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain) XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey) XCTAssertNil(override.privateKey)
@ -63,7 +61,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain) XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey) XCTAssertNil(override.privateKey)
@ -81,7 +78,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .pem format: .pem
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain) XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey) XCTAssertNil(override.privateKey)
@ -99,7 +95,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain) XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey) XCTAssertNil(override.privateKey)
@ -107,7 +102,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
} }
func testReloadSuccessfully() async throws { func testReloadSuccessfully() async throws {
let certificateBox = NIOLockedValueBox([UInt8]()) let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(nil)
try await runTimedCertificateReloaderTest( try await runTimedCertificateReloaderTest(
certificate: .init( certificate: .init(
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }), location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
@ -118,14 +113,19 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) // On first attempt, we should have no certificate or private key overrides available,
// since the certificate box is empty.
var override = reloader.sslContextConfigurationOverride var override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain) XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey) XCTAssertNil(override.privateKey)
// Update the box to contain a valid certificate.
certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes }) certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// Now the overrides should be present.
override = reloader.sslContextConfigurationOverride override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,
@ -152,7 +152,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) // On first attempt, the overrides should be correctly present.
var override = reloader.sslContextConfigurationOverride var override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,
@ -163,9 +163,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
) )
// Update the box to not contain a certificate.
certificateBox.withLockedValue({ $0 = nil }) certificateBox.withLockedValue({ $0 = nil })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// We should still be offering the previously valid cert-key pair.
override = reloader.sslContextConfigurationOverride override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,
@ -192,7 +196,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) // On first attempt, the overrides should be correctly present.
var override = reloader.sslContextConfigurationOverride var override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,
@ -203,9 +207,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
) )
// Update the box to not contain a key.
keyBox.withLockedValue({ $0 = nil }) keyBox.withLockedValue({ $0 = nil })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// We should still be offering the previously valid cert-key pair.
override = reloader.sslContextConfigurationOverride override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,
@ -232,7 +240,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der format: .der
) )
) { reloader in ) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) // On first attempt, the overrides should be correctly present.
var override = reloader.sslContextConfigurationOverride var override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,
@ -243,9 +251,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der)) .privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
) )
// Update the box to contain a key that does not match the given certificate.
keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) }) keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero) try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// We should still be offering the previously valid cert-key pair.
override = reloader.sslContextConfigurationOverride override = reloader.sslContextConfigurationOverride
XCTAssertEqual( XCTAssertEqual(
override.certificateChain, override.certificateChain,