Guard not assert state errors

This commit is contained in:
David Evans 2021-06-07 11:47:05 +01:00
parent 2fcdc8a98a
commit 2059fe0305
4 changed files with 38 additions and 22 deletions

View File

@ -57,7 +57,7 @@ public class SOCKSClientHandler: ChannelDuplexHandler {
self.buffered.writeBuffer(&buffer)
do {
let action = try self.state.receiveBuffer(&self.buffered)
self.handleAction(action, context: context)
try self.handleAction(action, context: context)
} catch {
context.fireErrorCaught(error)
context.close(mode: .all, promise: nil)
@ -85,20 +85,25 @@ public class SOCKSClientHandler: ChannelDuplexHandler {
extension SOCKSClientHandler {
func beginHandshake(context: ChannelHandlerContext) {
guard self.state.shouldBeginHandshake else {
return
do {
guard self.state.shouldBeginHandshake else {
return
}
try self.handleAction(self.state.connectionEstablished(), context: context)
} catch {
context.fireErrorCaught(error)
context.close(promise: nil)
}
self.handleAction(self.state.connectionEstablished(), context: context)
}
func handleAction(_ action: ClientAction, context: ChannelHandlerContext) {
func handleAction(_ action: ClientAction, context: ChannelHandlerContext) throws {
switch action {
case .waitForMoreData:
break // do nothing, we've already buffered the data
case .sendGreeting:
self.handleActionSendClientGreeting(context: context)
try self.handleActionSendClientGreeting(context: context)
case .sendRequest:
self.handleActionSendRequest(context: context)
try self.handleActionSendRequest(context: context)
case .proxyEstablished:
self.handleActionProxyEstablished(context: context)
case .sendData(let data):
@ -106,12 +111,12 @@ extension SOCKSClientHandler {
}
}
func handleActionSendClientGreeting(context: ChannelHandlerContext) {
func handleActionSendClientGreeting(context: ChannelHandlerContext) throws {
let greeting = ClientGreeting(methods: [.noneRequired]) // no authentication currently supported
let capacity = 1 + 1 + 1 // [version, #methods, methods...]
var buffer = context.channel.allocator.buffer(capacity: capacity)
buffer.writeClientGreeting(greeting)
self.state.sendClientGreeting(greeting)
try self.state.sendClientGreeting(greeting)
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
@ -128,9 +133,9 @@ extension SOCKSClientHandler {
self.writeBufferedData(context: context)
}
func handleActionSendRequest(context: ChannelHandlerContext) {
func handleActionSendRequest(context: ChannelHandlerContext) throws {
let request = ClientRequest(command: .connect, addressType: self.targetAddress)
self.state.sendClientRequest(request)
try self.state.sendClientRequest(request)
// the client request is always 6 bytes + the address info
// [protocol_version, command, reserved, address type, <address>, port (2bytes)]

View File

@ -16,6 +16,11 @@ import NIO
public enum SOCKSError {
public struct InvalidState: Error, Hashable {
var expected: ClientState
var actual: ClientState
}
public struct InvalidProtocolVersion: Error, Hashable {
public var actual: UInt8
public init(actual: UInt8) {

View File

@ -150,19 +150,25 @@ extension ClientStateMachine {
// MARK: - Outgoing
extension ClientStateMachine {
mutating func connectionEstablished() -> ClientAction {
assert(self.state == .inactive)
mutating func connectionEstablished() throws -> ClientAction {
guard self.state == .inactive else {
throw SOCKSError.InvalidState(expected: .inactive, actual: self.state)
}
self.state = .waitingForClientGreeting
return .sendGreeting
}
mutating func sendClientGreeting(_ greeting: ClientGreeting) {
assert(self.state == .waitingForClientGreeting)
mutating func sendClientGreeting(_ greeting: ClientGreeting) throws {
guard self.state == .inactive else {
throw SOCKSError.InvalidState(expected: .waitingForClientGreeting, actual: self.state)
}
self.state = .waitingForAuthenticationMethod(greeting)
}
mutating func sendClientRequest(_ request: ClientRequest) {
assert(self.state == .waitingForClientRequest)
mutating func sendClientRequest(_ request: ClientRequest) throws {
guard self.state == .inactive else {
throw SOCKSError.InvalidState(expected: .waitingForClientRequest, actual: self.state)
}
self.state = .waitingForServerResponse(request)
}

View File

@ -23,11 +23,11 @@ public class ClientStateMachineTests: XCTestCase {
// create state machine and immediately connect
var stateMachine = ClientStateMachine()
XCTAssertTrue(stateMachine.shouldBeginHandshake)
XCTAssertEqual(stateMachine.connectionEstablished(), .sendGreeting)
XCTAssertNoThrow(XCTAssertEqual(try stateMachine.connectionEstablished(), .sendGreeting))
XCTAssertFalse(stateMachine.proxyEstablished)
// send the client greeting
stateMachine.sendClientGreeting(.init(methods: [.noneRequired]))
XCTAssertNoThrow(try stateMachine.sendClientGreeting(.init(methods: [.noneRequired])))
XCTAssertFalse(stateMachine.shouldBeginHandshake)
XCTAssertFalse(stateMachine.proxyEstablished)
@ -42,7 +42,7 @@ public class ClientStateMachineTests: XCTestCase {
XCTAssertFalse(stateMachine.proxyEstablished)
// send the client request
stateMachine.sendClientRequest(.init(command: .bind, addressType: .address(try! .init(ipAddress: "192.168.1.1", port: 80))))
XCTAssertNoThrow(try stateMachine.sendClientRequest(.init(command: .bind, addressType: .address(try! .init(ipAddress: "192.168.1.1", port: 80)))))
XCTAssertFalse(stateMachine.shouldBeginHandshake)
XCTAssertFalse(stateMachine.proxyEstablished)
@ -62,8 +62,8 @@ public class ClientStateMachineTests: XCTestCase {
// prepare the state machine
var stateMachine = ClientStateMachine()
XCTAssertEqual(stateMachine.connectionEstablished(), .sendGreeting)
stateMachine.sendClientGreeting(.init(methods: [.noneRequired]))
XCTAssertNoThrow(XCTAssertEqual(try stateMachine.connectionEstablished(), .sendGreeting))
XCTAssertNoThrow(try stateMachine.sendClientGreeting(.init(methods: [.noneRequired])))
// write some invalid bytes from the server
// the state machine should throw