mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 17:02:43 +08:00
Ensuring Handler removal works
This commit is contained in:
parent
b9efbabd28
commit
7fda0c4981
@ -28,6 +28,7 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
||||
private let targetAddress: SOCKSAddress
|
||||
|
||||
private var state: ClientStateMachine
|
||||
private var removalToken: ChannelHandlerContext.RemovalToken?
|
||||
private var inboundBuffer: ByteBuffer?
|
||||
|
||||
private var bufferedWrites: MarkedCircularBuffer<(NIOAny, EventLoopPromise<Void>?)> = .init(initialCapacity: 8)
|
||||
@ -95,6 +96,11 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
||||
context.write(data, promise: promise)
|
||||
}
|
||||
context.flush() // safe to flush otherwise we wouldn't have the mark
|
||||
|
||||
while !self.bufferedWrites.isEmpty {
|
||||
let (data, promise) = self.bufferedWrites.removeFirst()
|
||||
context.write(data, promise: promise)
|
||||
}
|
||||
}
|
||||
|
||||
public func flush(context: ChannelHandlerContext) {
|
||||
@ -140,18 +146,13 @@ extension SOCKSClientHandler {
|
||||
}
|
||||
|
||||
private func handleProxyEstablished(context: ChannelHandlerContext) {
|
||||
// for some reason we have extra bytes
|
||||
// so let's send them down the pipe
|
||||
// (Safe to bang, self.buffered will always exist at this point)
|
||||
if self.inboundBuffer!.readableBytes > 0 {
|
||||
let data = self.wrapInboundOut(self.inboundBuffer!)
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
|
||||
// If we have any buffered writes then now
|
||||
// we can send them.
|
||||
self.writeBufferedData(context: context)
|
||||
context.fireUserInboundEventTriggered(SOCKSProxyEstablishedEvent())
|
||||
|
||||
self.emptyInboundAndOutboundBuffer(context: context)
|
||||
|
||||
if let removalToken = self.removalToken {
|
||||
context.leavePipeline(removalToken: removalToken)
|
||||
}
|
||||
}
|
||||
|
||||
private func handleActionSendRequest(context: ChannelHandlerContext) throws {
|
||||
@ -166,14 +167,33 @@ extension SOCKSClientHandler {
|
||||
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
|
||||
}
|
||||
|
||||
private func emptyInboundAndOutboundBuffer(context: ChannelHandlerContext) {
|
||||
if let inboundBuffer = self.inboundBuffer, inboundBuffer.readableBytes > 0 {
|
||||
// after the SOCKS handshake message we already received further bytes.
|
||||
// so let's send them down the pipe
|
||||
self.inboundBuffer = nil
|
||||
context.fireChannelRead(self.wrapInboundOut(inboundBuffer))
|
||||
}
|
||||
|
||||
// If we have any buffered writes, we must send them before we are removed from the pipeline
|
||||
self.writeBufferedData(context: context)
|
||||
}
|
||||
}
|
||||
|
||||
extension SOCKSClientHandler: RemovableChannelHandler {
|
||||
|
||||
public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
|
||||
guard self.state.proxyEstablished else {
|
||||
preconditionFailure("The SOCKSClientHandler can only be removed once a connection has been established")
|
||||
self.removalToken = removalToken
|
||||
return
|
||||
}
|
||||
|
||||
// We must clear the buffers here before we are removed, since the
|
||||
// handler removal may be triggered as a side effect of the
|
||||
// `SOCKSProxyEstablishedEvent`. In this case we may end up here,
|
||||
// before the buffer empty method in `handleProxyEstablished` is
|
||||
// invoked.
|
||||
self.emptyInboundAndOutboundBuffer(context: context)
|
||||
context.leavePipeline(removalToken: removalToken)
|
||||
}
|
||||
|
||||
|
@ -34,6 +34,8 @@ extension SocksClientHandlerTests {
|
||||
("testProxyConnectionFailed", testProxyConnectionFailed),
|
||||
("testDelayedConnection", testDelayedConnection),
|
||||
("testDelayedHandlerAdded", testDelayedHandlerAdded),
|
||||
("testHandlerRemovalAfterEstablishEvent", testHandlerRemovalAfterEstablishEvent),
|
||||
("testHandlerRemovalBeforeConnectionIsEstablished", testHandlerRemovalBeforeConnectionIsEstablished),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -236,7 +236,94 @@ class SocksClientHandlerTests: XCTestCase {
|
||||
XCTAssertNoThrow(self.channel.pipeline.addHandler(handler))
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||
}
|
||||
|
||||
|
||||
func testHandlerRemovalAfterEstablishEvent() {
|
||||
class SOCKSEventHandler: ChannelInboundHandler {
|
||||
typealias InboundIn = NIOAny
|
||||
|
||||
var establishedPromise: EventLoopPromise<Void>
|
||||
|
||||
init(establishedPromise: EventLoopPromise<Void>) {
|
||||
self.establishedPromise = establishedPromise
|
||||
}
|
||||
|
||||
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
||||
switch event {
|
||||
case is SOCKSProxyEstablishedEvent:
|
||||
self.establishedPromise.succeed(())
|
||||
default:
|
||||
break
|
||||
}
|
||||
context.fireUserInboundEventTriggered(event)
|
||||
}
|
||||
}
|
||||
|
||||
let establishPromise = self.channel.eventLoop.makePromise(of: Void.self)
|
||||
let removalPromise = self.channel.eventLoop.makePromise(of: Void.self)
|
||||
establishPromise.futureResult.whenSuccess { _ in
|
||||
self.channel.pipeline.removeHandler(self.handler).cascade(to: removalPromise)
|
||||
}
|
||||
|
||||
try! self.channel.pipeline.addHandler(SOCKSEventHandler(establishedPromise: establishPromise)).wait()
|
||||
|
||||
self.connect()
|
||||
|
||||
self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil)
|
||||
self.channel.flush()
|
||||
self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil)
|
||||
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||
self.writeInbound([0x05, 0x00])
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
|
||||
self.assertOutputBuffer([1, 2, 3])
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait())
|
||||
|
||||
self.assertOutputBuffer([4, 5, 6])
|
||||
self.assertOutputBuffer([7, 8, 9])
|
||||
|
||||
XCTAssertNoThrow(try removalPromise.futureResult.wait())
|
||||
XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) {
|
||||
XCTAssertEqual($0 as? ChannelPipelineError, .notFound)
|
||||
}
|
||||
}
|
||||
|
||||
func testHandlerRemovalBeforeConnectionIsEstablished() {
|
||||
self.connect()
|
||||
|
||||
self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil)
|
||||
self.channel.flush()
|
||||
self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil)
|
||||
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||
self.writeInbound([0x05, 0x00])
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
|
||||
// we try to remove the handler before the connection is established.
|
||||
let removalPromise = self.channel.eventLoop.makePromise(of: Void.self)
|
||||
self.channel.pipeline.removeHandler(self.handler, promise: removalPromise)
|
||||
|
||||
// establishes the connection
|
||||
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
|
||||
// write six more bytes - those should be passed through right away
|
||||
self.writeInbound([1, 2, 3, 4, 5, 6])
|
||||
self.assertInbound([1, 2, 3, 4, 5, 6])
|
||||
|
||||
self.assertOutputBuffer([1, 2, 3])
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait())
|
||||
|
||||
self.assertOutputBuffer([4, 5, 6])
|
||||
self.assertOutputBuffer([7, 8, 9])
|
||||
|
||||
XCTAssertNoThrow(try removalPromise.futureResult.wait())
|
||||
XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) {
|
||||
XCTAssertEqual($0 as? ChannelPipelineError, .notFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MockSOCKSClientHandler: ChannelInboundHandler {
|
||||
|
Loading…
x
Reference in New Issue
Block a user