mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 08:52:42 +08:00
Cleanup unwinding
This commit is contained in:
parent
44ceb74f50
commit
7a7deecb80
@ -22,30 +22,28 @@ struct ClientGreeting: Hashable {
|
||||
extension ByteBuffer {
|
||||
|
||||
mutating func readClientGreeting() throws -> ClientGreeting? {
|
||||
let save = self
|
||||
guard
|
||||
let version = self.readInteger(as: UInt8.self),
|
||||
let numMethods = self.readInteger(as: UInt8.self)
|
||||
else {
|
||||
self = save
|
||||
return nil
|
||||
}
|
||||
|
||||
guard version == 5 else {
|
||||
self = save
|
||||
throw SOCKSError.InvalidProtocolVersion(actual: version)
|
||||
}
|
||||
|
||||
var methods: [AuthenticationMethod] = []
|
||||
methods.reserveCapacity(Int(numMethods))
|
||||
for _ in 0..<numMethods {
|
||||
guard let method = self.readInteger(as: UInt8.self) else {
|
||||
self = save
|
||||
return nil
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard
|
||||
let version = buffer.readInteger(as: UInt8.self),
|
||||
let numMethods = buffer.readInteger(as: UInt8.self)
|
||||
else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
methods.append(.init(value: method))
|
||||
|
||||
guard version == 5 else {
|
||||
throw SOCKSError.InvalidProtocolVersion(actual: version)
|
||||
}
|
||||
|
||||
var methods: [AuthenticationMethod] = []
|
||||
methods.reserveCapacity(Int(numMethods))
|
||||
for _ in 0..<numMethods {
|
||||
guard let method = buffer.readInteger(as: UInt8.self) else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
methods.append(.init(value: method))
|
||||
}
|
||||
return .init(methods: methods)
|
||||
}
|
||||
return .init(methods: methods)
|
||||
}
|
||||
|
||||
@discardableResult mutating func writeClientGreeting(_ greeting: ClientGreeting) -> Int {
|
||||
|
@ -100,59 +100,59 @@ public enum AddressType: Hashable {
|
||||
extension ByteBuffer {
|
||||
|
||||
mutating func readAddresType() throws -> AddressType? {
|
||||
let save = self
|
||||
guard let type = self.readInteger(as: UInt8.self) else {
|
||||
self = save
|
||||
return nil
|
||||
}
|
||||
|
||||
switch type {
|
||||
case 0x01:
|
||||
return try self.readIPv4Address()
|
||||
case 0x03:
|
||||
return try self.readDomain()
|
||||
case 0x04:
|
||||
return try self.readIPv6Address()
|
||||
default:
|
||||
throw SOCKSError.InvalidAddressType(actual: type)
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard let type = buffer.readInteger(as: UInt8.self) else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
|
||||
switch type {
|
||||
case 0x01:
|
||||
return try buffer.readIPv4Address()
|
||||
case 0x03:
|
||||
return try buffer.readDomain()
|
||||
case 0x04:
|
||||
return try buffer.readIPv6Address()
|
||||
default:
|
||||
throw SOCKSError.InvalidAddressType(actual: type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mutating func readIPv4Address() throws -> AddressType? {
|
||||
let save = self
|
||||
guard
|
||||
let bytes = self.readSlice(length: 4),
|
||||
let port = try self.readPort()
|
||||
else {
|
||||
self = save
|
||||
return nil
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard
|
||||
let bytes = buffer.readSlice(length: 4),
|
||||
let port = try buffer.readPort()
|
||||
else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
return .address(try .init(packedIPAddress: bytes, port: port))
|
||||
}
|
||||
return .address(try .init(packedIPAddress: bytes, port: port))
|
||||
}
|
||||
|
||||
mutating func readIPv6Address() throws -> AddressType? {
|
||||
let save = self
|
||||
guard
|
||||
let bytes = self.readSlice(length: 16),
|
||||
let port = try self.readPort()
|
||||
else {
|
||||
self = save
|
||||
return nil
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard
|
||||
let bytes = buffer.readSlice(length: 16),
|
||||
let port = try buffer.readPort()
|
||||
else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
return .address(try .init(packedIPAddress: bytes, port: port))
|
||||
}
|
||||
return .address(try .init(packedIPAddress: bytes, port: port))
|
||||
}
|
||||
|
||||
mutating func readDomain() throws -> AddressType? {
|
||||
let save = self
|
||||
guard
|
||||
let length = self.readInteger(as: UInt8.self),
|
||||
let host = self.readString(length: Int(length)),
|
||||
let port = try self.readPort()
|
||||
else {
|
||||
self = save
|
||||
return nil
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard
|
||||
let length = buffer.readInteger(as: UInt8.self),
|
||||
let host = buffer.readString(length: Int(length)),
|
||||
let port = try buffer.readPort()
|
||||
else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
return .domain(host, port: UInt16(port))
|
||||
}
|
||||
return .domain(host, port: UInt16(port))
|
||||
}
|
||||
|
||||
mutating func readPort() throws -> Int? {
|
||||
|
@ -35,21 +35,20 @@ struct MethodSelection: Hashable {
|
||||
extension ByteBuffer {
|
||||
|
||||
mutating func readMethodSelection() throws -> MethodSelection? {
|
||||
let save = self
|
||||
guard
|
||||
let version = self.readInteger(as: UInt8.self),
|
||||
let method = self.readInteger(as: UInt8.self)
|
||||
else {
|
||||
self = save
|
||||
return nil
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard
|
||||
let version = buffer.readInteger(as: UInt8.self),
|
||||
let method = buffer.readInteger(as: UInt8.self)
|
||||
else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
|
||||
guard version == 0x05 else {
|
||||
throw SOCKSError.InvalidProtocolVersion(actual: version)
|
||||
}
|
||||
|
||||
return .init(method: .init(value: method))
|
||||
}
|
||||
|
||||
guard version == 0x05 else {
|
||||
self = save
|
||||
throw SOCKSError.InvalidProtocolVersion(actual: version)
|
||||
}
|
||||
|
||||
return .init(method: .init(value: method))
|
||||
}
|
||||
|
||||
@discardableResult mutating func writeMethodSelection(_ method: MethodSelection) -> Int {
|
||||
|
@ -43,28 +43,26 @@ struct ServerResponse: Hashable {
|
||||
extension ByteBuffer {
|
||||
|
||||
mutating func readServerResponse() throws -> ServerResponse? {
|
||||
let save = self
|
||||
guard
|
||||
let version = self.readInteger(as: UInt8.self),
|
||||
let reply = self.readInteger(as: UInt8.self).map({ Reply(value: $0) }),
|
||||
let reserved = self.readInteger(as: UInt8.self),
|
||||
let boundAddress = try self.readAddresType()
|
||||
else {
|
||||
self = save
|
||||
return nil
|
||||
return try self.parseUnwindingIfNeeded { buffer in
|
||||
guard
|
||||
let version = buffer.readInteger(as: UInt8.self),
|
||||
let reply = buffer.readInteger(as: UInt8.self).map({ Reply(value: $0) }),
|
||||
let reserved = buffer.readInteger(as: UInt8.self),
|
||||
let boundAddress = try buffer.readAddresType()
|
||||
else {
|
||||
throw MissingBytes()
|
||||
}
|
||||
|
||||
guard reserved == 0x0 else {
|
||||
throw SOCKSError.InvalidReservedByte(actual: reserved)
|
||||
}
|
||||
|
||||
guard version == 0x05 else {
|
||||
throw SOCKSError.InvalidProtocolVersion(actual: version)
|
||||
}
|
||||
|
||||
return .init(reply: reply, boundAddress: boundAddress)
|
||||
}
|
||||
|
||||
guard reserved == 0x0 else {
|
||||
self = save
|
||||
throw SOCKSError.InvalidReservedByte(actual: reserved)
|
||||
}
|
||||
|
||||
guard version == 0x05 else {
|
||||
self = save
|
||||
throw SOCKSError.InvalidProtocolVersion(actual: version)
|
||||
}
|
||||
|
||||
return .init(reply: reply, boundAddress: boundAddress)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -14,6 +14,37 @@
|
||||
|
||||
import NIO
|
||||
|
||||
struct MissingBytes: Error {
|
||||
|
||||
}
|
||||
|
||||
extension ByteBuffer {
|
||||
|
||||
mutating func parseUnwindingIfNeeded<T>(_ closure: (inout ByteBuffer) throws -> T?) rethrows -> T? {
|
||||
let save = self
|
||||
do {
|
||||
return try closure(&self)
|
||||
} catch is MissingBytes {
|
||||
self = save
|
||||
return nil
|
||||
} catch {
|
||||
self = save
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
mutating func parseUnwindingIfNeeded<T>(_ closure: (inout ByteBuffer) throws -> T) rethrows -> T {
|
||||
let save = self
|
||||
do {
|
||||
return try closure(&self)
|
||||
} catch {
|
||||
self = save
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
enum ClientState: Hashable {
|
||||
case inactive
|
||||
case waitingForClientGreeting
|
||||
@ -58,16 +89,6 @@ struct ClientStateMachine {
|
||||
self.state = .inactive
|
||||
}
|
||||
|
||||
private func unwindIfNeeded<T>(_ buffer: inout ByteBuffer, _ closure: (inout ByteBuffer) throws -> T) rethrows -> T {
|
||||
let save = buffer
|
||||
do {
|
||||
return try closure(&buffer)
|
||||
} catch {
|
||||
buffer = save
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// MARK: - Incoming
|
||||
@ -90,7 +111,7 @@ extension ClientStateMachine {
|
||||
}
|
||||
|
||||
mutating func handleSelectedAuthenticationMethod(_ buffer: inout ByteBuffer, greeting: ClientGreeting) throws -> ClientAction {
|
||||
return try self.unwindIfNeeded(&buffer) { buffer -> ClientAction in
|
||||
return try buffer.parseUnwindingIfNeeded { buffer -> ClientAction in
|
||||
guard let selected = try buffer.readMethodSelection() else {
|
||||
return .waitForMoreData
|
||||
}
|
||||
@ -104,7 +125,7 @@ extension ClientStateMachine {
|
||||
}
|
||||
|
||||
mutating func handleServerResponse(_ buffer: inout ByteBuffer, request: ClientRequest) throws -> ClientAction {
|
||||
return try self.unwindIfNeeded(&buffer) { buffer -> ClientAction in
|
||||
return try buffer.parseUnwindingIfNeeded { buffer -> ClientAction in
|
||||
guard let response = try buffer.readServerResponse() else {
|
||||
return .waitForMoreData
|
||||
}
|
||||
@ -117,14 +138,11 @@ extension ClientStateMachine {
|
||||
}
|
||||
|
||||
mutating func authenticate(_ buffer: inout ByteBuffer) -> ClientAction {
|
||||
return self.unwindIfNeeded(&buffer) { buffer -> ClientAction in
|
||||
|
||||
// we don't currently support any authentication
|
||||
// so assume all is fine, and instruct the client
|
||||
// to send the request
|
||||
self.state = .waitingForClientRequest
|
||||
return .sendRequest
|
||||
}
|
||||
// we don't currently support any authentication
|
||||
// so assume all is fine, and instruct the client
|
||||
// to send the request
|
||||
self.state = .waitingForClientRequest
|
||||
return .sendRequest
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user