Refactor some stuff

This commit is contained in:
Gus Cairo 2025-04-30 14:53:29 +01:00
parent 6b6a5b5e10
commit 7a755d74e0
2 changed files with 109 additions and 100 deletions

View File

@ -158,9 +158,11 @@ public struct TimedCertificateReloader: CertificateReloader {
self.refreshInterval = refreshInterval
self.certificateDescription = certificateDescription
self.privateKeyDescription = privateKeyDescription
// TODO: try parsing key and cert here too
self.state = NIOLockedValueBox(nil)
// Immediately try to load the configured cert and key to avoid having to wait for the first
// reload loop to run.
self.reloadPair()
}
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
@ -169,96 +171,91 @@ public struct TimedCertificateReloader: CertificateReloader {
public func run() async throws {
while !Task.isCancelled {
try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds))
let certificateBytes: [UInt8]?
switch self.certificateDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
certificateBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
certificateBytes = bytesProvider()
}
let keyBytes: [UInt8]?
switch self.privateKeyDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
keyBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
keyBytes = bytesProvider()
}
if let certificateBytes, let keyBytes {
let certificate: Certificate
switch self.certificateDescription.format._backing {
case .der:
let parsedCertificate = try? Certificate(derEncoded: Array(certificateBytes))
guard let parsedCertificate else {
// could not parse certificate, ignore this update
continue
}
certificate = parsedCertificate
case .pem:
let parsedCertificate = String(bytes: certificateBytes, encoding: .utf8)
.flatMap { try? Certificate(pemEncoded: $0) }
self.reloadPair()
}
}
guard let parsedCertificate else {
// could not parse certificate, ignore this update
continue
}
certificate = parsedCertificate
}
let key: Certificate.PrivateKey
switch self.privateKeyDescription.format._backing {
case .der:
let parsedKey = try? Certificate.PrivateKey(derBytes: Array(keyBytes))
guard let parsedKey else {
// could not parse key, ignore the update
continue
}
key = parsedKey
case .pem:
let parsedKey = String(bytes: keyBytes, encoding: .utf8)
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
private func reloadPair() {
if let certificateBytes = self.loadCertificate(),
let keyBytes = self.loadPrivateKey(),
let certificate = self.parseCertificate(from: certificateBytes),
let key = self.parsePrivateKey(from: keyBytes),
key.publicKey.isValidSignature(certificate.signature, for: certificate) {
self.attemptToUpdatePair(certificate: certificate, key: key)
}
}
guard let parsedKey else {
// could not parse key, ignore the update
continue
}
key = parsedKey
}
if key.publicKey.isValidSignature(certificate.signature, for: certificate) {
let nioSSLCertificate = try? NIOSSLCertificate(
bytes: certificate.serializeAsPEM().derBytes,
format: .der
)
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
bytes: key.serializeAsPEM().derBytes,
format: .der
)
guard let nioSSLCertificate, let nioSSLPrivateKey else {
continue
}
private func loadCertificate() -> [UInt8]? {
let certificateBytes: [UInt8]?
switch self.certificateDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
certificateBytes = bytes.map { Array($0) }
self.state.withLockedValue {
$0 = CertificateKeyPair(
certificate: .certificate(nioSSLCertificate),
privateKey: .privateKey(nioSSLPrivateKey)
)
}
}
}
case .memory(let bytesProvider):
certificateBytes = bytesProvider()
}
return certificateBytes
}
private func loadPrivateKey() -> [UInt8]? {
let keyBytes: [UInt8]?
switch self.privateKeyDescription.location._backing {
case .file(path: let path):
let bytes = FileManager.default.contents(atPath: path)
keyBytes = bytes.map { Array($0) }
case .memory(let bytesProvider):
keyBytes = bytesProvider()
}
return keyBytes
}
private func parseCertificate(from certificateBytes: [UInt8]) -> Certificate? {
let certificate: Certificate?
switch self.certificateDescription.format._backing {
case .der:
certificate = try? Certificate(derEncoded: certificateBytes)
case .pem:
certificate = String(bytes: certificateBytes, encoding: .utf8)
.flatMap { try? Certificate(pemEncoded: $0) }
}
return certificate
}
private func parsePrivateKey(from keyBytes: [UInt8]) -> Certificate.PrivateKey? {
let key: Certificate.PrivateKey?
switch self.privateKeyDescription.format._backing {
case .der:
key = try? Certificate.PrivateKey(derBytes: keyBytes)
case .pem:
key = String(bytes: keyBytes, encoding: .utf8)
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
}
return key
}
private func attemptToUpdatePair(certificate: Certificate, key: Certificate.PrivateKey) {
let nioSSLCertificate = try? NIOSSLCertificate(
bytes: certificate.serializeAsPEM().derBytes,
format: .der
)
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
bytes: key.serializeAsPEM().derBytes,
format: .der
)
guard let nioSSLCertificate, let nioSSLPrivateKey else {
return
}
self.state.withLockedValue {
$0 = CertificateKeyPair(
certificate: .certificate(nioSSLCertificate),
privateKey: .privateKey(nioSSLPrivateKey)
)
}
}
}

View File

@ -2,7 +2,7 @@
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
@ -27,7 +27,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
@ -45,7 +44,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
@ -63,7 +61,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
@ -81,7 +78,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .pem
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
@ -99,7 +95,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
let override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
@ -107,7 +102,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
}
func testReloadSuccessfully() async throws {
let certificateBox = NIOLockedValueBox([UInt8]())
let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(nil)
try await runTimedCertificateReloaderTest(
certificate: .init(
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
@ -118,14 +113,19 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// On first attempt, we should have no certificate or private key overrides available,
// since the certificate box is empty.
var override = reloader.sslContextConfigurationOverride
XCTAssertNil(override.certificateChain)
XCTAssertNil(override.privateKey)
// Update the box to contain a valid certificate.
certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// Now the overrides should be present.
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
@ -152,7 +152,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// On first attempt, the overrides should be correctly present.
var override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
@ -163,9 +163,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
// Update the box to not contain a certificate.
certificateBox.withLockedValue({ $0 = nil })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// We should still be offering the previously valid cert-key pair.
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
@ -192,7 +196,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// On first attempt, the overrides should be correctly present.
var override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
@ -203,9 +207,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
// Update the box to not contain a key.
keyBox.withLockedValue({ $0 = nil })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// We should still be offering the previously valid cert-key pair.
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
@ -232,7 +240,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
format: .der
)
) { reloader in
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// On first attempt, the overrides should be correctly present.
var override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,
@ -243,9 +251,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
)
// Update the box to contain a key that does not match the given certificate.
keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) })
// Give the reload loop some time to run and update the cert-key pair.
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
// We should still be offering the previously valid cert-key pair.
override = reloader.sslContextConfigurationOverride
XCTAssertEqual(
override.certificateChain,