diff --git a/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift b/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift index 8c64cb5..8604b93 100644 --- a/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift +++ b/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift @@ -28,6 +28,7 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { private let targetAddress: SOCKSAddress private var state: ClientStateMachine + private var removalToken: ChannelHandlerContext.RemovalToken? private var inboundBuffer: ByteBuffer? private var bufferedWrites: MarkedCircularBuffer<(NIOAny, EventLoopPromise?)> = .init(initialCapacity: 8) @@ -95,6 +96,11 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { context.write(data, promise: promise) } context.flush() // safe to flush otherwise we wouldn't have the mark + + while !self.bufferedWrites.isEmpty { + let (data, promise) = self.bufferedWrites.removeFirst() + context.write(data, promise: promise) + } } public func flush(context: ChannelHandlerContext) { @@ -140,18 +146,13 @@ extension SOCKSClientHandler { } private func handleProxyEstablished(context: ChannelHandlerContext) { - // for some reason we have extra bytes - // so let's send them down the pipe - // (Safe to bang, self.buffered will always exist at this point) - if self.inboundBuffer!.readableBytes > 0 { - let data = self.wrapInboundOut(self.inboundBuffer!) - context.fireChannelRead(data) - } - - // If we have any buffered writes then now - // we can send them. - self.writeBufferedData(context: context) context.fireUserInboundEventTriggered(SOCKSProxyEstablishedEvent()) + + self.emptyInboundAndOutboundBuffer(context: context) + + if let removalToken = self.removalToken { + context.leavePipeline(removalToken: removalToken) + } } private func handleActionSendRequest(context: ChannelHandlerContext) throws { @@ -166,14 +167,33 @@ extension SOCKSClientHandler { context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) } + private func emptyInboundAndOutboundBuffer(context: ChannelHandlerContext) { + if let inboundBuffer = self.inboundBuffer, inboundBuffer.readableBytes > 0 { + // after the SOCKS handshake message we already received further bytes. + // so let's send them down the pipe + self.inboundBuffer = nil + context.fireChannelRead(self.wrapInboundOut(inboundBuffer)) + } + + // If we have any buffered writes, we must send them before we are removed from the pipeline + self.writeBufferedData(context: context) + } } extension SOCKSClientHandler: RemovableChannelHandler { public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) { guard self.state.proxyEstablished else { - preconditionFailure("The SOCKSClientHandler can only be removed once a connection has been established") + self.removalToken = removalToken + return } + + // We must clear the buffers here before we are removed, since the + // handler removal may be triggered as a side effect of the + // `SOCKSProxyEstablishedEvent`. In this case we may end up here, + // before the buffer empty method in `handleProxyEstablished` is + // invoked. + self.emptyInboundAndOutboundBuffer(context: context) context.leavePipeline(removalToken: removalToken) } diff --git a/Tests/NIOSOCKSTests/SocksClientHandler+Tests+XCTest.swift b/Tests/NIOSOCKSTests/SocksClientHandler+Tests+XCTest.swift index 2a74969..50cab22 100644 --- a/Tests/NIOSOCKSTests/SocksClientHandler+Tests+XCTest.swift +++ b/Tests/NIOSOCKSTests/SocksClientHandler+Tests+XCTest.swift @@ -34,6 +34,8 @@ extension SocksClientHandlerTests { ("testProxyConnectionFailed", testProxyConnectionFailed), ("testDelayedConnection", testDelayedConnection), ("testDelayedHandlerAdded", testDelayedHandlerAdded), + ("testHandlerRemovalAfterEstablishEvent", testHandlerRemovalAfterEstablishEvent), + ("testHandlerRemovalBeforeConnectionIsEstablished", testHandlerRemovalBeforeConnectionIsEstablished), ] } } diff --git a/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift b/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift index 9c8f063..2f0b4b9 100644 --- a/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift +++ b/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift @@ -236,7 +236,94 @@ class SocksClientHandlerTests: XCTestCase { XCTAssertNoThrow(self.channel.pipeline.addHandler(handler)) self.assertOutputBuffer([0x05, 0x01, 0x00]) } - + + func testHandlerRemovalAfterEstablishEvent() { + class SOCKSEventHandler: ChannelInboundHandler { + typealias InboundIn = NIOAny + + var establishedPromise: EventLoopPromise + + init(establishedPromise: EventLoopPromise) { + self.establishedPromise = establishedPromise + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case is SOCKSProxyEstablishedEvent: + self.establishedPromise.succeed(()) + default: + break + } + context.fireUserInboundEventTriggered(event) + } + } + + let establishPromise = self.channel.eventLoop.makePromise(of: Void.self) + let removalPromise = self.channel.eventLoop.makePromise(of: Void.self) + establishPromise.futureResult.whenSuccess { _ in + self.channel.pipeline.removeHandler(self.handler).cascade(to: removalPromise) + } + + try! self.channel.pipeline.addHandler(SOCKSEventHandler(establishedPromise: establishPromise)).wait() + + self.connect() + + self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil) + self.channel.flush() + self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil) + + self.assertOutputBuffer([0x05, 0x01, 0x00]) + self.writeInbound([0x05, 0x00]) + self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) + self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) + + self.assertOutputBuffer([1, 2, 3]) + + XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait()) + + self.assertOutputBuffer([4, 5, 6]) + self.assertOutputBuffer([7, 8, 9]) + + XCTAssertNoThrow(try removalPromise.futureResult.wait()) + XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) { + XCTAssertEqual($0 as? ChannelPipelineError, .notFound) + } + } + + func testHandlerRemovalBeforeConnectionIsEstablished() { + self.connect() + + self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil) + self.channel.flush() + self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil) + + self.assertOutputBuffer([0x05, 0x01, 0x00]) + self.writeInbound([0x05, 0x00]) + self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) + + // we try to remove the handler before the connection is established. + let removalPromise = self.channel.eventLoop.makePromise(of: Void.self) + self.channel.pipeline.removeHandler(self.handler, promise: removalPromise) + + // establishes the connection + self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) + + // write six more bytes - those should be passed through right away + self.writeInbound([1, 2, 3, 4, 5, 6]) + self.assertInbound([1, 2, 3, 4, 5, 6]) + + self.assertOutputBuffer([1, 2, 3]) + + XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait()) + + self.assertOutputBuffer([4, 5, 6]) + self.assertOutputBuffer([7, 8, 9]) + + XCTAssertNoThrow(try removalPromise.futureResult.wait()) + XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) { + XCTAssertEqual($0 as? ChannelPipelineError, .notFound) + } + } } class MockSOCKSClientHandler: ChannelInboundHandler {