mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 17:02:43 +08:00
Refactor QuiescingHelper
to exhaustively iterate state (#193)
# Motivation Currently the `QuiescingHelper` is crashing on a precondition if you call shutdown when it already was shutdown. However, that can totally happen and we should support it. # Modification Refactor the `QuiescingHelper` to exhaustively switch over its state in every method. Furthermore, I added a few more test cases to test realistic scenarios. # Result We are now reliable checking our state and making sure to allow most transitions.
This commit is contained in:
parent
6bd9bf5c29
commit
d75ed708d0
@ -23,23 +23,25 @@ private enum ShutdownError: Error {
|
||||
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
|
||||
private final class ChannelCollector {
|
||||
enum LifecycleState {
|
||||
case upAndRunning
|
||||
case shuttingDown
|
||||
case upAndRunning(
|
||||
openChannels: [ObjectIdentifier: Channel],
|
||||
serverChannel: Channel
|
||||
)
|
||||
case shuttingDown(
|
||||
openChannels: [ObjectIdentifier: Channel],
|
||||
fullyShutdownPromise: EventLoopPromise<Void>
|
||||
)
|
||||
case shutdownCompleted
|
||||
}
|
||||
|
||||
private var openChannels: [ObjectIdentifier: Channel] = [:]
|
||||
private let serverChannel: Channel
|
||||
private var fullyShutdownPromise: EventLoopPromise<Void>? = nil
|
||||
private var lifecycleState = LifecycleState.upAndRunning
|
||||
private var lifecycleState: LifecycleState
|
||||
|
||||
private var eventLoop: EventLoop {
|
||||
return self.serverChannel.eventLoop
|
||||
}
|
||||
private let eventLoop: EventLoop
|
||||
|
||||
/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
|
||||
init(serverChannel: Channel) {
|
||||
self.serverChannel = serverChannel
|
||||
self.eventLoop = serverChannel.eventLoop
|
||||
self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
|
||||
}
|
||||
|
||||
/// Add a channel to the `ChannelCollector`.
|
||||
@ -51,30 +53,64 @@ private final class ChannelCollector {
|
||||
func channelAdded(_ channel: Channel) throws {
|
||||
self.eventLoop.assertInEventLoop()
|
||||
|
||||
guard self.lifecycleState != .shutdownCompleted else {
|
||||
switch self.lifecycleState {
|
||||
case .upAndRunning(var openChannels, let serverChannel):
|
||||
openChannels[ObjectIdentifier(channel)] = channel
|
||||
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
|
||||
|
||||
case .shuttingDown(var openChannels, let fullyShutdownPromise):
|
||||
openChannels[ObjectIdentifier(channel)] = channel
|
||||
channel.eventLoop.execute {
|
||||
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
|
||||
}
|
||||
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
|
||||
|
||||
case .shutdownCompleted:
|
||||
channel.close(promise: nil)
|
||||
throw ShutdownError.alreadyShutdown
|
||||
}
|
||||
|
||||
self.openChannels[ObjectIdentifier(channel)] = channel
|
||||
}
|
||||
|
||||
private func shutdownCompleted() {
|
||||
self.eventLoop.assertInEventLoop()
|
||||
assert(self.lifecycleState == .shuttingDown)
|
||||
|
||||
self.lifecycleState = .shutdownCompleted
|
||||
self.fullyShutdownPromise?.succeed(())
|
||||
switch self.lifecycleState {
|
||||
case .upAndRunning:
|
||||
preconditionFailure("This can never happen because we transition to shuttingDown first")
|
||||
|
||||
case .shuttingDown(_, let fullyShutdownPromise):
|
||||
self.lifecycleState = .shutdownCompleted
|
||||
fullyShutdownPromise.succeed(())
|
||||
|
||||
case .shutdownCompleted:
|
||||
preconditionFailure("We should only complete the shutdown once")
|
||||
}
|
||||
}
|
||||
|
||||
private func channelRemoved0(_ channel: Channel) {
|
||||
self.eventLoop.assertInEventLoop()
|
||||
precondition(self.openChannels.keys.contains(ObjectIdentifier(channel)),
|
||||
"channel \(channel) not in ChannelCollector \(self.openChannels)")
|
||||
|
||||
self.openChannels.removeValue(forKey: ObjectIdentifier(channel))
|
||||
if self.lifecycleState != .upAndRunning && self.openChannels.isEmpty {
|
||||
shutdownCompleted()
|
||||
switch self.lifecycleState {
|
||||
case .upAndRunning(var openChannels, let serverChannel):
|
||||
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
|
||||
|
||||
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
|
||||
|
||||
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
|
||||
|
||||
case .shuttingDown(var openChannels, let fullyShutdownPromise):
|
||||
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
|
||||
|
||||
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
|
||||
|
||||
if openChannels.isEmpty {
|
||||
self.shutdownCompleted()
|
||||
} else {
|
||||
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
|
||||
}
|
||||
|
||||
case .shutdownCompleted:
|
||||
preconditionFailure("We should not have channels removed after transitioned to completed")
|
||||
}
|
||||
}
|
||||
|
||||
@ -96,44 +132,39 @@ private final class ChannelCollector {
|
||||
|
||||
private func initiateShutdown0(promise: EventLoopPromise<Void>?) {
|
||||
self.eventLoop.assertInEventLoop()
|
||||
precondition(self.lifecycleState == .upAndRunning)
|
||||
|
||||
self.lifecycleState = .shuttingDown
|
||||
switch self.lifecycleState {
|
||||
case .upAndRunning(let openChannels, let serverChannel):
|
||||
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)
|
||||
|
||||
if let promise = promise {
|
||||
if let alreadyExistingPromise = self.fullyShutdownPromise {
|
||||
alreadyExistingPromise.futureResult.cascade(to: promise)
|
||||
} else {
|
||||
self.fullyShutdownPromise = promise
|
||||
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
|
||||
|
||||
serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
|
||||
serverChannel.close().cascadeFailure(to: fullyShutdownPromise)
|
||||
|
||||
for channel in openChannels.values {
|
||||
channel.eventLoop.execute {
|
||||
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.serverChannel.close().cascadeFailure(to: self.fullyShutdownPromise)
|
||||
|
||||
for channel in self.openChannels.values {
|
||||
channel.eventLoop.execute {
|
||||
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
|
||||
if openChannels.isEmpty {
|
||||
self.shutdownCompleted()
|
||||
}
|
||||
}
|
||||
|
||||
if self.openChannels.isEmpty {
|
||||
shutdownCompleted()
|
||||
case .shuttingDown(_, let fullyShutdownPromise):
|
||||
fullyShutdownPromise.futureResult.cascade(to: promise)
|
||||
|
||||
case .shutdownCompleted:
|
||||
promise?.succeed(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Initiate the shutdown fulfilling `promise` when all the previously registered `Channel`s have been closed.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - promise: The `EventLoopPromise` to fulfill when the shutdown of all previously registered `Channel`s has been completed.
|
||||
/// - promise: The `EventLoopPromise` to fulfil when the shutdown of all previously registered `Channel`s has been completed.
|
||||
func initiateShutdown(promise: EventLoopPromise<Void>?) {
|
||||
if self.serverChannel.eventLoop.inEventLoop {
|
||||
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
|
||||
} else {
|
||||
self.eventLoop.execute {
|
||||
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
|
||||
}
|
||||
}
|
||||
|
||||
if self.eventLoop.inEventLoop {
|
||||
self.initiateShutdown0(promise: promise)
|
||||
} else {
|
||||
@ -144,7 +175,6 @@ private final class ChannelCollector {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
extension ChannelCollector: @unchecked Sendable {}
|
||||
|
||||
/// A `ChannelHandler` that adds all channels that it receives through the `ChannelPipeline` to a `ChannelCollector`.
|
||||
@ -173,7 +203,7 @@ private final class CollectAcceptedChannelsHandler: ChannelInboundHandler {
|
||||
do {
|
||||
try self.channelCollector.channelAdded(channel)
|
||||
let closeFuture = channel.closeFuture
|
||||
closeFuture.whenComplete { (_: Result<(), Error>) in
|
||||
closeFuture.whenComplete { (_: Result<Void, Error>) in
|
||||
self.channelCollector.channelRemoved(channel)
|
||||
}
|
||||
context.fireChannelRead(data)
|
||||
@ -231,7 +261,7 @@ public final class ServerQuiescingHelper {
|
||||
deinit {
|
||||
self.channelCollectorPromise.fail(UnusedQuiescingHelperError())
|
||||
}
|
||||
|
||||
|
||||
/// Create the `ChannelHandler` for the server `channel` to collect all accepted child `Channel`s.
|
||||
///
|
||||
/// - parameters:
|
||||
@ -262,6 +292,4 @@ public final class ServerQuiescingHelper {
|
||||
}
|
||||
}
|
||||
|
||||
extension ServerQuiescingHelper: Sendable {
|
||||
|
||||
}
|
||||
extension ServerQuiescingHelper: Sendable {}
|
||||
|
@ -31,6 +31,11 @@ extension QuiescingHelperTest {
|
||||
("testQuiesceUserEventReceivedOnShutdown", testQuiesceUserEventReceivedOnShutdown),
|
||||
("testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler", testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler),
|
||||
("testShutdownIsImmediateWhenPromiseDoesNotSucceed", testShutdownIsImmediateWhenPromiseDoesNotSucceed),
|
||||
("testShutdown_whenAlreadyShutdown", testShutdown_whenAlreadyShutdown),
|
||||
("testShutdown_whenNoOpenChild", testShutdown_whenNoOpenChild),
|
||||
("testChannelClose_whenRunning", testChannelClose_whenRunning),
|
||||
("testChannelAdded_whenShuttingDown", testChannelAdded_whenShuttingDown),
|
||||
("testChannelAdded_whenShutdown", testChannelAdded_whenShutdown),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -12,12 +12,27 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import XCTest
|
||||
import NIOCore
|
||||
import NIOEmbedded
|
||||
@testable import NIOExtras
|
||||
import NIOPosix
|
||||
import NIOTestUtils
|
||||
@testable import NIOExtras
|
||||
import XCTest
|
||||
|
||||
private final class WaitForQuiesceUserEvent: ChannelInboundHandler {
|
||||
typealias InboundIn = Never
|
||||
private let promise: EventLoopPromise<Void>
|
||||
|
||||
init(promise: EventLoopPromise<Void>) {
|
||||
self.promise = promise
|
||||
}
|
||||
|
||||
func userInboundEventTriggered(context _: ChannelHandlerContext, event: Any) {
|
||||
if event is ChannelShouldQuiesceEvent {
|
||||
self.promise.succeed(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class QuiescingHelperTest: XCTestCase {
|
||||
func testShutdownIsImmediateWhenNoChannelsCollected() throws {
|
||||
@ -35,21 +50,6 @@ public class QuiescingHelperTest: XCTestCase {
|
||||
}
|
||||
|
||||
func testQuiesceUserEventReceivedOnShutdown() throws {
|
||||
class WaitForQuiesceUserEvent: ChannelInboundHandler {
|
||||
typealias InboundIn = Never
|
||||
private let promise: EventLoopPromise<Void>
|
||||
|
||||
init(promise: EventLoopPromise<Void>) {
|
||||
self.promise = promise
|
||||
}
|
||||
|
||||
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
||||
if event is ChannelShouldQuiesceEvent {
|
||||
self.promise.succeed(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let el = EmbeddedEventLoop()
|
||||
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
|
||||
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
|
||||
@ -63,7 +63,7 @@ public class QuiescingHelperTest: XCTestCase {
|
||||
|
||||
// add a bunch of channels
|
||||
for pretendPort in 1...128 {
|
||||
let waitForPromise: EventLoopPromise<()> = el.makePromise()
|
||||
let waitForPromise: EventLoopPromise<Void> = el.makePromise()
|
||||
let channel = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise), loop: el)
|
||||
// activate the child chan
|
||||
XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "1.2.3.4", port: pretendPort)).wait())
|
||||
@ -137,7 +137,7 @@ public class QuiescingHelperTest: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
///verifying that the promise fails when goes out of scope for shutdown
|
||||
/// verifying that the promise fails when goes out of scope for shutdown
|
||||
func testShutdownIsImmediateWhenPromiseDoesNotSucceed() throws {
|
||||
let el = EmbeddedEventLoop()
|
||||
|
||||
@ -151,4 +151,164 @@ public class QuiescingHelperTest: XCTestCase {
|
||||
XCTAssertTrue(error is ServerQuiescingHelper.UnusedQuiescingHelperError)
|
||||
}
|
||||
}
|
||||
|
||||
func testShutdown_whenAlreadyShutdown() throws {
|
||||
let el = EmbeddedEventLoop()
|
||||
let channel = EmbeddedChannel(handler: nil, loop: el)
|
||||
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
|
||||
XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
XCTAssertTrue(channel.isActive)
|
||||
let quiesce = ServerQuiescingHelper(group: el)
|
||||
_ = quiesce.makeServerChannelHandler(channel: channel)
|
||||
let p1: EventLoopPromise<Void> = el.makePromise()
|
||||
quiesce.initiateShutdown(promise: p1)
|
||||
XCTAssertNoThrow(try p1.futureResult.wait())
|
||||
XCTAssertFalse(channel.isActive)
|
||||
|
||||
let p2: EventLoopPromise<Void> = el.makePromise()
|
||||
quiesce.initiateShutdown(promise: p2)
|
||||
XCTAssertNoThrow(try p2.futureResult.wait())
|
||||
}
|
||||
|
||||
func testShutdown_whenNoOpenChild() throws {
|
||||
let el = EmbeddedEventLoop()
|
||||
let channel = EmbeddedChannel(handler: nil, loop: el)
|
||||
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
|
||||
XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
XCTAssertTrue(channel.isActive)
|
||||
let quiesce = ServerQuiescingHelper(group: el)
|
||||
_ = quiesce.makeServerChannelHandler(channel: channel)
|
||||
let p1: EventLoopPromise<Void> = el.makePromise()
|
||||
quiesce.initiateShutdown(promise: p1)
|
||||
el.run()
|
||||
XCTAssertNoThrow(try p1.futureResult.wait())
|
||||
XCTAssertFalse(channel.isActive)
|
||||
}
|
||||
|
||||
func testChannelClose_whenRunning() throws {
|
||||
let el = EmbeddedEventLoop()
|
||||
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
|
||||
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
|
||||
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
|
||||
XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
let quiesce = ServerQuiescingHelper(group: el)
|
||||
let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel)
|
||||
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait())
|
||||
|
||||
// let's one channels
|
||||
let eventCounterHandler = EventCounterHandler()
|
||||
let childChannel1 = EmbeddedChannel(handler: eventCounterHandler, loop: el)
|
||||
// activate the child channel
|
||||
XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait())
|
||||
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1))
|
||||
|
||||
// check that the server channel and channel are active before initiating the shutdown
|
||||
XCTAssertTrue(serverChannel.isActive)
|
||||
XCTAssertTrue(childChannel1.isActive)
|
||||
|
||||
XCTAssertEqual(eventCounterHandler.userInboundEventTriggeredCalls, 0)
|
||||
|
||||
// now close the first child channel
|
||||
childChannel1.close(promise: nil)
|
||||
el.run()
|
||||
|
||||
// check that the server is active and child is not
|
||||
XCTAssertTrue(serverChannel.isActive)
|
||||
XCTAssertFalse(childChannel1.isActive)
|
||||
|
||||
quiesce.initiateShutdown(promise: allShutdownPromise)
|
||||
el.run()
|
||||
|
||||
// check that the server channel is closed as the first thing
|
||||
XCTAssertFalse(serverChannel.isActive)
|
||||
|
||||
el.run()
|
||||
|
||||
// check that the shutdown has completed
|
||||
XCTAssertNoThrow(try allShutdownPromise.futureResult.wait())
|
||||
}
|
||||
|
||||
func testChannelAdded_whenShuttingDown() throws {
|
||||
let el = EmbeddedEventLoop()
|
||||
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
|
||||
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
|
||||
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
|
||||
XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
let quiesce = ServerQuiescingHelper(group: el)
|
||||
let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel)
|
||||
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait())
|
||||
|
||||
// let's add one channel
|
||||
let waitForPromise1: EventLoopPromise<Void> = el.makePromise()
|
||||
let childChannel1 = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise1), loop: el)
|
||||
// activate the child channel
|
||||
XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait())
|
||||
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1))
|
||||
|
||||
el.run()
|
||||
|
||||
// check that the server and channel are running
|
||||
XCTAssertTrue(serverChannel.isActive)
|
||||
XCTAssertTrue(childChannel1.isActive)
|
||||
|
||||
// let's shut down
|
||||
quiesce.initiateShutdown(promise: allShutdownPromise)
|
||||
|
||||
// let's add one more channel
|
||||
let waitForPromise2: EventLoopPromise<Void> = el.makePromise()
|
||||
let childChannel2 = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise2), loop: el)
|
||||
// activate the child channel
|
||||
XCTAssertNoThrow(try childChannel2.connect(to: .init(ipAddress: "1.2.3.4", port: 2)).wait())
|
||||
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel2))
|
||||
el.run()
|
||||
|
||||
// Check that we got all quiescing events
|
||||
XCTAssertNoThrow(try waitForPromise1.futureResult.wait() as Void)
|
||||
XCTAssertNoThrow(try waitForPromise2.futureResult.wait() as Void)
|
||||
|
||||
// check that the server is closed and the children are running
|
||||
XCTAssertFalse(serverChannel.isActive)
|
||||
XCTAssertTrue(childChannel1.isActive)
|
||||
XCTAssertTrue(childChannel2.isActive)
|
||||
|
||||
// let's close the children
|
||||
childChannel1.close(promise: nil)
|
||||
childChannel2.close(promise: nil)
|
||||
el.run()
|
||||
|
||||
// check that everything is closed
|
||||
XCTAssertFalse(serverChannel.isActive)
|
||||
XCTAssertFalse(childChannel1.isActive)
|
||||
XCTAssertFalse(childChannel2.isActive)
|
||||
|
||||
XCTAssertNoThrow(try allShutdownPromise.futureResult.wait() as Void)
|
||||
}
|
||||
|
||||
func testChannelAdded_whenShutdown() throws {
|
||||
let el = EmbeddedEventLoop()
|
||||
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
|
||||
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
|
||||
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
|
||||
XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
|
||||
let quiesce = ServerQuiescingHelper(group: el)
|
||||
let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel)
|
||||
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait())
|
||||
|
||||
// check that the server is running
|
||||
XCTAssertTrue(serverChannel.isActive)
|
||||
|
||||
// let's shut down
|
||||
quiesce.initiateShutdown(promise: allShutdownPromise)
|
||||
|
||||
// check that the shutdown has completed
|
||||
XCTAssertNoThrow(try allShutdownPromise.futureResult.wait())
|
||||
|
||||
// let's add one channel
|
||||
let childChannel1 = EmbeddedChannel(loop: el)
|
||||
// activate the child channel
|
||||
XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait())
|
||||
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1))
|
||||
|
||||
el.run()
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user