Add SOCKS authentication fast path (#134)

Add a fast path to SOCKS auto-authenticate when the server selects noneRequired as the SOCKS server authentication method.
This commit is contained in:
David Evans 2021-06-17 14:56:32 +01:00 committed by GitHub
parent 548e0d4893
commit b03d835bca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 131 additions and 52 deletions

View File

@ -71,14 +71,17 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) { public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
do { do {
let message = self.unwrapOutboundIn(data) let message = self.unwrapOutboundIn(data)
let outboundBuffer: ByteBuffer
switch message { switch message {
case .selectedAuthenticationMethod(let method): case .selectedAuthenticationMethod(let method):
try self.handleWriteSelectedAuthenticationMethod(method, context: context, promise: promise) outboundBuffer = try self.handleWriteSelectedAuthenticationMethod(method, context: context)
case .response(let response): case .response(let response):
try self.handleWriteResponse(response, context: context, promise: promise) outboundBuffer = try self.handleWriteResponse(response, context: context)
case .authenticationData(let data, let complete): case .authenticationData(let data, let complete):
try self.handleWriteAuthenticationData(data, complete: complete, context: context, promise: promise) outboundBuffer = try self.handleWriteAuthenticationData(data, complete: complete, context: context)
} }
context.write(self.wrapOutboundOut(outboundBuffer), promise: promise)
} catch { } catch {
context.fireErrorCaught(error) context.fireErrorCaught(error)
promise?.fail(error) promise?.fail(error)
@ -86,31 +89,25 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC
} }
private func handleWriteSelectedAuthenticationMethod( private func handleWriteSelectedAuthenticationMethod(
_ method: SelectedAuthenticationMethod, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws { _ method: SelectedAuthenticationMethod, context: ChannelHandlerContext) throws -> ByteBuffer {
try stateMachine.sendAuthenticationMethod(method) try stateMachine.sendAuthenticationMethod(method)
var buffer = context.channel.allocator.buffer(capacity: 16) var buffer = context.channel.allocator.buffer(capacity: 16)
buffer.writeMethodSelection(method) buffer.writeMethodSelection(method)
context.write(self.wrapOutboundOut(buffer), promise: promise) return buffer
} }
private func handleWriteResponse( private func handleWriteResponse(
_ response: SOCKSResponse, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws { _ response: SOCKSResponse, context: ChannelHandlerContext) throws -> ByteBuffer {
try stateMachine.sendServerResponse(response) try stateMachine.sendServerResponse(response)
var buffer = context.channel.allocator.buffer(capacity: 16) var buffer = context.channel.allocator.buffer(capacity: 16)
buffer.writeServerResponse(response) buffer.writeServerResponse(response)
context.write(self.wrapOutboundOut(buffer), promise: promise) return buffer
} }
private func handleWriteAuthenticationData(_ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws { private func handleWriteAuthenticationData(
do { _ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext) throws -> ByteBuffer {
try self.stateMachine.sendData() try self.stateMachine.sendAuthenticationData(data, complete: complete)
if complete { return data
try self.stateMachine.authenticationComplete()
}
context.write(self.wrapOutboundOut(data), promise: promise)
} catch {
promise?.fail(error)
}
} }
} }

View File

@ -28,6 +28,7 @@ enum ServerState: Hashable {
struct ServerStateMachine: Hashable { struct ServerStateMachine: Hashable {
private var state: ServerState private var state: ServerState
private var authenticationMethod: AuthenticationMethod?
var proxyEstablished: Bool { var proxyEstablished: Bool {
switch self.state { switch self.state {
@ -118,7 +119,7 @@ extension ServerStateMachine {
self.state = .waitingForClientGreeting self.state = .waitingForClientGreeting
} }
mutating func sendAuthenticationMethod(_ method: SelectedAuthenticationMethod) throws { mutating func sendAuthenticationMethod(_ selected: SelectedAuthenticationMethod) throws {
switch self.state { switch self.state {
case .waitingToSendAuthenticationMethod: case .waitingToSendAuthenticationMethod:
() ()
@ -131,8 +132,14 @@ extension ServerStateMachine {
.error: .error:
throw SOCKSError.InvalidServerState() throw SOCKSError.InvalidServerState()
} }
self.authenticationMethod = selected.method
if selected.method == .noneRequired {
self.state = .waitingForClientRequest
} else {
self.state = .authenticating self.state = .authenticating
} }
}
mutating func sendServerResponse(_ response: SOCKSResponse) throws { mutating func sendServerResponse(_ response: SOCKSResponse) throws {
switch self.state { switch self.state {
@ -155,35 +162,25 @@ extension ServerStateMachine {
} }
} }
mutating func sendData() throws { mutating func sendAuthenticationData(_ data: ByteBuffer, complete: Bool) throws {
switch self.state { switch self.state {
case .authenticating: case .authenticating:
() break
case .inactive, case .waitingForClientRequest:
.waitingForClientGreeting, guard self.authenticationMethod == .noneRequired, complete, data.readableBytes == 0 else {
.waitingToSendAuthenticationMethod,
.waitingForClientRequest,
.waitingToSendResponse,
.active,
.error:
throw SOCKSError.InvalidServerState() throw SOCKSError.InvalidServerState()
} }
}
mutating func authenticationComplete() throws {
switch self.state {
case .authenticating:
()
case .inactive, case .inactive,
.waitingForClientGreeting, .waitingForClientGreeting,
.waitingToSendAuthenticationMethod, .waitingToSendAuthenticationMethod,
.waitingForClientRequest,
.waitingToSendResponse, .waitingToSendResponse,
.active, .active,
.error: .error:
throw SOCKSError.InvalidServerState() throw SOCKSError.InvalidServerState()
} }
if complete {
self.state = .waitingForClientRequest self.state = .waitingForClientRequest
} }
}
} }

View File

@ -32,6 +32,11 @@ extension SOCKSServerHandlerTests {
("testOutboundErrorsAreHandled", testOutboundErrorsAreHandled), ("testOutboundErrorsAreHandled", testOutboundErrorsAreHandled),
("testFlushOnHandlerRemoved", testFlushOnHandlerRemoved), ("testFlushOnHandlerRemoved", testFlushOnHandlerRemoved),
("testForceHandlerRemovalAfterAuth", testForceHandlerRemovalAfterAuth), ("testForceHandlerRemovalAfterAuth", testForceHandlerRemovalAfterAuth),
("testAutoAuthenticationComplete", testAutoAuthenticationComplete),
("testAutoAuthenticationCompleteWithManualCompletion", testAutoAuthenticationCompleteWithManualCompletion),
("testEagerClientRequestBeforeAuthenticationComplete", testEagerClientRequestBeforeAuthenticationComplete),
("testManualAuthenticationFailureExtraBytes", testManualAuthenticationFailureExtraBytes),
("testManualAuthenticationFailureInvalidCompletion", testManualAuthenticationFailureInvalidCompletion),
] ]
} }
} }

View File

@ -152,7 +152,7 @@ class SOCKSServerHandlerTests: XCTestCase {
// tests dripfeeding to ensure we buffer data correctly // tests dripfeeding to ensure we buffer data correctly
func testTypicalWorkflowDripfeed() { func testTypicalWorkflowDripfeed() {
let expectedGreeting = ClientGreeting(methods: [.noneRequired]) let expectedGreeting = ClientGreeting(methods: [.gssapi])
let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80))) let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))
let expectedData = ByteBuffer(string: "1234") let expectedData = ByteBuffer(string: "1234")
let testHandler = PromiseTestHandler( let testHandler = PromiseTestHandler(
@ -168,16 +168,15 @@ class SOCKSServerHandlerTests: XCTestCase {
self.assertOutputBuffer([]) self.assertOutputBuffer([])
self.writeInbound([0x01]) self.writeInbound([0x01])
self.assertOutputBuffer([]) self.assertOutputBuffer([])
self.writeInbound([0x00]) self.writeInbound([0x01])
self.assertOutputBuffer([]) self.assertOutputBuffer([])
XCTAssertTrue(testHandler.hadGreeting) XCTAssertTrue(testHandler.hadGreeting)
// write the auth selection // write the auth selection
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired)))) XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .gssapi))))
self.assertOutputBuffer([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01])
// finish authentication - nothing should be written // finish authentication with some bytes
// as this is informing the state machine only
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true))) XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true)))
self.assertOutputBuffer([0xFF, 0xFF]) self.assertOutputBuffer([0xFF, 0xFF])
@ -217,10 +216,11 @@ class SOCKSServerHandlerTests: XCTestCase {
func testForceHandlerRemovalAfterAuth() { func testForceHandlerRemovalAfterAuth() {
// go through auth // go through auth
self.writeInbound([0x05, 0x01, 0x00]) self.writeInbound([0x05, 0x01, 0x01])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi)))
self.assertOutputBuffer([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01])
XCTAssertNoThrow(try self.handler.stateMachine.authenticationComplete()) self.writeOutbound(.authenticationData(ByteBuffer(), complete: true))
self.assertOutputBuffer([])
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
@ -229,4 +229,88 @@ class SOCKSServerHandlerTests: XCTestCase {
// removing the handler, it should fail // removing the handler, it should fail
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false))) XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false)))
} }
func testAutoAuthenticationComplete() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])
// if we try and write the request then the data would be read
// as authentication data, and so the server wouldn't reply
// with a response
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
}
func testAutoAuthenticationCompleteWithManualCompletion() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])
// complete authentication, but nothing should be written
// to the network
self.writeOutbound(.authenticationData(ByteBuffer(), complete: true))
self.assertOutputBuffer([])
// if we try and write the request then the data would be read
// as authentication data, and so the server wouldn't reply
// with a response
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
}
func testEagerClientRequestBeforeAuthenticationComplete() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x01])
self.assertInbound(.greeting(.init(methods: [.gssapi])))
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi)))
self.assertOutputBuffer([0x05, 0x01])
// at this point authentication isn't complete
// so if the client sends a request then the
// server will read those as authentication bytes
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.assertInbound(.authenticationData(ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])))
}
func testManualAuthenticationFailureExtraBytes() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])
// invalid authentication completion
// we've selected `noneRequired`, so no
// bytes should be written
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0x00]), complete: true)))
}
func testManualAuthenticationFailureInvalidCompletion() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])
// invalid authentication completion
// authentication should have already completed
// as we selected `noneRequired`, so sending
// `complete = false` should be an error
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: []), complete: false)))
}
} }

View File

@ -34,9 +34,6 @@ public class ServerStateMachineTests: XCTestCase {
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired))) XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
XCTAssertFalse(stateMachine.proxyEstablished) XCTAssertFalse(stateMachine.proxyEstablished)
// authentication is now finished, as we didn't send any
XCTAssertNoThrow(try stateMachine.authenticationComplete())
// send the client request // send the client request
var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
XCTAssertNoThrow(try stateMachine.receiveBuffer(&request)) XCTAssertNoThrow(try stateMachine.receiveBuffer(&request))
@ -61,7 +58,6 @@ public class ServerStateMachineTests: XCTestCase {
XCTAssertNoThrow(try stateMachine.connectionEstablished()) XCTAssertNoThrow(try stateMachine.connectionEstablished())
XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting)) XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting))
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired))) XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
XCTAssertNoThrow(try stateMachine.authenticationComplete())
// write some invalid bytes from the client // write some invalid bytes from the client
// the state machine should throw // the state machine should throw