writes are buffered, other review changes

- writes issued whilst the CONNECT is ongoing are now buffered rather
  than triggering a failure
- Error is restructured to do away with `Kind`
- failure logic is consolidated in `failWithError`
This commit is contained in:
Rick Newton-Rogers 2022-12-15 14:20:53 +00:00
parent 313e224abf
commit b34607e15a
3 changed files with 313 additions and 90 deletions

View File

@ -20,6 +20,15 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
public typealias OutboundOut = HTTPClientRequestPart
public typealias InboundIn = HTTPClientResponsePart
/// Whether we've already seen the first request.
private var seenFirstRequest = false
private var bufferedWrittenMessages: MarkedCircularBuffer<BufferedWrite>
struct BufferedWrite {
var data: NIOAny
var promise: EventLoopPromise<Void>?
}
private enum State {
// transitions to `.connectSent` or `.failed`
case initialized
@ -39,7 +48,7 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
private let targetPort: Int
private let headers: HTTPHeaders
private let deadline: NIODeadline
private let promise: EventLoopPromise<Void>
private let promise: EventLoopPromise<Void>?
/// Creates a new ``NIOHTTP1ProxyConnectHandler`` that issues a CONNECT request to a proxy server
/// and instructs the server to connect to `targetHost`.
@ -59,8 +68,52 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
self.headers = headers
self.deadline = deadline
self.promise = promise
self.bufferedWrittenMessages = MarkedCircularBuffer(initialCapacity: 16) // matches CircularBuffer default
}
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
switch self.state {
case .initialized, .connectSent, .headReceived, .completed:
self.bufferedWrittenMessages.append(BufferedWrite(data: data, promise: promise))
case .failed(let error):
promise?.fail(error)
}
}
public func flush(context: ChannelHandlerContext) {
self.bufferedWrittenMessages.mark()
}
public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
// We have been formally removed from the pipeline. We should send any buffered data we have.
switch self.state {
case .initialized, .connectSent, .headReceived, .failed:
self.failWithError(.noResult(), context: context)
case .completed:
let hadMark = self.bufferedWrittenMessages.hasMark
while self.bufferedWrittenMessages.hasMark {
// write until mark
let bufferedPart = self.bufferedWrittenMessages.removeFirst()
context.write(bufferedPart.data, promise: bufferedPart.promise)
}
// flush any messages up to the mark
if hadMark {
context.flush()
}
// write remainder
while let bufferedPart = self.bufferedWrittenMessages.popFirst() {
context.write(bufferedPart.data, promise: bufferedPart.promise)
}
}
context.leavePipeline(removalToken: removalToken)
}
public func handlerAdded(context: ChannelHandlerContext) {
if context.channel.isActive {
self.sendConnect(context: context)
@ -70,10 +123,11 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
public func handlerRemoved(context: ChannelHandlerContext) {
switch self.state {
case .failed, .completed:
// we don't expect there to be any buffered messages in these states
assert(self.bufferedWrittenMessages.isEmpty)
break
case .initialized, .connectSent, .headReceived:
self.state = .failed(Error.noResult())
self.promise.fail(Error.noResult())
self.failWithError(Error.noResult(), context: context)
}
}
@ -96,10 +150,6 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
context.fireChannelInactive()
}
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
preconditionFailure("We don't support outgoing traffic during HTTP Proxy update.")
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head(let head):
@ -187,22 +237,33 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
case .headReceived(let timeout):
timeout.cancel()
self.state = .completed
self.promise.succeed(())
case .failed:
// ran into an error before... ignore this one
break
return
case .initialized, .connectSent, .completed:
preconditionFailure("Invalid state: \(self.state)")
}
// Ok, we've set up the proxy connection. We can now remove ourselves, which should happen synchronously.
context.pipeline.removeHandler(context: context, promise: nil)
self.promise?.succeed(())
}
private func failWithError(_ error: Error, context: ChannelHandlerContext, closeConnection: Bool = true) {
self.state = .failed(error)
self.promise.fail(error)
context.fireErrorCaught(error)
if closeConnection {
context.close(mode: .all, promise: nil)
switch self.state {
case .failed:
return
case .initialized, .connectSent, .headReceived, .completed:
self.state = .failed(error)
self.promise?.fail(error)
context.fireErrorCaught(error)
if closeConnection {
context.close(mode: .all, promise: nil)
}
for bufferedWrite in self.bufferedWrittenMessages {
bufferedWrite.promise?.fail(error)
}
}
}
@ -217,15 +278,6 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
case noResult
}
fileprivate enum Kind: String, Equatable, Hashable {
case proxyAuthenticationRequired
case invalidProxyResponseHead
case invalidProxyResponse
case remoteConnectionClosed
case httpProxyHandshakeTimeout
case noResult
}
final class Storage: Sendable {
fileprivate let details: Details
public let file: String
@ -273,54 +325,75 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC
public static func noResult(file: String = #file, line: UInt = #line) -> Error {
Error(error: .noResult, file: file, line: line)
}
fileprivate var errorCode: Int {
switch self.store.details {
case .proxyAuthenticationRequired:
return 0
case .invalidProxyResponseHead:
return 1
case .invalidProxyResponse:
return 2
case .remoteConnectionClosed:
return 3
case .httpProxyHandshakeTimeout:
return 4
case .noResult:
return 5
}
}
}
}
extension NIOHTTP1ProxyConnectHandler.Error: Hashable {
public static func == (lhs: NIOHTTP1ProxyConnectHandler.Error, rhs: NIOHTTP1ProxyConnectHandler.Error) -> Bool {
// ignore *where* the error was thrown
lhs.store.details == rhs.store.details
}
public func hash(into hasher: inout Hasher) {
hasher.combine(self.store.details)
}
}
extension NIOHTTP1ProxyConnectHandler.Error.Details: Hashable {
// compare only the kind of error, not the associated response head
@inlinable
static func == (lhs: Self, rhs: Self) -> Bool {
NIOHTTP1ProxyConnectHandler.Error.Kind(lhs) == NIOHTTP1ProxyConnectHandler.Error.Kind(rhs)
}
@inlinable
public func hash(into hasher: inout Hasher) {
hasher.combine(NIOHTTP1ProxyConnectHandler.Error.Kind(self))
}
}
extension NIOHTTP1ProxyConnectHandler.Error.Kind {
init(_ details: NIOHTTP1ProxyConnectHandler.Error.Details) {
switch details {
case .proxyAuthenticationRequired:
self = .proxyAuthenticationRequired
case .invalidProxyResponseHead:
self = .invalidProxyResponseHead
case .invalidProxyResponse:
self = .invalidProxyResponse
case .remoteConnectionClosed:
self = .remoteConnectionClosed
case .httpProxyHandshakeTimeout:
self = .httpProxyHandshakeTimeout
case .noResult:
self = .noResult
public static func == (lhs: Self, rhs: Self) -> Bool {
switch (lhs.store.details, rhs.store.details) {
case (.proxyAuthenticationRequired, .proxyAuthenticationRequired):
return true
case (.invalidProxyResponseHead, .invalidProxyResponseHead):
return true
case (.invalidProxyResponse, .invalidProxyResponse):
return true
case (.remoteConnectionClosed, .remoteConnectionClosed):
return true
case (.httpProxyHandshakeTimeout, .httpProxyHandshakeTimeout):
return true
case (.noResult, .noResult):
return true
default:
return false
}
}
public func hash(into hasher: inout Hasher) {
hasher.combine(self.errorCode)
}
}
extension NIOHTTP1ProxyConnectHandler.Error: CustomStringConvertible {
public var description: String { return NIOHTTP1ProxyConnectHandler.Error.Kind(store.details).rawValue }
public var description: String {
self.store.details.description
}
}
extension NIOHTTP1ProxyConnectHandler.Error.Details: CustomStringConvertible {
public var description: String {
switch self {
case .proxyAuthenticationRequired:
return "Proxy Authentication Required"
case .invalidProxyResponseHead:
return "Invalid Proxy Response Head"
case .invalidProxyResponse:
return "Invalid Proxy Response"
case .remoteConnectionClosed:
return "Remote Connection Closed"
case .httpProxyHandshakeTimeout:
return "HTTP Proxy Handshake Timeout"
case .noResult:
return "No Result"
}
}
}

View File

@ -32,6 +32,8 @@ extension HTTP1ProxyConnectHandlerTests {
("testProxyConnectWithoutAuthorizationFailure500", testProxyConnectWithoutAuthorizationFailure500),
("testProxyConnectWithoutAuthorizationButAuthorizationNeeded", testProxyConnectWithoutAuthorizationButAuthorizationNeeded),
("testProxyConnectReceivesBody", testProxyConnectReceivesBody),
("testProxyConnectWithoutAuthorizationBufferedWrites", testProxyConnectWithoutAuthorizationBufferedWrites),
("testProxyConnectFailsBufferedWritesAreFailed", testProxyConnectFailsBufferedWritesAreFailed),
]
}
}

View File

@ -19,7 +19,7 @@ import NIOHTTP1
import XCTest
class HTTP1ProxyConnectHandlerTests: XCTestCase {
func testProxyConnectWithoutAuthorizationSuccess() {
func testProxyConnectWithoutAuthorizationSuccess() throws {
let embedded = EmbeddedChannel()
defer { XCTAssertNoThrow(try embedded.finish(acceptAlreadyClosed: false)) }
@ -37,12 +37,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(proxyConnectHandler))
var maybeHead: HTTPClientRequestPart?
XCTAssertNoThrow(maybeHead = try embedded.readOutbound(as: HTTPClientRequestPart.self))
guard case .some(.head(let head)) = maybeHead else {
return XCTFail("Expected the proxy connect handler to first send a http head part")
}
let head = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(head.method, .CONNECT)
XCTAssertEqual(head.uri, "swift.org:443")
XCTAssertNil(head.headers["proxy-authorization"].first)
@ -55,7 +50,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertNoThrow(try promise.futureResult.wait())
}
func testProxyConnectWithAuthorization() {
func testProxyConnectWithAuthorization() throws {
let embedded = EmbeddedChannel()
let socketAddress = try! SocketAddress.makeAddressResolvingHost("localhost", port: 0)
@ -72,12 +67,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(proxyConnectHandler))
var maybeHead: HTTPClientRequestPart?
XCTAssertNoThrow(maybeHead = try embedded.readOutbound(as: HTTPClientRequestPart.self))
guard case .some(.head(let head)) = maybeHead else {
return XCTFail("Expected the proxy connect handler to first send a http head part")
}
let head = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(head.method, .CONNECT)
XCTAssertEqual(head.uri, "swift.org:443")
XCTAssertEqual(head.headers["proxy-authorization"].first, "Basic abc123")
@ -90,7 +80,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertNoThrow(try promise.futureResult.wait())
}
func testProxyConnectWithoutAuthorizationFailure500() {
func testProxyConnectWithoutAuthorizationFailure500() throws {
let embedded = EmbeddedChannel()
let socketAddress = try! SocketAddress.makeAddressResolvingHost("localhost", port: 0)
@ -107,12 +97,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(proxyConnectHandler))
var maybeHead: HTTPClientRequestPart?
XCTAssertNoThrow(maybeHead = try embedded.readOutbound(as: HTTPClientRequestPart.self))
guard case .some(.head(let head)) = maybeHead else {
return XCTFail("Expected the proxy connect handler to first send a http head part")
}
let head = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(head.method, .CONNECT)
XCTAssertEqual(head.uri, "swift.org:443")
XCTAssertNil(head.headers["proxy-authorization"].first)
@ -131,7 +116,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
}
}
func testProxyConnectWithoutAuthorizationButAuthorizationNeeded() {
func testProxyConnectWithoutAuthorizationButAuthorizationNeeded() throws {
let embedded = EmbeddedChannel()
let socketAddress = try! SocketAddress.makeAddressResolvingHost("localhost", port: 0)
@ -148,12 +133,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(proxyConnectHandler))
var maybeHead: HTTPClientRequestPart?
XCTAssertNoThrow(maybeHead = try embedded.readOutbound(as: HTTPClientRequestPart.self))
guard case .some(.head(let head)) = maybeHead else {
return XCTFail("Expected the proxy connect handler to first send a http head part")
}
let head = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(head.method, .CONNECT)
XCTAssertEqual(head.uri, "swift.org:443")
XCTAssertNil(head.headers["proxy-authorization"].first)
@ -212,4 +192,172 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase {
XCTAssertEqual($0 as? NIOHTTP1ProxyConnectHandler.Error, .invalidProxyResponse())
}
}
func testProxyConnectWithoutAuthorizationBufferedWrites() throws {
let embedded = EmbeddedChannel()
defer { XCTAssertNoThrow(try embedded.finish(acceptAlreadyClosed: false)) }
let socketAddress = try! SocketAddress.makeAddressResolvingHost("localhost", port: 0)
XCTAssertNoThrow(try embedded.connect(to: socketAddress).wait())
let proxyConnectPromise: EventLoopPromise<Void> = embedded.eventLoop.makePromise()
let proxyConnectHandler = NIOHTTP1ProxyConnectHandler(
targetHost: "swift.org",
targetPort: 443,
headers: [:],
deadline: .now() + .seconds(10),
promise: proxyConnectPromise
)
XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(proxyConnectHandler))
// write a request to be buffered inside the ProxyConnectHandler
// it will be unbuffered when the handler completes and removes itself
let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "http://apple.com")
var promises: [EventLoopPromise<Void>] = []
promises.append(embedded.eventLoop.makePromise())
embedded.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promises.last)
promises.append(embedded.eventLoop.makePromise())
embedded.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "Test")))), promise: promises.last)
promises.append(embedded.eventLoop.makePromise())
embedded.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promises.last)
embedded.pipeline.flush()
// read the connect header back
let connectHead = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(connectHead.method, .CONNECT)
XCTAssertEqual(connectHead.uri, "swift.org:443")
XCTAssertNil(connectHead.headers["proxy-authorization"].first)
let connectTrailers = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertEnd()
XCTAssertNil(connectTrailers)
// ensure that nothing has been unbuffered by mistake
XCTAssertNil(try embedded.readOutbound(as: HTTPClientRequestPart.self))
let responseHead = HTTPResponseHead(version: .http1_1, status: .ok)
XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead)))
XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil)))
XCTAssertNoThrow(try proxyConnectPromise.futureResult.wait())
// read the buffered write back
let bufferedHead = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(bufferedHead.method, .GET)
XCTAssertEqual(bufferedHead.uri, "http://apple.com")
XCTAssertNil(bufferedHead.headers["proxy-authorization"].first)
let bufferedBody = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertBody()
XCTAssertEqual(bufferedBody, ByteBuffer(string: "Test"))
let bufferedTrailers = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertEnd()
XCTAssertNil(bufferedTrailers)
let resultFutures = promises.map { $0.futureResult }
XCTAssertNoThrow(_ = try EventLoopFuture.whenAllComplete(resultFutures, on: embedded.eventLoop).wait())
}
func testProxyConnectFailsBufferedWritesAreFailed() throws {
let embedded = EmbeddedChannel()
let socketAddress = try! SocketAddress.makeAddressResolvingHost("localhost", port: 0)
XCTAssertNoThrow(try embedded.connect(to: socketAddress).wait())
let proxyConnectPromise: EventLoopPromise<Void> = embedded.eventLoop.makePromise()
let proxyConnectHandler = NIOHTTP1ProxyConnectHandler(
targetHost: "swift.org",
targetPort: 443,
headers: [:],
deadline: .now() + .seconds(10),
promise: proxyConnectPromise
)
XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(proxyConnectHandler))
// write a request to be buffered inside the ProxyConnectHandler
// it will be unbuffered when the handler completes and removes itself
let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "apple.com")
var promises: [EventLoopPromise<Void>] = []
promises.append(embedded.eventLoop.makePromise())
embedded.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promises.last)
promises.append(embedded.eventLoop.makePromise())
embedded.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "Test")))), promise: promises.last)
promises.append(embedded.eventLoop.makePromise())
embedded.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promises.last)
embedded.pipeline.flush()
// read the connect header back
let connectHead = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertHead()
XCTAssertEqual(connectHead.method, .CONNECT)
XCTAssertEqual(connectHead.uri, "swift.org:443")
XCTAssertNil(connectHead.headers["proxy-authorization"].first)
let connectTrailers = try XCTUnwrap(try embedded.readOutbound(as: HTTPClientRequestPart.self)).assertEnd()
XCTAssertNil(connectTrailers)
// ensure that nothing has been unbuffered by mistake
XCTAssertNil(try embedded.readOutbound(as: HTTPClientRequestPart.self))
let responseHead = HTTPResponseHead(version: .http1_1, status: .internalServerError)
XCTAssertThrowsError(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) {
XCTAssertEqual($0 as? NIOHTTP1ProxyConnectHandler.Error, .invalidProxyResponseHead(responseHead))
}
XCTAssertFalse(embedded.isActive, "Channel should be closed in response to the error")
XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil)))
XCTAssertThrowsError(try proxyConnectPromise.futureResult.wait()) {
XCTAssertEqual($0 as? NIOHTTP1ProxyConnectHandler.Error, .invalidProxyResponseHead(responseHead))
}
// buffered writes are dropped
XCTAssertNil(try embedded.readOutbound(as: HTTPClientRequestPart.self))
// all outstanding buffered write promises should be completed
let resultFutures = promises.map { $0.futureResult }
XCTAssertNoThrow(_ = try EventLoopFuture.whenAllComplete(resultFutures, on: embedded.eventLoop).wait())
}
}
struct HTTPRequestPartMismatch: Error {}
extension HTTPClientRequestPart {
@discardableResult
func assertHead(file: StaticString = #file, line: UInt = #line) throws -> HTTPRequestHead {
switch self {
case .head(let head):
return head
default:
XCTFail("Expected .head but got \(self)", file: file, line: line)
throw HTTPRequestPartMismatch()
}
}
@discardableResult
func assertBody(file: StaticString = #file, line: UInt = #line) throws -> ByteBuffer {
switch self {
case .body(.byteBuffer(let body)):
return body
default:
XCTFail("Expected .body but got \(self)", file: file, line: line)
throw HTTPRequestPartMismatch()
}
}
@discardableResult
func assertEnd(file: StaticString = #file, line: UInt = #line) throws -> HTTPHeaders? {
switch self {
case .end(let trailers):
return trailers
default:
XCTFail("Expected .end but got \(self)", file: file, line: line)
throw HTTPRequestPartMismatch()
}
}
}