mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 08:52:42 +08:00
Refactor some stuff
This commit is contained in:
parent
6b6a5b5e10
commit
7a755d74e0
@ -158,9 +158,11 @@ public struct TimedCertificateReloader: CertificateReloader {
|
|||||||
self.refreshInterval = refreshInterval
|
self.refreshInterval = refreshInterval
|
||||||
self.certificateDescription = certificateDescription
|
self.certificateDescription = certificateDescription
|
||||||
self.privateKeyDescription = privateKeyDescription
|
self.privateKeyDescription = privateKeyDescription
|
||||||
|
|
||||||
// TODO: try parsing key and cert here too
|
|
||||||
self.state = NIOLockedValueBox(nil)
|
self.state = NIOLockedValueBox(nil)
|
||||||
|
|
||||||
|
// Immediately try to load the configured cert and key to avoid having to wait for the first
|
||||||
|
// reload loop to run.
|
||||||
|
self.reloadPair()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
|
/// A long-running method to run the ``TimedCertificateReloader`` and start observing updates for the certificate and
|
||||||
@ -169,96 +171,91 @@ public struct TimedCertificateReloader: CertificateReloader {
|
|||||||
public func run() async throws {
|
public func run() async throws {
|
||||||
while !Task.isCancelled {
|
while !Task.isCancelled {
|
||||||
try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds))
|
try await Task.sleep(nanoseconds: UInt64(self.refreshInterval.nanoseconds))
|
||||||
let certificateBytes: [UInt8]?
|
self.reloadPair()
|
||||||
switch self.certificateDescription.location._backing {
|
}
|
||||||
case .file(path: let path):
|
}
|
||||||
let bytes = FileManager.default.contents(atPath: path)
|
|
||||||
certificateBytes = bytes.map { Array($0) }
|
|
||||||
|
|
||||||
case .memory(let bytesProvider):
|
|
||||||
certificateBytes = bytesProvider()
|
|
||||||
}
|
|
||||||
|
|
||||||
let keyBytes: [UInt8]?
|
|
||||||
switch self.privateKeyDescription.location._backing {
|
|
||||||
case .file(path: let path):
|
|
||||||
let bytes = FileManager.default.contents(atPath: path)
|
|
||||||
keyBytes = bytes.map { Array($0) }
|
|
||||||
|
|
||||||
case .memory(let bytesProvider):
|
|
||||||
keyBytes = bytesProvider()
|
|
||||||
}
|
|
||||||
|
|
||||||
if let certificateBytes, let keyBytes {
|
|
||||||
let certificate: Certificate
|
|
||||||
switch self.certificateDescription.format._backing {
|
|
||||||
case .der:
|
|
||||||
let parsedCertificate = try? Certificate(derEncoded: Array(certificateBytes))
|
|
||||||
|
|
||||||
guard let parsedCertificate else {
|
|
||||||
// could not parse certificate, ignore this update
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
certificate = parsedCertificate
|
|
||||||
|
|
||||||
case .pem:
|
|
||||||
let parsedCertificate = String(bytes: certificateBytes, encoding: .utf8)
|
|
||||||
.flatMap { try? Certificate(pemEncoded: $0) }
|
|
||||||
|
|
||||||
guard let parsedCertificate else {
|
private func reloadPair() {
|
||||||
// could not parse certificate, ignore this update
|
if let certificateBytes = self.loadCertificate(),
|
||||||
continue
|
let keyBytes = self.loadPrivateKey(),
|
||||||
}
|
let certificate = self.parseCertificate(from: certificateBytes),
|
||||||
|
let key = self.parsePrivateKey(from: keyBytes),
|
||||||
certificate = parsedCertificate
|
key.publicKey.isValidSignature(certificate.signature, for: certificate) {
|
||||||
}
|
self.attemptToUpdatePair(certificate: certificate, key: key)
|
||||||
|
}
|
||||||
let key: Certificate.PrivateKey
|
}
|
||||||
switch self.privateKeyDescription.format._backing {
|
|
||||||
case .der:
|
|
||||||
let parsedKey = try? Certificate.PrivateKey(derBytes: Array(keyBytes))
|
|
||||||
|
|
||||||
guard let parsedKey else {
|
|
||||||
// could not parse key, ignore the update
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
key = parsedKey
|
|
||||||
|
|
||||||
case .pem:
|
|
||||||
let parsedKey = String(bytes: keyBytes, encoding: .utf8)
|
|
||||||
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
|
|
||||||
|
|
||||||
guard let parsedKey else {
|
private func loadCertificate() -> [UInt8]? {
|
||||||
// could not parse key, ignore the update
|
let certificateBytes: [UInt8]?
|
||||||
continue
|
switch self.certificateDescription.location._backing {
|
||||||
}
|
case .file(path: let path):
|
||||||
|
let bytes = FileManager.default.contents(atPath: path)
|
||||||
key = parsedKey
|
certificateBytes = bytes.map { Array($0) }
|
||||||
}
|
|
||||||
|
|
||||||
if key.publicKey.isValidSignature(certificate.signature, for: certificate) {
|
|
||||||
let nioSSLCertificate = try? NIOSSLCertificate(
|
|
||||||
bytes: certificate.serializeAsPEM().derBytes,
|
|
||||||
format: .der
|
|
||||||
)
|
|
||||||
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
|
|
||||||
bytes: key.serializeAsPEM().derBytes,
|
|
||||||
format: .der
|
|
||||||
)
|
|
||||||
guard let nioSSLCertificate, let nioSSLPrivateKey else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
self.state.withLockedValue {
|
case .memory(let bytesProvider):
|
||||||
$0 = CertificateKeyPair(
|
certificateBytes = bytesProvider()
|
||||||
certificate: .certificate(nioSSLCertificate),
|
}
|
||||||
privateKey: .privateKey(nioSSLPrivateKey)
|
return certificateBytes
|
||||||
)
|
}
|
||||||
}
|
|
||||||
}
|
private func loadPrivateKey() -> [UInt8]? {
|
||||||
}
|
let keyBytes: [UInt8]?
|
||||||
|
switch self.privateKeyDescription.location._backing {
|
||||||
|
case .file(path: let path):
|
||||||
|
let bytes = FileManager.default.contents(atPath: path)
|
||||||
|
keyBytes = bytes.map { Array($0) }
|
||||||
|
|
||||||
|
case .memory(let bytesProvider):
|
||||||
|
keyBytes = bytesProvider()
|
||||||
|
}
|
||||||
|
return keyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
private func parseCertificate(from certificateBytes: [UInt8]) -> Certificate? {
|
||||||
|
let certificate: Certificate?
|
||||||
|
switch self.certificateDescription.format._backing {
|
||||||
|
case .der:
|
||||||
|
certificate = try? Certificate(derEncoded: certificateBytes)
|
||||||
|
|
||||||
|
case .pem:
|
||||||
|
certificate = String(bytes: certificateBytes, encoding: .utf8)
|
||||||
|
.flatMap { try? Certificate(pemEncoded: $0) }
|
||||||
|
}
|
||||||
|
return certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
private func parsePrivateKey(from keyBytes: [UInt8]) -> Certificate.PrivateKey? {
|
||||||
|
let key: Certificate.PrivateKey?
|
||||||
|
switch self.privateKeyDescription.format._backing {
|
||||||
|
case .der:
|
||||||
|
key = try? Certificate.PrivateKey(derBytes: keyBytes)
|
||||||
|
|
||||||
|
case .pem:
|
||||||
|
key = String(bytes: keyBytes, encoding: .utf8)
|
||||||
|
.flatMap { try? Certificate.PrivateKey(pemEncoded: $0) }
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
private func attemptToUpdatePair(certificate: Certificate, key: Certificate.PrivateKey) {
|
||||||
|
let nioSSLCertificate = try? NIOSSLCertificate(
|
||||||
|
bytes: certificate.serializeAsPEM().derBytes,
|
||||||
|
format: .der
|
||||||
|
)
|
||||||
|
let nioSSLPrivateKey = try? NIOSSLPrivateKey(
|
||||||
|
bytes: key.serializeAsPEM().derBytes,
|
||||||
|
format: .der
|
||||||
|
)
|
||||||
|
|
||||||
|
guard let nioSSLCertificate, let nioSSLPrivateKey else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
self.state.withLockedValue {
|
||||||
|
$0 = CertificateKeyPair(
|
||||||
|
certificate: .certificate(nioSSLCertificate),
|
||||||
|
privateKey: .privateKey(nioSSLPrivateKey)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
//
|
//
|
||||||
// This source file is part of the SwiftNIO open source project
|
// This source file is part of the SwiftNIO open source project
|
||||||
//
|
//
|
||||||
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
|
// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors
|
||||||
// Licensed under Apache License v2.0
|
// Licensed under Apache License v2.0
|
||||||
//
|
//
|
||||||
// See LICENSE.txt for license information
|
// See LICENSE.txt for license information
|
||||||
@ -27,7 +27,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
|
||||||
let override = reloader.sslContextConfigurationOverride
|
let override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertNil(override.certificateChain)
|
XCTAssertNil(override.certificateChain)
|
||||||
XCTAssertNil(override.privateKey)
|
XCTAssertNil(override.privateKey)
|
||||||
@ -45,7 +44,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
|
||||||
let override = reloader.sslContextConfigurationOverride
|
let override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertNil(override.certificateChain)
|
XCTAssertNil(override.certificateChain)
|
||||||
XCTAssertNil(override.privateKey)
|
XCTAssertNil(override.privateKey)
|
||||||
@ -63,7 +61,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
|
||||||
let override = reloader.sslContextConfigurationOverride
|
let override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertNil(override.certificateChain)
|
XCTAssertNil(override.certificateChain)
|
||||||
XCTAssertNil(override.privateKey)
|
XCTAssertNil(override.privateKey)
|
||||||
@ -81,7 +78,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .pem
|
format: .pem
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
|
||||||
let override = reloader.sslContextConfigurationOverride
|
let override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertNil(override.certificateChain)
|
XCTAssertNil(override.certificateChain)
|
||||||
XCTAssertNil(override.privateKey)
|
XCTAssertNil(override.privateKey)
|
||||||
@ -99,7 +95,6 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
|
||||||
let override = reloader.sslContextConfigurationOverride
|
let override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertNil(override.certificateChain)
|
XCTAssertNil(override.certificateChain)
|
||||||
XCTAssertNil(override.privateKey)
|
XCTAssertNil(override.privateKey)
|
||||||
@ -107,7 +102,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testReloadSuccessfully() async throws {
|
func testReloadSuccessfully() async throws {
|
||||||
let certificateBox = NIOLockedValueBox([UInt8]())
|
let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(nil)
|
||||||
try await runTimedCertificateReloaderTest(
|
try await runTimedCertificateReloaderTest(
|
||||||
certificate: .init(
|
certificate: .init(
|
||||||
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
|
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
|
||||||
@ -118,14 +113,19 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
// On first attempt, we should have no certificate or private key overrides available,
|
||||||
|
// since the certificate box is empty.
|
||||||
var override = reloader.sslContextConfigurationOverride
|
var override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertNil(override.certificateChain)
|
XCTAssertNil(override.certificateChain)
|
||||||
XCTAssertNil(override.privateKey)
|
XCTAssertNil(override.privateKey)
|
||||||
|
|
||||||
|
// Update the box to contain a valid certificate.
|
||||||
certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes })
|
certificateBox.withLockedValue({ $0 = try! Self.sampleCert.serializeAsPEM().derBytes })
|
||||||
|
|
||||||
|
// Give the reload loop some time to run and update the cert-key pair.
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||||
|
|
||||||
|
// Now the overrides should be present.
|
||||||
override = reloader.sslContextConfigurationOverride
|
override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
@ -152,7 +152,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
// On first attempt, the overrides should be correctly present.
|
||||||
var override = reloader.sslContextConfigurationOverride
|
var override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
@ -163,9 +163,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Update the box to not contain a certificate.
|
||||||
certificateBox.withLockedValue({ $0 = nil })
|
certificateBox.withLockedValue({ $0 = nil })
|
||||||
|
|
||||||
|
// Give the reload loop some time to run and update the cert-key pair.
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||||
|
|
||||||
|
// We should still be offering the previously valid cert-key pair.
|
||||||
override = reloader.sslContextConfigurationOverride
|
override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
@ -192,7 +196,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
// On first attempt, the overrides should be correctly present.
|
||||||
var override = reloader.sslContextConfigurationOverride
|
var override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
@ -203,9 +207,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Update the box to not contain a key.
|
||||||
keyBox.withLockedValue({ $0 = nil })
|
keyBox.withLockedValue({ $0 = nil })
|
||||||
|
|
||||||
|
// Give the reload loop some time to run and update the cert-key pair.
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||||
|
|
||||||
|
// We should still be offering the previously valid cert-key pair.
|
||||||
override = reloader.sslContextConfigurationOverride
|
override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
@ -232,7 +240,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
format: .der
|
format: .der
|
||||||
)
|
)
|
||||||
) { reloader in
|
) { reloader in
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
// On first attempt, the overrides should be correctly present.
|
||||||
var override = reloader.sslContextConfigurationOverride
|
var override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
@ -243,9 +251,13 @@ final class TimedCertificateReloaderTests: XCTestCase {
|
|||||||
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Update the box to contain a key that does not match the given certificate.
|
||||||
keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) })
|
keyBox.withLockedValue({ $0 = Array(P384.Signing.PrivateKey().derRepresentation) })
|
||||||
|
|
||||||
|
// Give the reload loop some time to run and update the cert-key pair.
|
||||||
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
|
||||||
|
|
||||||
|
// We should still be offering the previously valid cert-key pair.
|
||||||
override = reloader.sslContextConfigurationOverride
|
override = reloader.sslContextConfigurationOverride
|
||||||
XCTAssertEqual(
|
XCTAssertEqual(
|
||||||
override.certificateChain,
|
override.certificateChain,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user