Ensuring Handler removal works

This commit is contained in:
Fabian Fett 2021-06-22 14:43:04 +02:00
parent b9efbabd28
commit 7fda0c4981
3 changed files with 122 additions and 13 deletions

View File

@ -28,6 +28,7 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
private let targetAddress: SOCKSAddress
private var state: ClientStateMachine
private var removalToken: ChannelHandlerContext.RemovalToken?
private var inboundBuffer: ByteBuffer?
private var bufferedWrites: MarkedCircularBuffer<(NIOAny, EventLoopPromise<Void>?)> = .init(initialCapacity: 8)
@ -95,6 +96,11 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
context.write(data, promise: promise)
}
context.flush() // safe to flush otherwise we wouldn't have the mark
while !self.bufferedWrites.isEmpty {
let (data, promise) = self.bufferedWrites.removeFirst()
context.write(data, promise: promise)
}
}
public func flush(context: ChannelHandlerContext) {
@ -140,18 +146,13 @@ extension SOCKSClientHandler {
}
private func handleProxyEstablished(context: ChannelHandlerContext) {
// for some reason we have extra bytes
// so let's send them down the pipe
// (Safe to bang, self.buffered will always exist at this point)
if self.inboundBuffer!.readableBytes > 0 {
let data = self.wrapInboundOut(self.inboundBuffer!)
context.fireChannelRead(data)
}
// If we have any buffered writes then now
// we can send them.
self.writeBufferedData(context: context)
context.fireUserInboundEventTriggered(SOCKSProxyEstablishedEvent())
self.emptyInboundAndOutboundBuffer(context: context)
if let removalToken = self.removalToken {
context.leavePipeline(removalToken: removalToken)
}
}
private func handleActionSendRequest(context: ChannelHandlerContext) throws {
@ -166,14 +167,33 @@ extension SOCKSClientHandler {
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
private func emptyInboundAndOutboundBuffer(context: ChannelHandlerContext) {
if let inboundBuffer = self.inboundBuffer, inboundBuffer.readableBytes > 0 {
// after the SOCKS handshake message we already received further bytes.
// so let's send them down the pipe
self.inboundBuffer = nil
context.fireChannelRead(self.wrapInboundOut(inboundBuffer))
}
// If we have any buffered writes, we must send them before we are removed from the pipeline
self.writeBufferedData(context: context)
}
}
extension SOCKSClientHandler: RemovableChannelHandler {
public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
guard self.state.proxyEstablished else {
preconditionFailure("The SOCKSClientHandler can only be removed once a connection has been established")
self.removalToken = removalToken
return
}
// We must clear the buffers here before we are removed, since the
// handler removal may be triggered as a side effect of the
// `SOCKSProxyEstablishedEvent`. In this case we may end up here,
// before the buffer empty method in `handleProxyEstablished` is
// invoked.
self.emptyInboundAndOutboundBuffer(context: context)
context.leavePipeline(removalToken: removalToken)
}

View File

@ -34,6 +34,8 @@ extension SocksClientHandlerTests {
("testProxyConnectionFailed", testProxyConnectionFailed),
("testDelayedConnection", testDelayedConnection),
("testDelayedHandlerAdded", testDelayedHandlerAdded),
("testHandlerRemovalAfterEstablishEvent", testHandlerRemovalAfterEstablishEvent),
("testHandlerRemovalBeforeConnectionIsEstablished", testHandlerRemovalBeforeConnectionIsEstablished),
]
}
}

View File

@ -236,7 +236,94 @@ class SocksClientHandlerTests: XCTestCase {
XCTAssertNoThrow(self.channel.pipeline.addHandler(handler))
self.assertOutputBuffer([0x05, 0x01, 0x00])
}
func testHandlerRemovalAfterEstablishEvent() {
class SOCKSEventHandler: ChannelInboundHandler {
typealias InboundIn = NIOAny
var establishedPromise: EventLoopPromise<Void>
init(establishedPromise: EventLoopPromise<Void>) {
self.establishedPromise = establishedPromise
}
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case is SOCKSProxyEstablishedEvent:
self.establishedPromise.succeed(())
default:
break
}
context.fireUserInboundEventTriggered(event)
}
}
let establishPromise = self.channel.eventLoop.makePromise(of: Void.self)
let removalPromise = self.channel.eventLoop.makePromise(of: Void.self)
establishPromise.futureResult.whenSuccess { _ in
self.channel.pipeline.removeHandler(self.handler).cascade(to: removalPromise)
}
try! self.channel.pipeline.addHandler(SOCKSEventHandler(establishedPromise: establishPromise)).wait()
self.connect()
self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil)
self.channel.flush()
self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil)
self.assertOutputBuffer([0x05, 0x01, 0x00])
self.writeInbound([0x05, 0x00])
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
self.assertOutputBuffer([1, 2, 3])
XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait())
self.assertOutputBuffer([4, 5, 6])
self.assertOutputBuffer([7, 8, 9])
XCTAssertNoThrow(try removalPromise.futureResult.wait())
XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) {
XCTAssertEqual($0 as? ChannelPipelineError, .notFound)
}
}
func testHandlerRemovalBeforeConnectionIsEstablished() {
self.connect()
self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil)
self.channel.flush()
self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil)
self.assertOutputBuffer([0x05, 0x01, 0x00])
self.writeInbound([0x05, 0x00])
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
// we try to remove the handler before the connection is established.
let removalPromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.pipeline.removeHandler(self.handler, promise: removalPromise)
// establishes the connection
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
// write six more bytes - those should be passed through right away
self.writeInbound([1, 2, 3, 4, 5, 6])
self.assertInbound([1, 2, 3, 4, 5, 6])
self.assertOutputBuffer([1, 2, 3])
XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait())
self.assertOutputBuffer([4, 5, 6])
self.assertOutputBuffer([7, 8, 9])
XCTAssertNoThrow(try removalPromise.futureResult.wait())
XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) {
XCTAssertEqual($0 as? ChannelPipelineError, .notFound)
}
}
}
class MockSOCKSClientHandler: ChannelInboundHandler {