Refactor QuiescingHelper to exhaustively iterate state (#193)

# Motivation
Currently the `QuiescingHelper` is crashing on a precondition if you call shutdown when it already was shutdown. However, that can totally happen and we should support it.

# Modification
Refactor the `QuiescingHelper` to exhaustively switch over its state in every method. Furthermore, I added a few more test cases to test realistic scenarios.

# Result
We are now reliable checking our state and making sure to allow most transitions.
This commit is contained in:
Franz Busch 2023-02-24 17:24:37 +00:00 committed by GitHub
parent 6bd9bf5c29
commit d75ed708d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 264 additions and 71 deletions

View File

@ -23,23 +23,25 @@ private enum ShutdownError: Error {
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
private final class ChannelCollector {
enum LifecycleState {
case upAndRunning
case shuttingDown
case upAndRunning(
openChannels: [ObjectIdentifier: Channel],
serverChannel: Channel
)
case shuttingDown(
openChannels: [ObjectIdentifier: Channel],
fullyShutdownPromise: EventLoopPromise<Void>
)
case shutdownCompleted
}
private var openChannels: [ObjectIdentifier: Channel] = [:]
private let serverChannel: Channel
private var fullyShutdownPromise: EventLoopPromise<Void>? = nil
private var lifecycleState = LifecycleState.upAndRunning
private var lifecycleState: LifecycleState
private var eventLoop: EventLoop {
return self.serverChannel.eventLoop
}
private let eventLoop: EventLoop
/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
init(serverChannel: Channel) {
self.serverChannel = serverChannel
self.eventLoop = serverChannel.eventLoop
self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
}
/// Add a channel to the `ChannelCollector`.
@ -51,30 +53,64 @@ private final class ChannelCollector {
func channelAdded(_ channel: Channel) throws {
self.eventLoop.assertInEventLoop()
guard self.lifecycleState != .shutdownCompleted else {
switch self.lifecycleState {
case .upAndRunning(var openChannels, let serverChannel):
openChannels[ObjectIdentifier(channel)] = channel
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
case .shuttingDown(var openChannels, let fullyShutdownPromise):
openChannels[ObjectIdentifier(channel)] = channel
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
case .shutdownCompleted:
channel.close(promise: nil)
throw ShutdownError.alreadyShutdown
}
self.openChannels[ObjectIdentifier(channel)] = channel
}
private func shutdownCompleted() {
self.eventLoop.assertInEventLoop()
assert(self.lifecycleState == .shuttingDown)
self.lifecycleState = .shutdownCompleted
self.fullyShutdownPromise?.succeed(())
switch self.lifecycleState {
case .upAndRunning:
preconditionFailure("This can never happen because we transition to shuttingDown first")
case .shuttingDown(_, let fullyShutdownPromise):
self.lifecycleState = .shutdownCompleted
fullyShutdownPromise.succeed(())
case .shutdownCompleted:
preconditionFailure("We should only complete the shutdown once")
}
}
private func channelRemoved0(_ channel: Channel) {
self.eventLoop.assertInEventLoop()
precondition(self.openChannels.keys.contains(ObjectIdentifier(channel)),
"channel \(channel) not in ChannelCollector \(self.openChannels)")
self.openChannels.removeValue(forKey: ObjectIdentifier(channel))
if self.lifecycleState != .upAndRunning && self.openChannels.isEmpty {
shutdownCompleted()
switch self.lifecycleState {
case .upAndRunning(var openChannels, let serverChannel):
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
case .shuttingDown(var openChannels, let fullyShutdownPromise):
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
if openChannels.isEmpty {
self.shutdownCompleted()
} else {
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
}
case .shutdownCompleted:
preconditionFailure("We should not have channels removed after transitioned to completed")
}
}
@ -96,44 +132,39 @@ private final class ChannelCollector {
private func initiateShutdown0(promise: EventLoopPromise<Void>?) {
self.eventLoop.assertInEventLoop()
precondition(self.lifecycleState == .upAndRunning)
self.lifecycleState = .shuttingDown
switch self.lifecycleState {
case .upAndRunning(let openChannels, let serverChannel):
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)
if let promise = promise {
if let alreadyExistingPromise = self.fullyShutdownPromise {
alreadyExistingPromise.futureResult.cascade(to: promise)
} else {
self.fullyShutdownPromise = promise
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
serverChannel.close().cascadeFailure(to: fullyShutdownPromise)
for channel in openChannels.values {
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
}
}
self.serverChannel.close().cascadeFailure(to: self.fullyShutdownPromise)
for channel in self.openChannels.values {
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
if openChannels.isEmpty {
self.shutdownCompleted()
}
}
if self.openChannels.isEmpty {
shutdownCompleted()
case .shuttingDown(_, let fullyShutdownPromise):
fullyShutdownPromise.futureResult.cascade(to: promise)
case .shutdownCompleted:
promise?.succeed(())
}
}
/// Initiate the shutdown fulfilling `promise` when all the previously registered `Channel`s have been closed.
///
/// - parameters:
/// - promise: The `EventLoopPromise` to fulfill when the shutdown of all previously registered `Channel`s has been completed.
/// - promise: The `EventLoopPromise` to fulfil when the shutdown of all previously registered `Channel`s has been completed.
func initiateShutdown(promise: EventLoopPromise<Void>?) {
if self.serverChannel.eventLoop.inEventLoop {
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
} else {
self.eventLoop.execute {
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
}
if self.eventLoop.inEventLoop {
self.initiateShutdown0(promise: promise)
} else {
@ -144,7 +175,6 @@ private final class ChannelCollector {
}
}
extension ChannelCollector: @unchecked Sendable {}
/// A `ChannelHandler` that adds all channels that it receives through the `ChannelPipeline` to a `ChannelCollector`.
@ -173,7 +203,7 @@ private final class CollectAcceptedChannelsHandler: ChannelInboundHandler {
do {
try self.channelCollector.channelAdded(channel)
let closeFuture = channel.closeFuture
closeFuture.whenComplete { (_: Result<(), Error>) in
closeFuture.whenComplete { (_: Result<Void, Error>) in
self.channelCollector.channelRemoved(channel)
}
context.fireChannelRead(data)
@ -231,7 +261,7 @@ public final class ServerQuiescingHelper {
deinit {
self.channelCollectorPromise.fail(UnusedQuiescingHelperError())
}
/// Create the `ChannelHandler` for the server `channel` to collect all accepted child `Channel`s.
///
/// - parameters:
@ -262,6 +292,4 @@ public final class ServerQuiescingHelper {
}
}
extension ServerQuiescingHelper: Sendable {
}
extension ServerQuiescingHelper: Sendable {}

View File

@ -31,6 +31,11 @@ extension QuiescingHelperTest {
("testQuiesceUserEventReceivedOnShutdown", testQuiesceUserEventReceivedOnShutdown),
("testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler", testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler),
("testShutdownIsImmediateWhenPromiseDoesNotSucceed", testShutdownIsImmediateWhenPromiseDoesNotSucceed),
("testShutdown_whenAlreadyShutdown", testShutdown_whenAlreadyShutdown),
("testShutdown_whenNoOpenChild", testShutdown_whenNoOpenChild),
("testChannelClose_whenRunning", testChannelClose_whenRunning),
("testChannelAdded_whenShuttingDown", testChannelAdded_whenShuttingDown),
("testChannelAdded_whenShutdown", testChannelAdded_whenShutdown),
]
}
}

View File

@ -12,12 +12,27 @@
//
//===----------------------------------------------------------------------===//
import XCTest
import NIOCore
import NIOEmbedded
@testable import NIOExtras
import NIOPosix
import NIOTestUtils
@testable import NIOExtras
import XCTest
private final class WaitForQuiesceUserEvent: ChannelInboundHandler {
typealias InboundIn = Never
private let promise: EventLoopPromise<Void>
init(promise: EventLoopPromise<Void>) {
self.promise = promise
}
func userInboundEventTriggered(context _: ChannelHandlerContext, event: Any) {
if event is ChannelShouldQuiesceEvent {
self.promise.succeed(())
}
}
}
public class QuiescingHelperTest: XCTestCase {
func testShutdownIsImmediateWhenNoChannelsCollected() throws {
@ -35,21 +50,6 @@ public class QuiescingHelperTest: XCTestCase {
}
func testQuiesceUserEventReceivedOnShutdown() throws {
class WaitForQuiesceUserEvent: ChannelInboundHandler {
typealias InboundIn = Never
private let promise: EventLoopPromise<Void>
init(promise: EventLoopPromise<Void>) {
self.promise = promise
}
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event is ChannelShouldQuiesceEvent {
self.promise.succeed(())
}
}
}
let el = EmbeddedEventLoop()
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
@ -63,7 +63,7 @@ public class QuiescingHelperTest: XCTestCase {
// add a bunch of channels
for pretendPort in 1...128 {
let waitForPromise: EventLoopPromise<()> = el.makePromise()
let waitForPromise: EventLoopPromise<Void> = el.makePromise()
let channel = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise), loop: el)
// activate the child chan
XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "1.2.3.4", port: pretendPort)).wait())
@ -137,7 +137,7 @@ public class QuiescingHelperTest: XCTestCase {
}
}
///verifying that the promise fails when goes out of scope for shutdown
/// verifying that the promise fails when goes out of scope for shutdown
func testShutdownIsImmediateWhenPromiseDoesNotSucceed() throws {
let el = EmbeddedEventLoop()
@ -151,4 +151,164 @@ public class QuiescingHelperTest: XCTestCase {
XCTAssertTrue(error is ServerQuiescingHelper.UnusedQuiescingHelperError)
}
}
func testShutdown_whenAlreadyShutdown() throws {
let el = EmbeddedEventLoop()
let channel = EmbeddedChannel(handler: nil, loop: el)
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
XCTAssertTrue(channel.isActive)
let quiesce = ServerQuiescingHelper(group: el)
_ = quiesce.makeServerChannelHandler(channel: channel)
let p1: EventLoopPromise<Void> = el.makePromise()
quiesce.initiateShutdown(promise: p1)
XCTAssertNoThrow(try p1.futureResult.wait())
XCTAssertFalse(channel.isActive)
let p2: EventLoopPromise<Void> = el.makePromise()
quiesce.initiateShutdown(promise: p2)
XCTAssertNoThrow(try p2.futureResult.wait())
}
func testShutdown_whenNoOpenChild() throws {
let el = EmbeddedEventLoop()
let channel = EmbeddedChannel(handler: nil, loop: el)
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
XCTAssertTrue(channel.isActive)
let quiesce = ServerQuiescingHelper(group: el)
_ = quiesce.makeServerChannelHandler(channel: channel)
let p1: EventLoopPromise<Void> = el.makePromise()
quiesce.initiateShutdown(promise: p1)
el.run()
XCTAssertNoThrow(try p1.futureResult.wait())
XCTAssertFalse(channel.isActive)
}
func testChannelClose_whenRunning() throws {
let el = EmbeddedEventLoop()
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
let quiesce = ServerQuiescingHelper(group: el)
let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel)
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait())
// let's one channels
let eventCounterHandler = EventCounterHandler()
let childChannel1 = EmbeddedChannel(handler: eventCounterHandler, loop: el)
// activate the child channel
XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait())
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1))
// check that the server channel and channel are active before initiating the shutdown
XCTAssertTrue(serverChannel.isActive)
XCTAssertTrue(childChannel1.isActive)
XCTAssertEqual(eventCounterHandler.userInboundEventTriggeredCalls, 0)
// now close the first child channel
childChannel1.close(promise: nil)
el.run()
// check that the server is active and child is not
XCTAssertTrue(serverChannel.isActive)
XCTAssertFalse(childChannel1.isActive)
quiesce.initiateShutdown(promise: allShutdownPromise)
el.run()
// check that the server channel is closed as the first thing
XCTAssertFalse(serverChannel.isActive)
el.run()
// check that the shutdown has completed
XCTAssertNoThrow(try allShutdownPromise.futureResult.wait())
}
func testChannelAdded_whenShuttingDown() throws {
let el = EmbeddedEventLoop()
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
let quiesce = ServerQuiescingHelper(group: el)
let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel)
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait())
// let's add one channel
let waitForPromise1: EventLoopPromise<Void> = el.makePromise()
let childChannel1 = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise1), loop: el)
// activate the child channel
XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait())
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1))
el.run()
// check that the server and channel are running
XCTAssertTrue(serverChannel.isActive)
XCTAssertTrue(childChannel1.isActive)
// let's shut down
quiesce.initiateShutdown(promise: allShutdownPromise)
// let's add one more channel
let waitForPromise2: EventLoopPromise<Void> = el.makePromise()
let childChannel2 = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise2), loop: el)
// activate the child channel
XCTAssertNoThrow(try childChannel2.connect(to: .init(ipAddress: "1.2.3.4", port: 2)).wait())
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel2))
el.run()
// Check that we got all quiescing events
XCTAssertNoThrow(try waitForPromise1.futureResult.wait() as Void)
XCTAssertNoThrow(try waitForPromise2.futureResult.wait() as Void)
// check that the server is closed and the children are running
XCTAssertFalse(serverChannel.isActive)
XCTAssertTrue(childChannel1.isActive)
XCTAssertTrue(childChannel2.isActive)
// let's close the children
childChannel1.close(promise: nil)
childChannel2.close(promise: nil)
el.run()
// check that everything is closed
XCTAssertFalse(serverChannel.isActive)
XCTAssertFalse(childChannel1.isActive)
XCTAssertFalse(childChannel2.isActive)
XCTAssertNoThrow(try allShutdownPromise.futureResult.wait() as Void)
}
func testChannelAdded_whenShutdown() throws {
let el = EmbeddedEventLoop()
let allShutdownPromise: EventLoopPromise<Void> = el.makePromise()
let serverChannel = EmbeddedChannel(handler: nil, loop: el)
// let's activate the server channel, nothing actually happens as this is an EmbeddedChannel
XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait())
let quiesce = ServerQuiescingHelper(group: el)
let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel)
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait())
// check that the server is running
XCTAssertTrue(serverChannel.isActive)
// let's shut down
quiesce.initiateShutdown(promise: allShutdownPromise)
// check that the shutdown has completed
XCTAssertNoThrow(try allShutdownPromise.futureResult.wait())
// let's add one channel
let childChannel1 = EmbeddedChannel(loop: el)
// activate the child channel
XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait())
serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1))
el.run()
}
}