mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-16 01:42:29 +08:00
Add SOCKS authentication fast path (#134)
Add a fast path to SOCKS auto-authenticate when the server selects noneRequired as the SOCKS server authentication method.
This commit is contained in:
parent
548e0d4893
commit
b03d835bca
@ -71,14 +71,17 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC
|
|||||||
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||||
do {
|
do {
|
||||||
let message = self.unwrapOutboundIn(data)
|
let message = self.unwrapOutboundIn(data)
|
||||||
|
let outboundBuffer: ByteBuffer
|
||||||
switch message {
|
switch message {
|
||||||
case .selectedAuthenticationMethod(let method):
|
case .selectedAuthenticationMethod(let method):
|
||||||
try self.handleWriteSelectedAuthenticationMethod(method, context: context, promise: promise)
|
outboundBuffer = try self.handleWriteSelectedAuthenticationMethod(method, context: context)
|
||||||
case .response(let response):
|
case .response(let response):
|
||||||
try self.handleWriteResponse(response, context: context, promise: promise)
|
outboundBuffer = try self.handleWriteResponse(response, context: context)
|
||||||
case .authenticationData(let data, let complete):
|
case .authenticationData(let data, let complete):
|
||||||
try self.handleWriteAuthenticationData(data, complete: complete, context: context, promise: promise)
|
outboundBuffer = try self.handleWriteAuthenticationData(data, complete: complete, context: context)
|
||||||
}
|
}
|
||||||
|
context.write(self.wrapOutboundOut(outboundBuffer), promise: promise)
|
||||||
|
|
||||||
} catch {
|
} catch {
|
||||||
context.fireErrorCaught(error)
|
context.fireErrorCaught(error)
|
||||||
promise?.fail(error)
|
promise?.fail(error)
|
||||||
@ -86,31 +89,25 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC
|
|||||||
}
|
}
|
||||||
|
|
||||||
private func handleWriteSelectedAuthenticationMethod(
|
private func handleWriteSelectedAuthenticationMethod(
|
||||||
_ method: SelectedAuthenticationMethod, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws {
|
_ method: SelectedAuthenticationMethod, context: ChannelHandlerContext) throws -> ByteBuffer {
|
||||||
try stateMachine.sendAuthenticationMethod(method)
|
try stateMachine.sendAuthenticationMethod(method)
|
||||||
var buffer = context.channel.allocator.buffer(capacity: 16)
|
var buffer = context.channel.allocator.buffer(capacity: 16)
|
||||||
buffer.writeMethodSelection(method)
|
buffer.writeMethodSelection(method)
|
||||||
context.write(self.wrapOutboundOut(buffer), promise: promise)
|
return buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handleWriteResponse(
|
private func handleWriteResponse(
|
||||||
_ response: SOCKSResponse, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws {
|
_ response: SOCKSResponse, context: ChannelHandlerContext) throws -> ByteBuffer {
|
||||||
try stateMachine.sendServerResponse(response)
|
try stateMachine.sendServerResponse(response)
|
||||||
var buffer = context.channel.allocator.buffer(capacity: 16)
|
var buffer = context.channel.allocator.buffer(capacity: 16)
|
||||||
buffer.writeServerResponse(response)
|
buffer.writeServerResponse(response)
|
||||||
context.write(self.wrapOutboundOut(buffer), promise: promise)
|
return buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handleWriteAuthenticationData(_ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws {
|
private func handleWriteAuthenticationData(
|
||||||
do {
|
_ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext) throws -> ByteBuffer {
|
||||||
try self.stateMachine.sendData()
|
try self.stateMachine.sendAuthenticationData(data, complete: complete)
|
||||||
if complete {
|
return data
|
||||||
try self.stateMachine.authenticationComplete()
|
|
||||||
}
|
|
||||||
context.write(self.wrapOutboundOut(data), promise: promise)
|
|
||||||
} catch {
|
|
||||||
promise?.fail(error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -28,6 +28,7 @@ enum ServerState: Hashable {
|
|||||||
struct ServerStateMachine: Hashable {
|
struct ServerStateMachine: Hashable {
|
||||||
|
|
||||||
private var state: ServerState
|
private var state: ServerState
|
||||||
|
private var authenticationMethod: AuthenticationMethod?
|
||||||
|
|
||||||
var proxyEstablished: Bool {
|
var proxyEstablished: Bool {
|
||||||
switch self.state {
|
switch self.state {
|
||||||
@ -118,7 +119,7 @@ extension ServerStateMachine {
|
|||||||
self.state = .waitingForClientGreeting
|
self.state = .waitingForClientGreeting
|
||||||
}
|
}
|
||||||
|
|
||||||
mutating func sendAuthenticationMethod(_ method: SelectedAuthenticationMethod) throws {
|
mutating func sendAuthenticationMethod(_ selected: SelectedAuthenticationMethod) throws {
|
||||||
switch self.state {
|
switch self.state {
|
||||||
case .waitingToSendAuthenticationMethod:
|
case .waitingToSendAuthenticationMethod:
|
||||||
()
|
()
|
||||||
@ -131,8 +132,14 @@ extension ServerStateMachine {
|
|||||||
.error:
|
.error:
|
||||||
throw SOCKSError.InvalidServerState()
|
throw SOCKSError.InvalidServerState()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.authenticationMethod = selected.method
|
||||||
|
if selected.method == .noneRequired {
|
||||||
|
self.state = .waitingForClientRequest
|
||||||
|
} else {
|
||||||
self.state = .authenticating
|
self.state = .authenticating
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mutating func sendServerResponse(_ response: SOCKSResponse) throws {
|
mutating func sendServerResponse(_ response: SOCKSResponse) throws {
|
||||||
switch self.state {
|
switch self.state {
|
||||||
@ -155,35 +162,25 @@ extension ServerStateMachine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mutating func sendData() throws {
|
mutating func sendAuthenticationData(_ data: ByteBuffer, complete: Bool) throws {
|
||||||
switch self.state {
|
switch self.state {
|
||||||
case .authenticating:
|
case .authenticating:
|
||||||
()
|
break
|
||||||
case .inactive,
|
case .waitingForClientRequest:
|
||||||
.waitingForClientGreeting,
|
guard self.authenticationMethod == .noneRequired, complete, data.readableBytes == 0 else {
|
||||||
.waitingToSendAuthenticationMethod,
|
|
||||||
.waitingForClientRequest,
|
|
||||||
.waitingToSendResponse,
|
|
||||||
.active,
|
|
||||||
.error:
|
|
||||||
throw SOCKSError.InvalidServerState()
|
throw SOCKSError.InvalidServerState()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
mutating func authenticationComplete() throws {
|
|
||||||
switch self.state {
|
|
||||||
case .authenticating:
|
|
||||||
()
|
|
||||||
case .inactive,
|
case .inactive,
|
||||||
.waitingForClientGreeting,
|
.waitingForClientGreeting,
|
||||||
.waitingToSendAuthenticationMethod,
|
.waitingToSendAuthenticationMethod,
|
||||||
.waitingForClientRequest,
|
|
||||||
.waitingToSendResponse,
|
.waitingToSendResponse,
|
||||||
.active,
|
.active,
|
||||||
.error:
|
.error:
|
||||||
throw SOCKSError.InvalidServerState()
|
throw SOCKSError.InvalidServerState()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if complete {
|
||||||
self.state = .waitingForClientRequest
|
self.state = .waitingForClientRequest
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,11 @@ extension SOCKSServerHandlerTests {
|
|||||||
("testOutboundErrorsAreHandled", testOutboundErrorsAreHandled),
|
("testOutboundErrorsAreHandled", testOutboundErrorsAreHandled),
|
||||||
("testFlushOnHandlerRemoved", testFlushOnHandlerRemoved),
|
("testFlushOnHandlerRemoved", testFlushOnHandlerRemoved),
|
||||||
("testForceHandlerRemovalAfterAuth", testForceHandlerRemovalAfterAuth),
|
("testForceHandlerRemovalAfterAuth", testForceHandlerRemovalAfterAuth),
|
||||||
|
("testAutoAuthenticationComplete", testAutoAuthenticationComplete),
|
||||||
|
("testAutoAuthenticationCompleteWithManualCompletion", testAutoAuthenticationCompleteWithManualCompletion),
|
||||||
|
("testEagerClientRequestBeforeAuthenticationComplete", testEagerClientRequestBeforeAuthenticationComplete),
|
||||||
|
("testManualAuthenticationFailureExtraBytes", testManualAuthenticationFailureExtraBytes),
|
||||||
|
("testManualAuthenticationFailureInvalidCompletion", testManualAuthenticationFailureInvalidCompletion),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,7 +152,7 @@ class SOCKSServerHandlerTests: XCTestCase {
|
|||||||
|
|
||||||
// tests dripfeeding to ensure we buffer data correctly
|
// tests dripfeeding to ensure we buffer data correctly
|
||||||
func testTypicalWorkflowDripfeed() {
|
func testTypicalWorkflowDripfeed() {
|
||||||
let expectedGreeting = ClientGreeting(methods: [.noneRequired])
|
let expectedGreeting = ClientGreeting(methods: [.gssapi])
|
||||||
let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))
|
let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))
|
||||||
let expectedData = ByteBuffer(string: "1234")
|
let expectedData = ByteBuffer(string: "1234")
|
||||||
let testHandler = PromiseTestHandler(
|
let testHandler = PromiseTestHandler(
|
||||||
@ -168,16 +168,15 @@ class SOCKSServerHandlerTests: XCTestCase {
|
|||||||
self.assertOutputBuffer([])
|
self.assertOutputBuffer([])
|
||||||
self.writeInbound([0x01])
|
self.writeInbound([0x01])
|
||||||
self.assertOutputBuffer([])
|
self.assertOutputBuffer([])
|
||||||
self.writeInbound([0x00])
|
self.writeInbound([0x01])
|
||||||
self.assertOutputBuffer([])
|
self.assertOutputBuffer([])
|
||||||
XCTAssertTrue(testHandler.hadGreeting)
|
XCTAssertTrue(testHandler.hadGreeting)
|
||||||
|
|
||||||
// write the auth selection
|
// write the auth selection
|
||||||
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired))))
|
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .gssapi))))
|
||||||
self.assertOutputBuffer([0x05, 0x00])
|
self.assertOutputBuffer([0x05, 0x01])
|
||||||
|
|
||||||
// finish authentication - nothing should be written
|
// finish authentication with some bytes
|
||||||
// as this is informing the state machine only
|
|
||||||
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true)))
|
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true)))
|
||||||
self.assertOutputBuffer([0xFF, 0xFF])
|
self.assertOutputBuffer([0xFF, 0xFF])
|
||||||
|
|
||||||
@ -217,10 +216,11 @@ class SOCKSServerHandlerTests: XCTestCase {
|
|||||||
func testForceHandlerRemovalAfterAuth() {
|
func testForceHandlerRemovalAfterAuth() {
|
||||||
|
|
||||||
// go through auth
|
// go through auth
|
||||||
self.writeInbound([0x05, 0x01, 0x00])
|
self.writeInbound([0x05, 0x01, 0x01])
|
||||||
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
|
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi)))
|
||||||
self.assertOutputBuffer([0x05, 0x00])
|
self.assertOutputBuffer([0x05, 0x01])
|
||||||
XCTAssertNoThrow(try self.handler.stateMachine.authenticationComplete())
|
self.writeOutbound(.authenticationData(ByteBuffer(), complete: true))
|
||||||
|
self.assertOutputBuffer([])
|
||||||
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
|
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
|
||||||
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
@ -229,4 +229,88 @@ class SOCKSServerHandlerTests: XCTestCase {
|
|||||||
// removing the handler, it should fail
|
// removing the handler, it should fail
|
||||||
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false)))
|
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testAutoAuthenticationComplete() {
|
||||||
|
|
||||||
|
// server selects none-required, this should mean we can continue without
|
||||||
|
// having to manually inform the state machine
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00])
|
||||||
|
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
|
||||||
|
self.assertOutputBuffer([0x05, 0x00])
|
||||||
|
|
||||||
|
// if we try and write the request then the data would be read
|
||||||
|
// as authentication data, and so the server wouldn't reply
|
||||||
|
// with a response
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
|
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
|
||||||
|
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAutoAuthenticationCompleteWithManualCompletion() {
|
||||||
|
|
||||||
|
// server selects none-required, this should mean we can continue without
|
||||||
|
// having to manually inform the state machine. However, informing the state
|
||||||
|
// machine manually shouldn't break anything.
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00])
|
||||||
|
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
|
||||||
|
self.assertOutputBuffer([0x05, 0x00])
|
||||||
|
|
||||||
|
// complete authentication, but nothing should be written
|
||||||
|
// to the network
|
||||||
|
self.writeOutbound(.authenticationData(ByteBuffer(), complete: true))
|
||||||
|
self.assertOutputBuffer([])
|
||||||
|
|
||||||
|
// if we try and write the request then the data would be read
|
||||||
|
// as authentication data, and so the server wouldn't reply
|
||||||
|
// with a response
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
|
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
|
||||||
|
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEagerClientRequestBeforeAuthenticationComplete() {
|
||||||
|
|
||||||
|
// server selects none-required, this should mean we can continue without
|
||||||
|
// having to manually inform the state machine. However, informing the state
|
||||||
|
// machine manually shouldn't break anything.
|
||||||
|
self.writeInbound([0x05, 0x01, 0x01])
|
||||||
|
self.assertInbound(.greeting(.init(methods: [.gssapi])))
|
||||||
|
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi)))
|
||||||
|
self.assertOutputBuffer([0x05, 0x01])
|
||||||
|
|
||||||
|
// at this point authentication isn't complete
|
||||||
|
// so if the client sends a request then the
|
||||||
|
// server will read those as authentication bytes
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
|
self.assertInbound(.authenticationData(ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testManualAuthenticationFailureExtraBytes() {
|
||||||
|
// server selects none-required, this should mean we can continue without
|
||||||
|
// having to manually inform the state machine. However, informing the state
|
||||||
|
// machine manually shouldn't break anything.
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00])
|
||||||
|
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
|
||||||
|
self.assertOutputBuffer([0x05, 0x00])
|
||||||
|
|
||||||
|
// invalid authentication completion
|
||||||
|
// we've selected `noneRequired`, so no
|
||||||
|
// bytes should be written
|
||||||
|
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0x00]), complete: true)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testManualAuthenticationFailureInvalidCompletion() {
|
||||||
|
// server selects none-required, this should mean we can continue without
|
||||||
|
// having to manually inform the state machine. However, informing the state
|
||||||
|
// machine manually shouldn't break anything.
|
||||||
|
self.writeInbound([0x05, 0x01, 0x00])
|
||||||
|
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
|
||||||
|
self.assertOutputBuffer([0x05, 0x00])
|
||||||
|
|
||||||
|
// invalid authentication completion
|
||||||
|
// authentication should have already completed
|
||||||
|
// as we selected `noneRequired`, so sending
|
||||||
|
// `complete = false` should be an error
|
||||||
|
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: []), complete: false)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,9 +34,6 @@ public class ServerStateMachineTests: XCTestCase {
|
|||||||
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
|
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
|
||||||
XCTAssertFalse(stateMachine.proxyEstablished)
|
XCTAssertFalse(stateMachine.proxyEstablished)
|
||||||
|
|
||||||
// authentication is now finished, as we didn't send any
|
|
||||||
XCTAssertNoThrow(try stateMachine.authenticationComplete())
|
|
||||||
|
|
||||||
// send the client request
|
// send the client request
|
||||||
var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
|
||||||
XCTAssertNoThrow(try stateMachine.receiveBuffer(&request))
|
XCTAssertNoThrow(try stateMachine.receiveBuffer(&request))
|
||||||
@ -61,7 +58,6 @@ public class ServerStateMachineTests: XCTestCase {
|
|||||||
XCTAssertNoThrow(try stateMachine.connectionEstablished())
|
XCTAssertNoThrow(try stateMachine.connectionEstablished())
|
||||||
XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting))
|
XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting))
|
||||||
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
|
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
|
||||||
XCTAssertNoThrow(try stateMachine.authenticationComplete())
|
|
||||||
|
|
||||||
// write some invalid bytes from the client
|
// write some invalid bytes from the client
|
||||||
// the state machine should throw
|
// the state machine should throw
|
||||||
|
Loading…
x
Reference in New Issue
Block a user