mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 17:02:43 +08:00
Refactor some stuff
This commit is contained in:
parent
6b6a5b5e10
commit
7a755d74e0
@ -158,9 +158,11 @@ public struct TimedCertificateReloader: CertificateReloader {
|
||||
self.refreshInterval = refreshInterval
|
||||
self.certificateDescription = certificateDescription
|
||||
self.privateKeyDescription = privateKeyDescription
|
||||
|
||||
// TODO: try parsing key and cert here too
|
||||
self.state = NIOLockedValueBox(nil)
|
||||
|
||||
// Immediately try to load the configured cert and key to avoid having to wait for the first
|
||||
// reload loop to run.
|
||||
self.reloadPair()
|
||||
}
|
||||
|
||||
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
|
||||
@ -169,6 +171,21 @@ public struct TimedCertificateReloader: CertificateReloader {
|
||||
public func run() async throws {
|
||||
while !Task.isCancelled {
|
||||
try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds))
|
||||
self.reloadPair()
|
||||
}
|
||||
}
|
||||
|
||||
private func reloadPair() {
|
||||
if let certificateBytes = self.loadCertificate(),
|
||||
let keyBytes = self.loadPrivateKey(),
|
||||
let certificate = self.parseCertificate(from: certificateBytes),
|
||||
let key = self.parsePrivateKey(from: keyBytes),
|
||||
key.publicKey.isValidSignature(certificate.signature, for: certificate) {
|
||||
self.attemptToUpdatePair(certificate: certificate, key: key)
|
||||
}
|
||||
}
|
||||
|
||||
private func loadCertificate() -> [UInt8]? {
|
||||
let certificateBytes: [UInt8]?
|
||||
switch self.certificateDescription.location._backing {
|
||||
case .file(path: let path):
|
||||
@ -178,7 +195,10 @@ public struct TimedCertificateReloader: CertificateReloader {
|
||||
case .memory(let bytesProvider):
|
||||
certificateBytes = bytesProvider()
|
||||
}
|
||||
return certificateBytes
|
||||
}
|
||||
|
||||
private func loadPrivateKey() -> [UInt8]? {
|
||||
let keyBytes: [UInt8]?
|
||||
switch self.privateKeyDescription.location._backing {
|
||||
case .file(path: let path):
|
||||
@ -188,57 +208,36 @@ public struct TimedCertificateReloader: CertificateReloader {
|
||||
case .memory(let bytesProvider):
|
||||
keyBytes = bytesProvider()
|
||||
}
|
||||
return keyBytes
|
||||
}
|
||||
|
||||
if let certificateBytes, let keyBytes {
|
||||
let certificate: Certificate
|
||||
private func parseCertificate(from certificateBytes: [UInt8]) -> Certificate? {
|
||||
let certificate: Certificate?
|
||||
switch self.certificateDescription.format._backing {
|
||||
case .der:
|
||||
let parsedCertificate = try? Certificate(derEncoded: Array(certificateBytes))
|
||||
|
||||
guard let parsedCertificate else {
|
||||
// could not parse certificate, ignore this update
|
||||
continue
|
||||
}
|
||||
|
||||
certificate = parsedCertificate
|
||||
certificate = try? Certificate(derEncoded: certificateBytes)
|
||||
|
||||
case .pem:
|
||||
let parsedCertificate = String(bytes: certificateBytes, encoding: .utf8)
|
||||
certificate = String(bytes: certificateBytes, encoding: .utf8)
|
||||
.flatMap { try? Certificate(pemEncoded: $0) }
|
||||
|
||||
guard let parsedCertificate else {
|
||||
// could not parse certificate, ignore this update
|
||||
continue
|
||||
}
|
||||
return certificate
|
||||
}
|
||||
|
||||
certificate = parsedCertificate
|
||||
}
|
||||
|
||||
let key: Certificate.PrivateKey
|
||||
private func parsePrivateKey(from keyBytes: [UInt8]) -> Certificate.PrivateKey? {
|
||||
let key: Certificate.PrivateKey?
|
||||
switch self.privateKeyDescription.format._backing {
|
||||
case .der:
|
||||
let parsedKey = try? Certificate.PrivateKey(derBytes: Array(keyBytes))
|
||||
|
||||
guard let parsedKey else {
|
||||
// could not parse key, ignore the update
|
||||
continue
|
||||
}
|
||||
|
||||
key = parsedKey
|
||||
key = try? Certificate.PrivateKey(derBytes: keyBytes)
|
||||
|
||||
case .pem:
|
||||
let parsedKey = String(bytes: keyBytes, encoding: .utf8)
|
||||
key = String(bytes: keyBytes, encoding: .utf8)
|
||||
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
|
||||
|
||||
guard let parsedKey else {
|
||||
// could not parse key, ignore the update
|
||||
continue
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
key = parsedKey
|
||||
}
|
||||
|
||||
if key.publicKey.isValidSignature(certificate.signature, for: certificate) {
|
||||
private func attemptToUpdatePair(certificate: Certificate, key: Certificate.PrivateKey) {
|
||||
let nioSSLCertificate = try? NIOSSLCertificate(
|
||||
bytes: certificate.serializeAsPEM().derBytes,
|
||||
format: .der
|
||||
@ -247,8 +246,9 @@ public struct TimedCertificateReloader: CertificateReloader {
|
||||
bytes: key.serializeAsPEM().derBytes,
|
||||
format: .der
|
||||
)
|
||||
|
||||
guard let nioSSLCertificate, let nioSSLPrivateKey else {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
self.state.withLockedValue {
|
||||
@ -259,9 +259,6 @@ public struct TimedCertificateReloader: CertificateReloader {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension TLSConfiguration {
|
||||
/// Configure a ``CertificateReloader`` to observe updates for the certificate and key pair used.
|
||||
|
@ -2,7 +2,7 @@
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
|
||||
// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
@ -27,7 +27,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
let override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertNil(override.certificateChain)
|
||||
XCTAssertNil(override.privateKey)
|
||||
@ -45,7 +44,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
let override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertNil(override.certificateChain)
|
||||
XCTAssertNil(override.privateKey)
|
||||
@ -63,7 +61,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
let override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertNil(override.certificateChain)
|
||||
XCTAssertNil(override.privateKey)
|
||||
@ -81,7 +78,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .pem
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
let override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertNil(override.certificateChain)
|
||||
XCTAssertNil(override.privateKey)
|
||||
@ -99,7 +95,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
let override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertNil(override.certificateChain)
|
||||
XCTAssertNil(override.privateKey)
|
||||
@ -107,7 +102,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
}
|
||||
|
||||
func testReloadSuccessfully() async throws {
|
||||
let certificateBox = NIOLockedValueBox([UInt8]())
|
||||
let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(nil)
|
||||
try await runTimedCertificateReloaderTest(
|
||||
certificate: .init(
|
||||
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
|
||||
@ -118,14 +113,19 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
// On first attempt, we should have no certificate or private key overrides available,
|
||||
// since the certificate box is empty.
|
||||
var override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertNil(override.certificateChain)
|
||||
XCTAssertNil(override.privateKey)
|
||||
|
||||
// Update the box to contain a valid certificate.
|
||||
certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes })
|
||||
|
||||
// Give the reload loop some time to run and update the cert-key pair.
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
|
||||
// Now the overrides should be present.
|
||||
override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
@ -152,7 +152,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
// On first attempt, the overrides should be correctly present.
|
||||
var override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
@ -163,9 +163,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
||||
)
|
||||
|
||||
// Update the box to not contain a certificate.
|
||||
certificateBox.withLockedValue({ $0 = nil })
|
||||
|
||||
// Give the reload loop some time to run and update the cert-key pair.
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
|
||||
// We should still be offering the previously valid cert-key pair.
|
||||
override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
@ -192,7 +196,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
// On first attempt, the overrides should be correctly present.
|
||||
var override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
@ -203,9 +207,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
||||
)
|
||||
|
||||
// Update the box to not contain a key.
|
||||
keyBox.withLockedValue({ $0 = nil })
|
||||
|
||||
// Give the reload loop some time to run and update the cert-key pair.
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
|
||||
// We should still be offering the previously valid cert-key pair.
|
||||
override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
@ -232,7 +240,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
format: .der
|
||||
)
|
||||
) { reloader in
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
// On first attempt, the overrides should be correctly present.
|
||||
var override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
@ -243,9 +251,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
||||
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
||||
)
|
||||
|
||||
// Update the box to contain a key that does not match the given certificate.
|
||||
keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) })
|
||||
|
||||
// Give the reload loop some time to run and update the cert-key pair.
|
||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||
|
||||
// We should still be offering the previously valid cert-key pair.
|
||||
override = reloader.sslContextConfigurationOverride
|
||||
XCTAssertEqual(
|
||||
override.certificateChain,
|
||||
|
Loading…
x
Reference in New Issue
Block a user