diff --git a/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift b/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift index 9545cad..a12bf94 100644 --- a/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift +++ b/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift @@ -19,7 +19,7 @@ import NIO /// and parser to enforce SOCKSv5 protocol correctness. Inbound bytes will by parsed into /// `ClientMessage` for downstream consumption. Send `ServerMessage` to this /// handler. -public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler { +public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableChannelHandler { public typealias InboundIn = ByteBuffer public typealias InboundOut = ClientMessage @@ -61,6 +61,13 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler { } } + public func handlerRemoved(context: ChannelHandlerContext) { + guard let buffer = self.inboundBuffer else { + return + } + context.fireChannelRead(.init(buffer)) + } + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { do { let message = self.unwrapOutboundIn(data) diff --git a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift index 1a7f702..35ca154 100644 --- a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift +++ b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift @@ -91,8 +91,11 @@ class SOCKSServerHandlerTests: XCTestCase { } func assertInbound(_ bytes: [UInt8], line: UInt = #line) { - var buffer = try! self.channel.readInbound(as: ByteBuffer.self) - XCTAssertEqual(buffer!.readBytes(length: buffer!.readableBytes), bytes, line: line) + if var buffer = try! self.channel.readInbound(as: ByteBuffer.self) { + XCTAssertEqual(buffer.readBytes(length: buffer.readableBytes), bytes, line: line) + } else { + XCTAssertTrue(bytes.count == 0) + } } func testTypicalWorkflow() { @@ -198,4 +201,11 @@ class SOCKSServerHandlerTests: XCTestCase { XCTAssertTrue(e is SOCKSError.InvalidServerState) } } + + func testFlushOnHandlerRemoved() { + self.writeInbound([0x05, 0x01]) + self.assertInbound([]) + XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.handler).wait()) + self.assertInbound([0x05, 0x01]) + } }