mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 17:02:43 +08:00
Fix incorrect SOCKS client flushing behaviour (#133)
* Add buffering test * Convert to marked buffer * Soundness * Re-add fastpath * Fix
This commit is contained in:
parent
93fc12bdb7
commit
548e0d4893
@ -28,9 +28,9 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
|||||||
private let targetAddress: SOCKSAddress
|
private let targetAddress: SOCKSAddress
|
||||||
|
|
||||||
private var state: ClientStateMachine
|
private var state: ClientStateMachine
|
||||||
private var buffered: ByteBuffer?
|
private var inboundBuffer: ByteBuffer?
|
||||||
|
|
||||||
private var bufferedWrites: CircularBuffer<(NIOAny, EventLoopPromise<Void>?)> = .init()
|
private var bufferedWrites: MarkedCircularBuffer<(NIOAny, EventLoopPromise<Void>?)> = .init(initialCapacity: 8)
|
||||||
|
|
||||||
/// Creates a new `SOCKSClientHandler` that connects to a server
|
/// Creates a new `SOCKSClientHandler` that connects to a server
|
||||||
/// and instructs the server to connect to `targetAddress`.
|
/// and instructs the server to connect to `targetAddress`.
|
||||||
@ -66,11 +66,11 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
|||||||
|
|
||||||
var inboundBuffer = self.unwrapInboundIn(data)
|
var inboundBuffer = self.unwrapInboundIn(data)
|
||||||
|
|
||||||
self.buffered.setOrWriteBuffer(&inboundBuffer)
|
self.inboundBuffer.setOrWriteBuffer(&inboundBuffer)
|
||||||
do {
|
do {
|
||||||
// Safe to bang, `setOrWrite` above means there will
|
// Safe to bang, `setOrWrite` above means there will
|
||||||
// always be a value.
|
// always be a value.
|
||||||
let action = try self.state.receiveBuffer(&self.buffered!)
|
let action = try self.state.receiveBuffer(&self.inboundBuffer!)
|
||||||
try self.handleAction(action, context: context)
|
try self.handleAction(action, context: context)
|
||||||
} catch {
|
} catch {
|
||||||
context.fireErrorCaught(error)
|
context.fireErrorCaught(error)
|
||||||
@ -79,21 +79,28 @@ public final class SOCKSClientHandler: ChannelDuplexHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||||
guard self.state.proxyEstablished else {
|
if self.state.proxyEstablished && self.bufferedWrites.count == 0 {
|
||||||
|
context.write(data, promise: promise)
|
||||||
|
} else {
|
||||||
self.bufferedWrites.append((data, promise))
|
self.bufferedWrites.append((data, promise))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
self.writeBufferedData(context: context)
|
|
||||||
context.write(data, promise: promise)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private func writeBufferedData(context: ChannelHandlerContext) {
|
private func writeBufferedData(context: ChannelHandlerContext) {
|
||||||
while self.bufferedWrites.count > 0 {
|
guard self.state.proxyEstablished else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
while self.bufferedWrites.hasMark {
|
||||||
let (data, promise) = self.bufferedWrites.removeFirst()
|
let (data, promise) = self.bufferedWrites.removeFirst()
|
||||||
context.write(data, promise: promise)
|
context.write(data, promise: promise)
|
||||||
}
|
}
|
||||||
|
context.flush() // safe to flush otherwise we wouldn't have the mark
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public func flush(context: ChannelHandlerContext) {
|
||||||
|
self.bufferedWrites.mark()
|
||||||
|
self.writeBufferedData(context: context)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension SOCKSClientHandler {
|
extension SOCKSClientHandler {
|
||||||
@ -136,8 +143,8 @@ extension SOCKSClientHandler {
|
|||||||
// for some reason we have extra bytes
|
// for some reason we have extra bytes
|
||||||
// so let's send them down the pipe
|
// so let's send them down the pipe
|
||||||
// (Safe to bang, self.buffered will always exist at this point)
|
// (Safe to bang, self.buffered will always exist at this point)
|
||||||
if self.buffered!.readableBytes > 0 {
|
if self.inboundBuffer!.readableBytes > 0 {
|
||||||
let data = self.wrapInboundOut(self.buffered!)
|
let data = self.wrapInboundOut(self.inboundBuffer!)
|
||||||
context.fireChannelRead(data)
|
context.fireChannelRead(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,6 +27,8 @@ extension SocksClientHandlerTests {
|
|||||||
static var allTests : [(String, (SocksClientHandlerTests) -> () throws -> Void)] {
|
static var allTests : [(String, (SocksClientHandlerTests) -> () throws -> Void)] {
|
||||||
return [
|
return [
|
||||||
("testTypicalWorkflow", testTypicalWorkflow),
|
("testTypicalWorkflow", testTypicalWorkflow),
|
||||||
|
("testThatBufferingWorks", testThatBufferingWorks),
|
||||||
|
("testBufferingWithMark", testBufferingWithMark),
|
||||||
("testTypicalWorkflowDripfeed", testTypicalWorkflowDripfeed),
|
("testTypicalWorkflowDripfeed", testTypicalWorkflowDripfeed),
|
||||||
("testInvalidAuthenticationMethod", testInvalidAuthenticationMethod),
|
("testInvalidAuthenticationMethod", testInvalidAuthenticationMethod),
|
||||||
("testProxyConnectionFailed", testProxyConnectionFailed),
|
("testProxyConnectionFailed", testProxyConnectionFailed),
|
||||||
|
@ -71,7 +71,50 @@ class SocksClientHandlerTests: XCTestCase {
|
|||||||
// any inbound data should now go straight through
|
// any inbound data should now go straight through
|
||||||
self.writeInbound([1, 2, 3, 4, 5])
|
self.writeInbound([1, 2, 3, 4, 5])
|
||||||
self.assertInbound([1, 2, 3, 4, 5])
|
self.assertInbound([1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
// any outbound data should also go straight through
|
||||||
|
XCTAssertNoThrow(try self.channel.writeOutbound(ByteBuffer(bytes: [1, 2, 3, 4, 5])))
|
||||||
|
self.assertOutputBuffer([1, 2, 3, 4, 5])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that if we write alot of data at the start then
|
||||||
|
// that data will be written after the client has completed
|
||||||
|
// the socks handshake.
|
||||||
|
func testThatBufferingWorks() {
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
|
||||||
|
self.channel.writeAndFlush(ByteBuffer(bytes: [1, 2, 3, 4, 5]), promise: writePromise)
|
||||||
|
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||||
|
self.writeInbound([0x05, 0x00])
|
||||||
|
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||||
|
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||||
|
|
||||||
|
XCTAssertNoThrow(try writePromise.futureResult.wait())
|
||||||
|
self.assertOutputBuffer([1, 2, 3, 4, 5])
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBufferingWithMark() {
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
let writePromise1 = self.channel.eventLoop.makePromise(of: Void.self)
|
||||||
|
let writePromise2 = self.channel.eventLoop.makePromise(of: Void.self)
|
||||||
|
self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: writePromise1)
|
||||||
|
self.channel.flush()
|
||||||
|
self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: writePromise2)
|
||||||
|
|
||||||
|
self.assertOutputBuffer([0x05, 0x01, 0x00])
|
||||||
|
self.writeInbound([0x05, 0x00])
|
||||||
|
self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||||
|
self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50])
|
||||||
|
|
||||||
|
XCTAssertNoThrow(try writePromise1.futureResult.wait())
|
||||||
|
self.assertOutputBuffer([1, 2, 3])
|
||||||
|
|
||||||
|
XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait())
|
||||||
|
XCTAssertNoThrow(try writePromise2.futureResult.wait())
|
||||||
|
self.assertOutputBuffer([4, 5, 6])
|
||||||
|
self.assertOutputBuffer([7, 8, 9])
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTypicalWorkflowDripfeed() {
|
func testTypicalWorkflowDripfeed() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user