mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 08:52:42 +08:00
Fix incorrect SOCKS client flushing behaviour (#133)
* Add buffering test * Convert to marked buffer * Soundness * Re-add fastpath * Fix
This commit is contained in:
parent
93fc12bdb7
commit
548e0d4893
@ -28,9 +28,9 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
||||
private let targetAddress: SOCKSAddress
|
||||
|
||||
private var state: ClientStateMachine
|
||||
private var buffered: ByteBuffer?
|
||||
private var inboundBuffer: ByteBuffer?
|
||||
|
||||
private var bufferedWrites: CircularBuffer<(NIOAny, EventLoopPromise<Void>?)> = .init()
|
||||
private var bufferedWrites: MarkedCircularBuffer<(NIOAny, EventLoopPromise<Void>?)> = .init(initialCapacity: 8)
|
||||
|
||||
/// Creates a new `SOCKSClientHandler` that connects to a server
|
||||
/// and instructs the server to connect to `targetAddress`.
|
||||
@ -66,11 +66,11 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
||||
|
||||
var inboundBuffer = self.unwrapInboundIn(data)
|
||||
|
||||
self.buffered.setOrWriteBuffer(&inboundBuffer)
|
||||
self.inboundBuffer.setOrWriteBuffer(&inboundBuffer)
|
||||
do {
|
||||
// Safe to bang, `setOrWrite` above means there will
|
||||
// always be a value.
|
||||
let action = try self.state.receiveBuffer(&self.buffered!)
|
||||
let action = try self.state.receiveBuffer(&self.inboundBuffer!)
|
||||
try self.handleAction(action, context: context)
|
||||
} catch {
|
||||
context.fireErrorCaught(error)
|
||||
@ -79,21 +79,28 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
||||
}
|
||||
|
||||
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||
guard self.state.proxyEstablished else {
|
||||
if self.state.proxyEstablished && self.bufferedWrites.count == 0 {
|
||||
context.write(data, promise: promise)
|
||||
} else {
|
||||
self.bufferedWrites.append((data, promise))
|
||||
return
|
||||
}
|
||||
self.writeBufferedData(context: context)
|
||||
context.write(data, promise: promise)
|
||||
}
|
||||
|
||||
private func writeBufferedData(context: ChannelHandlerContext) {
|
||||
while self.bufferedWrites.count > 0 {
|
||||
guard self.state.proxyEstablished else {
|
||||
return
|
||||
}
|
||||
while self.bufferedWrites.hasMark {
|
||||
let (data, promise) = self.bufferedWrites.removeFirst()
|
||||
context.write(data, promise: promise)
|
||||
}
|
||||
context.flush() // safe to flush otherwise we wouldn't have the mark
|
||||
}
|
||||
|
||||
public func flush(context: ChannelHandlerContext) {
|
||||
self.bufferedWrites.mark()
|
||||
self.writeBufferedData(context: context)
|
||||
}
|
||||
}
|
||||
|
||||
extension SOCKSClientHandler {
|
||||
@ -136,8 +143,8 @@ extension SOCKSClientHandler {
|
||||
// for some reason we have extra bytes
|
||||
// so let's send them down the pipe
|
||||
// (Safe to bang, self.buffered will always exist at this point)
|
||||
if self.buffered!.readableBytes > 0 {
|
||||
let data = self.wrapInboundOut(self.buffered!)
|
||||
if self.inboundBuffer!.readableBytes > 0 {
|
||||
let data = self.wrapInboundOut(self.inboundBuffer!)
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,8 @@ extension SocksClientHandlerTests {
|
||||
static var allTests : [(String, (SocksClientHandlerTests) -> () throws -> Void)] {
|
||||
return [
|
||||
("testTypicalWorkflow", testTypicalWorkflow),
|
||||
("testThatBufferingWorks", testThatBufferingWorks),
|
||||
("testBufferingWithMark", testBufferingWithMark),
|
||||
("testTypicalWorkflowDripfeed", testTypicalWorkflowDripfeed),
|
||||
("testInvalidAuthenticationMethod", testInvalidAuthenticationMethod),
|
||||
("testProxyConnectionFailed", testProxyConnectionFailed),
|
||||
|
@ -71,7 +71,50 @@ class SocksClientHandlerTests: XCTestCase {
|
||||
// any inbound data should now go straight through
|
||||
self.writeInbound([1, 2, 3, 4, 5])
|
||||
self.assertInbound([1, 2, 3, 4, 5])
|
||||
|
||||
// any outbound data should also go straight through
|
||||
XCTAssertNoThrow(try self.channel.writeOutbound(ByteBuffer(bytes: [1, 2, 3, 4, 5])))
|
||||
self.assertOutputBuffer([1, 2, 3, 4, 5])
|
||||
}
|
||||
|
||||
// Tests that if we write alot of data at the start then
|
||||
// that data will be written after the client has completed
|
||||
// the socks handshake.
|
||||
func testThatBufferingWorks() {
|
||||
self.connect()
|
||||
|
||||
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
|
||||
self.channel.writeAndFlush(ByteBuffer(bytes: [1, 2, 3, 4, 5]), promise: writePromise)
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||
self.writeInbound([0x05, 0x00])
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
|
||||
XCTAssertNoThrow(try writePromise.futureResult.wait())
|
||||
self.assertOutputBuffer([1, 2, 3, 4, 5])
|
||||
}
|
||||
|
||||
func testBufferingWithMark() {
|
||||
self.connect()
|
||||
|
||||
let writePromise1 = self.channel.eventLoop.makePromise(of: Void.self)
|
||||
let writePromise2 = self.channel.eventLoop.makePromise(of: Void.self)
|
||||
self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: writePromise1)
|
||||
self.channel.flush()
|
||||
self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: writePromise2)
|
||||
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||
self.writeInbound([0x05, 0x00])
|
||||
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||
|
||||
XCTAssertNoThrow(try writePromise1.futureResult.wait())
|
||||
self.assertOutputBuffer([1, 2, 3])
|
||||
|
||||
XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait())
|
||||
XCTAssertNoThrow(try writePromise2.futureResult.wait())
|
||||
self.assertOutputBuffer([4, 5, 6])
|
||||
self.assertOutputBuffer([7, 8, 9])
|
||||
}
|
||||
|
||||
func testTypicalWorkflowDripfeed() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user