Cleanup unwinding

This commit is contained in:
David Evans 2021-06-07 11:23:13 +01:00
parent 44ceb74f50
commit 7a7deecb80
5 changed files with 130 additions and 117 deletions

View File

@ -22,30 +22,28 @@ struct ClientGreeting: Hashable {
extension ByteBuffer {
mutating func readClientGreeting() throws -> ClientGreeting? {
let save = self
guard
let version = self.readInteger(as: UInt8.self),
let numMethods = self.readInteger(as: UInt8.self)
else {
self = save
return nil
}
guard version == 5 else {
self = save
throw SOCKSError.InvalidProtocolVersion(actual: version)
}
var methods: [AuthenticationMethod] = []
methods.reserveCapacity(Int(numMethods))
for _ in 0..<numMethods {
guard let method = self.readInteger(as: UInt8.self) else {
self = save
return nil
return try self.parseUnwindingIfNeeded { buffer in
guard
let version = buffer.readInteger(as: UInt8.self),
let numMethods = buffer.readInteger(as: UInt8.self)
else {
throw MissingBytes()
}
methods.append(.init(value: method))
guard version == 5 else {
throw SOCKSError.InvalidProtocolVersion(actual: version)
}
var methods: [AuthenticationMethod] = []
methods.reserveCapacity(Int(numMethods))
for _ in 0..<numMethods {
guard let method = buffer.readInteger(as: UInt8.self) else {
throw MissingBytes()
}
methods.append(.init(value: method))
}
return .init(methods: methods)
}
return .init(methods: methods)
}
@discardableResult mutating func writeClientGreeting(_ greeting: ClientGreeting) -> Int {

View File

@ -100,59 +100,59 @@ public enum AddressType: Hashable {
extension ByteBuffer {
mutating func readAddresType() throws -> AddressType? {
let save = self
guard let type = self.readInteger(as: UInt8.self) else {
self = save
return nil
}
switch type {
case 0x01:
return try self.readIPv4Address()
case 0x03:
return try self.readDomain()
case 0x04:
return try self.readIPv6Address()
default:
throw SOCKSError.InvalidAddressType(actual: type)
return try self.parseUnwindingIfNeeded { buffer in
guard let type = buffer.readInteger(as: UInt8.self) else {
throw MissingBytes()
}
switch type {
case 0x01:
return try buffer.readIPv4Address()
case 0x03:
return try buffer.readDomain()
case 0x04:
return try buffer.readIPv6Address()
default:
throw SOCKSError.InvalidAddressType(actual: type)
}
}
}
mutating func readIPv4Address() throws -> AddressType? {
let save = self
guard
let bytes = self.readSlice(length: 4),
let port = try self.readPort()
else {
self = save
return nil
return try self.parseUnwindingIfNeeded { buffer in
guard
let bytes = buffer.readSlice(length: 4),
let port = try buffer.readPort()
else {
throw MissingBytes()
}
return .address(try .init(packedIPAddress: bytes, port: port))
}
return .address(try .init(packedIPAddress: bytes, port: port))
}
mutating func readIPv6Address() throws -> AddressType? {
let save = self
guard
let bytes = self.readSlice(length: 16),
let port = try self.readPort()
else {
self = save
return nil
return try self.parseUnwindingIfNeeded { buffer in
guard
let bytes = buffer.readSlice(length: 16),
let port = try buffer.readPort()
else {
throw MissingBytes()
}
return .address(try .init(packedIPAddress: bytes, port: port))
}
return .address(try .init(packedIPAddress: bytes, port: port))
}
mutating func readDomain() throws -> AddressType? {
let save = self
guard
let length = self.readInteger(as: UInt8.self),
let host = self.readString(length: Int(length)),
let port = try self.readPort()
else {
self = save
return nil
return try self.parseUnwindingIfNeeded { buffer in
guard
let length = buffer.readInteger(as: UInt8.self),
let host = buffer.readString(length: Int(length)),
let port = try buffer.readPort()
else {
throw MissingBytes()
}
return .domain(host, port: UInt16(port))
}
return .domain(host, port: UInt16(port))
}
mutating func readPort() throws -> Int? {

View File

@ -35,21 +35,20 @@ struct MethodSelection: Hashable {
extension ByteBuffer {
mutating func readMethodSelection() throws -> MethodSelection? {
let save = self
guard
let version = self.readInteger(as: UInt8.self),
let method = self.readInteger(as: UInt8.self)
else {
self = save
return nil
return try self.parseUnwindingIfNeeded { buffer in
guard
let version = buffer.readInteger(as: UInt8.self),
let method = buffer.readInteger(as: UInt8.self)
else {
throw MissingBytes()
}
guard version == 0x05 else {
throw SOCKSError.InvalidProtocolVersion(actual: version)
}
return .init(method: .init(value: method))
}
guard version == 0x05 else {
self = save
throw SOCKSError.InvalidProtocolVersion(actual: version)
}
return .init(method: .init(value: method))
}
@discardableResult mutating func writeMethodSelection(_ method: MethodSelection) -> Int {

View File

@ -43,28 +43,26 @@ struct ServerResponse: Hashable {
extension ByteBuffer {
mutating func readServerResponse() throws -> ServerResponse? {
let save = self
guard
let version = self.readInteger(as: UInt8.self),
let reply = self.readInteger(as: UInt8.self).map({ Reply(value: $0) }),
let reserved = self.readInteger(as: UInt8.self),
let boundAddress = try self.readAddresType()
else {
self = save
return nil
return try self.parseUnwindingIfNeeded { buffer in
guard
let version = buffer.readInteger(as: UInt8.self),
let reply = buffer.readInteger(as: UInt8.self).map({ Reply(value: $0) }),
let reserved = buffer.readInteger(as: UInt8.self),
let boundAddress = try buffer.readAddresType()
else {
throw MissingBytes()
}
guard reserved == 0x0 else {
throw SOCKSError.InvalidReservedByte(actual: reserved)
}
guard version == 0x05 else {
throw SOCKSError.InvalidProtocolVersion(actual: version)
}
return .init(reply: reply, boundAddress: boundAddress)
}
guard reserved == 0x0 else {
self = save
throw SOCKSError.InvalidReservedByte(actual: reserved)
}
guard version == 0x05 else {
self = save
throw SOCKSError.InvalidProtocolVersion(actual: version)
}
return .init(reply: reply, boundAddress: boundAddress)
}
}

View File

@ -14,6 +14,37 @@
import NIO
struct MissingBytes: Error {
}
extension ByteBuffer {
mutating func parseUnwindingIfNeeded<T>(_ closure: (inout ByteBuffer) throws -> T?) rethrows -> T? {
let save = self
do {
return try closure(&self)
} catch is MissingBytes {
self = save
return nil
} catch {
self = save
throw error
}
}
mutating func parseUnwindingIfNeeded<T>(_ closure: (inout ByteBuffer) throws -> T) rethrows -> T {
let save = self
do {
return try closure(&self)
} catch {
self = save
throw error
}
}
}
enum ClientState: Hashable {
case inactive
case waitingForClientGreeting
@ -58,16 +89,6 @@ struct ClientStateMachine {
self.state = .inactive
}
private func unwindIfNeeded<T>(_ buffer: inout ByteBuffer, _ closure: (inout ByteBuffer) throws -> T) rethrows -> T {
let save = buffer
do {
return try closure(&buffer)
} catch {
buffer = save
throw error
}
}
}
// MARK: - Incoming
@ -90,7 +111,7 @@ extension ClientStateMachine {
}
mutating func handleSelectedAuthenticationMethod(_ buffer: inout ByteBuffer, greeting: ClientGreeting) throws -> ClientAction {
return try self.unwindIfNeeded(&buffer) { buffer -> ClientAction in
return try buffer.parseUnwindingIfNeeded { buffer -> ClientAction in
guard let selected = try buffer.readMethodSelection() else {
return .waitForMoreData
}
@ -104,7 +125,7 @@ extension ClientStateMachine {
}
mutating func handleServerResponse(_ buffer: inout ByteBuffer, request: ClientRequest) throws -> ClientAction {
return try self.unwindIfNeeded(&buffer) { buffer -> ClientAction in
return try buffer.parseUnwindingIfNeeded { buffer -> ClientAction in
guard let response = try buffer.readServerResponse() else {
return .waitForMoreData
}
@ -117,14 +138,11 @@ extension ClientStateMachine {
}
mutating func authenticate(_ buffer: inout ByteBuffer) -> ClientAction {
return self.unwindIfNeeded(&buffer) { buffer -> ClientAction in
// we don't currently support any authentication
// so assume all is fine, and instruct the client
// to send the request
self.state = .waitingForClientRequest
return .sendRequest
}
// we don't currently support any authentication
// so assume all is fine, and instruct the client
// to send the request
self.state = .waitingForClientRequest
return .sendRequest
}
}