NIORequestResponseWithIDHandler

This commit is contained in:
Johannes Weiss 2022-04-21 16:42:07 +01:00
parent 4569c6911b
commit bc7bd162f3
8 changed files with 513 additions and 3 deletions

View File

@ -0,0 +1,19 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
public protocol NIORequestIdentifiable {
associatedtype RequestID: Hashable
var requestID: RequestID { get }
}

View File

@ -0,0 +1,144 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
/// `NIORequestResponseWithIDHandler` receives a `Request` alongside an `EventLoopPromise<Response>` from the
/// `Channel`'s outbound side. It will fulfill the promise with the `Response` once it's received from the `Channel`'s
/// inbound side. Requests and responses can arrive out-of-order and are matched by the virtue of being
/// `NIORequestIdentifiable`.
///
/// `NIORequestResponseWithIDHandler` does support pipelining `Request`s and it will send them pipelined further down the
/// `Channel`. Should `RequestResponseHandler` receive an error from the `Channel`, it will fail all promises meant for
/// the outstanding `Reponse`s and close the `Channel`. All requests enqueued after an error occured will be immediately
/// failed with the first error the channel received.
///
/// `NIORequestResponseWithIDHandler` does _not_ require that the `Response`s arrive on `Channel` in the same order as
/// the `Request`s were submitted. They are matched by their `requestID` property (from `NIORequestIdentifiable`).
public final class NIORequestResponseWithIDHandler<Request: NIORequestIdentifiable,
Response: NIORequestIdentifiable>: ChannelDuplexHandler
where Request.RequestID == Response.RequestID {
public typealias InboundIn = Response
public typealias InboundOut = Never
public typealias OutboundIn = (Request, EventLoopPromise<Response>)
public typealias OutboundOut = Request
private enum State {
case operational
case inactive
case error(Error)
var isOperational: Bool {
switch self {
case .operational:
return true
case .error, .inactive:
return false
}
}
}
private var state: State = .operational
private var promiseBuffer: [Request.RequestID: EventLoopPromise<Response>]
/// Create a new `RequestResponseHandler`.
///
/// - parameters:
/// - initialBufferCapacity: `RequestResponseHandler` saves the promises for all outstanding responses in a
/// buffer. `initialBufferCapacity` is the initial capacity for this buffer. You usually do not need to set
/// this parameter unless you intend to pipeline very deeply and don't want the buffer to resize.
public init(initialBufferCapacity: Int = 4) {
self.promiseBuffer = [:]
self.promiseBuffer.reserveCapacity(initialBufferCapacity)
}
public func channelInactive(context: ChannelHandlerContext) {
switch self.state {
case .error:
// We failed any outstanding promises when we entered the error state and will fail any
// new promises in write.
assert(self.promiseBuffer.count == 0)
case .inactive:
assert(self.promiseBuffer.count == 0)
// This is weird, we shouldn't get this more than once but it's not the end of the world either. But in
// debug we probably want to crash.
assertionFailure("Received channelInactive on an already-inactive NIORequestResponseWithIDHandler")
case .operational:
let promiseBuffer = self.promiseBuffer
self.promiseBuffer.removeAll()
self.state = .inactive
promiseBuffer.forEach { promise in
promise.value.fail(NIOExtrasErrors.ClosedBeforeReceivingResponse())
}
}
context.fireChannelInactive()
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
guard self.state.isOperational else {
// we're in an error state, ignore further responses
assert(self.promiseBuffer.count == 0)
return
}
let response = self.unwrapInboundIn(data)
if let promise = self.promiseBuffer.removeValue(forKey: response.requestID) {
promise.succeed(response)
} else {
context.fireErrorCaught(NIOExtrasErrors.ResponseForInvalidRequest<Response>(requestID: response.requestID))
}
}
public func errorCaught(context: ChannelHandlerContext, error: Error) {
guard self.state.isOperational else {
assert(self.promiseBuffer.count == 0)
return
}
self.state = .error(error)
let promiseBuffer = self.promiseBuffer
self.promiseBuffer.removeAll()
context.close(promise: nil)
promiseBuffer.forEach {
$0.value.fail(error)
}
}
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let (request, responsePromise) = self.unwrapOutboundIn(data)
switch self.state {
case .error(let error):
assert(self.promiseBuffer.count == 0)
responsePromise.fail(error)
promise?.fail(error)
case .inactive:
assert(self.promiseBuffer.count == 0)
promise?.fail(ChannelError.ioOnClosedChannel)
responsePromise.fail(ChannelError.ioOnClosedChannel)
case .operational:
self.promiseBuffer[request.requestID] = responsePromise
context.write(self.wrapOutboundOut(request), promise: promise)
}
}
}
extension NIOExtrasErrors {
public struct ResponseForInvalidRequest<Response: NIORequestIdentifiable>: NIOExtrasError, Equatable {
public var requestID: Response.RequestID
public init(requestID: Response.RequestID) {
self.requestID = requestID
}
}
}

View File

@ -59,6 +59,7 @@ class LinuxMainRunner {
testCase(PCAPRingBufferTest.allTests),
testCase(QuiescingHelperTest.allTests),
testCase(RequestResponseHandlerTest.allTests),
testCase(RequestResponseWithIDHandlerTest.allTests),
testCase(SOCKSServerHandlerTests.allTests),
testCase(ServerResponseTests.allTests),
testCase(ServerStateMachineTests.allTests),

View File

@ -0,0 +1,41 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//
// RequestResponseWithIDHandlerTest+XCTest.swift
//
import XCTest
///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///
extension RequestResponseWithIDHandlerTest {
@available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings")
static var allTests : [(String, (RequestResponseWithIDHandlerTest) -> () throws -> Void)] {
return [
("testSimpleRequestWorks", testSimpleRequestWorks),
("testEnqueingMultipleRequestsWorks", testEnqueingMultipleRequestsWorks),
("testRequestsEnqueuedAfterErrorAreFailed", testRequestsEnqueuedAfterErrorAreFailed),
("testRequestsEnqueuedJustBeforeErrorAreFailed", testRequestsEnqueuedJustBeforeErrorAreFailed),
("testClosedConnectionFailsOutstandingPromises", testClosedConnectionFailsOutstandingPromises),
("testOutOfOrderResponsesWork", testOutOfOrderResponsesWork),
("testErrorOnResponseForNonExistantRequest", testErrorOnResponseForNonExistantRequest),
("testMoreRequestsAfterChannelInactiveFail", testMoreRequestsAfterChannelInactiveFail),
]
}
}

View File

@ -0,0 +1,305 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
import NIOCore
import NIOEmbedded
import NIOExtras
class RequestResponseWithIDHandlerTest: XCTestCase {
private var eventLoop: EmbeddedEventLoop!
private var channel: EmbeddedChannel!
private var buffer: ByteBuffer!
override func setUp() {
super.setUp()
self.eventLoop = EmbeddedEventLoop()
self.channel = EmbeddedChannel(loop: self.eventLoop)
self.buffer = self.channel.allocator.buffer(capacity: 16)
}
override func tearDown() {
self.buffer = nil
self.eventLoop = nil
if self.channel.isActive {
XCTAssertNoThrow(XCTAssertTrue(try self.channel.finish().isClean))
}
super.tearDown()
}
func testSimpleRequestWorks() {
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<IOData>, ValueWithRequestID<String>>()).wait())
self.buffer.writeString("hello")
// pretend to connect to the EmbeddedChannel knows it's supposed to be active
XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait())
let p: EventLoopPromise<ValueWithRequestID<String>> = self.channel.eventLoop.makePromise()
// write request
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)),
p)))
// write response
XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay")))
// verify request was forwarded
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)),
try self.channel.readOutbound())
// verify response was not forwarded
XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID<IOData>.self))
// verify the promise got succeeded with the response
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay"), try p.futureResult.wait())
}
func testEnqueingMultipleRequestsWorks() throws {
struct DummyError: Error {}
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<IOData>, ValueWithRequestID<Int>>()).wait())
var futures: [EventLoopFuture<ValueWithRequestID<Int>>] = []
// pretend to connect to the EmbeddedChannel knows it's supposed to be active
XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait())
for reqId in 0..<5 {
self.buffer.clear()
self.buffer.writeString("\(reqId)")
let p: EventLoopPromise<ValueWithRequestID<Int>> = self.channel.eventLoop.makePromise()
futures.append(p.futureResult)
// write request
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: reqId,
value: IOData.byteBuffer(self.buffer)), p)))
}
// let's have 3 successful responses
for reqIdExpected in 0..<3 {
switch try self.channel.readOutbound(as: ValueWithRequestID<IOData>.self) {
case .some(let req):
guard case .byteBuffer(var buffer) = req.value else {
XCTFail("wrong type")
return
}
if let reqId = buffer.readString(length: buffer.readableBytes).flatMap(Int.init) {
// write response
try self.channel.writeInbound(ValueWithRequestID(requestID: reqId, value: reqId))
} else {
XCTFail("couldn't get request id")
}
default:
XCTFail("could not find request")
}
XCTAssertNoThrow(XCTAssertEqual(ValueWithRequestID(requestID: reqIdExpected, value: reqIdExpected),
try futures[reqIdExpected].wait()))
}
// validate the Channel is active
XCTAssertTrue(self.channel.isActive)
self.channel.pipeline.fireErrorCaught(DummyError())
// after receiving an error, it should be closed
XCTAssertFalse(self.channel.isActive)
for failedReqId in 3..<5 {
XCTAssertThrowsError(try futures[failedReqId].wait()) { error in
XCTAssertNotNil(error as? DummyError)
}
}
// verify no response was not forwarded
XCTAssertNoThrow(XCTAssertEqual(nil, try self.channel.readInbound(as: IOData.self)))
}
func testRequestsEnqueuedAfterErrorAreFailed() {
struct DummyError: Error {}
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<IOData>, ValueWithRequestID<Void>>()).wait())
self.channel.pipeline.fireErrorCaught(DummyError())
let p: EventLoopPromise<ValueWithRequestID<Void>> = self.eventLoop.makePromise()
XCTAssertThrowsError(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1,
value: IOData.byteBuffer(self.buffer)), p))) { error in
XCTAssertNotNil(error as? DummyError)
}
XCTAssertThrowsError(try p.futureResult.wait()) { error in
XCTAssertNotNil(error as? DummyError)
}
}
func testRequestsEnqueuedJustBeforeErrorAreFailed() {
struct DummyError1: Error {}
struct DummyError2: Error {}
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<IOData>, ValueWithRequestID<Void>>()).wait())
let p: EventLoopPromise<ValueWithRequestID<Void>> = self.eventLoop.makePromise()
// right now, everything's still okay so the enqueued request won't immediately be failed
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)),
p)))
// but whilst we're waiting for the response, an error turns up
self.channel.pipeline.fireErrorCaught(DummyError1())
// we'll also fire a second error through the pipeline that shouldn't do anything
self.channel.pipeline.fireErrorCaught(DummyError2())
// and just after the error, the response arrives too (but too late)
XCTAssertNoThrow(try self.channel.writeInbound(()))
XCTAssertThrowsError(try p.futureResult.wait()) { error in
XCTAssertNotNil(error as? DummyError1)
}
}
func testClosedConnectionFailsOutstandingPromises() {
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<String>, ValueWithRequestID<Void>>()).wait())
let promise = self.eventLoop.makePromise(of: ValueWithRequestID<Void>.self)
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "Hello!"), promise)))
XCTAssertNoThrow(try self.channel.close().wait())
XCTAssertThrowsError(try promise.futureResult.wait()) { error in
XCTAssertTrue(error is NIOExtrasErrors.ClosedBeforeReceivingResponse)
}
}
func testOutOfOrderResponsesWork() {
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<String>, ValueWithRequestID<String>>()).wait())
self.buffer.writeString("hello")
// pretend to connect to the EmbeddedChannel knows it's supposed to be active
XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait())
let p1: EventLoopPromise<ValueWithRequestID<String>> = self.channel.eventLoop.makePromise()
let p2: EventLoopPromise<ValueWithRequestID<String>> = self.channel.eventLoop.makePromise()
// write requests
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "1"), p1)))
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 2, value: "2"), p2)))
// write responses but out of order
XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 2, value: "okay 2")))
XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay 1")))
// verify requests was forwarded
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "1"), try self.channel.readOutbound())
XCTAssertEqual(ValueWithRequestID(requestID: 2, value: "2"), try self.channel.readOutbound())
// verify responses were not forwarded
XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID<IOData>.self))
// verify the promises got succeeded with the response
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay 1"), try p1.futureResult.wait())
XCTAssertEqual(ValueWithRequestID(requestID: 2, value: "okay 2"), try p2.futureResult.wait())
}
func testErrorOnResponseForNonExistantRequest() {
XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler<ValueWithRequestID<String>, ValueWithRequestID<String>>()).wait())
self.buffer.writeString("hello")
// pretend to connect to the EmbeddedChannel knows it's supposed to be active
XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait())
let p1: EventLoopPromise<ValueWithRequestID<String>> = self.channel.eventLoop.makePromise()
// write request
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "1"), p1)))
// write wrong response
XCTAssertThrowsError(try self.channel.writeInbound(ValueWithRequestID(requestID: 2, value: "okay 2"))) { error in
guard let error = error as? NIOExtrasErrors.ResponseForInvalidRequest<ValueWithRequestID<String>> else {
XCTFail("wrong error")
return
}
XCTAssertEqual(2, error.requestID)
}
// verify requests was forwarded
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "1"), try self.channel.readOutbound())
// verify responses were not forwarded
XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID<IOData>.self))
// verify the promises got succeeded with the response
}
func testMoreRequestsAfterChannelInactiveFail() {
final class EmitRequestOnInactiveHandler: ChannelDuplexHandler {
typealias InboundIn = Never
typealias OutboundIn = (ValueWithRequestID<IOData>, EventLoopPromise<ValueWithRequestID<String>>)
typealias OutboundOut = (ValueWithRequestID<IOData>, EventLoopPromise<ValueWithRequestID<String>>)
func channelInactive(context: ChannelHandlerContext) {
let responsePromise = context.eventLoop.makePromise(of: ValueWithRequestID<String>.self)
let writePromise = context.eventLoop.makePromise(of: Void.self)
context.writeAndFlush(self.wrapOutboundOut(
(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(ByteBuffer(string: "hi"))), responsePromise)
),
promise: writePromise)
var writePromiseCompleted = false
defer {
XCTAssertTrue(writePromiseCompleted)
}
var responsePromiseCompleted = false
defer {
XCTAssertTrue(responsePromiseCompleted)
}
writePromise.futureResult.whenComplete { result in
writePromiseCompleted = true
switch result {
case .success:
XCTFail("shouldn't succeed")
case .failure(let error):
XCTAssertEqual(.ioOnClosedChannel, error as? ChannelError)
}
}
responsePromise.futureResult.whenComplete { result in
responsePromiseCompleted = true
switch result {
case .success:
XCTFail("shouldn't succeed")
case .failure(let error):
XCTAssertEqual(.ioOnClosedChannel, error as? ChannelError)
}
}
}
}
XCTAssertNoThrow(try self.channel.pipeline.addHandlers(
NIORequestResponseWithIDHandler<ValueWithRequestID<IOData>, ValueWithRequestID<String>>(),
EmitRequestOnInactiveHandler()
).wait())
self.buffer.writeString("hello")
// pretend to connect to the EmbeddedChannel knows it's supposed to be active
XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait())
let p: EventLoopPromise<ValueWithRequestID<String>> = self.channel.eventLoop.makePromise()
// write request
XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)),
p)))
// write response
XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay")))
// verify request was forwarded
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)),
try self.channel.readOutbound())
// verify the promise got succeeded with the response
XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay"), try p.futureResult.wait())
}
}
struct ValueWithRequestID<T>: NIORequestIdentifiable {
typealias RequestID = Int
var requestID: Int
var value: T
}
extension ValueWithRequestID: Equatable where T: Equatable {
}

View File

@ -2,7 +2,7 @@
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors
// Copyright (c) 2018-2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information

View File

@ -2,7 +2,7 @@
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors
// Copyright (c) 2018-2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information

View File

@ -2,7 +2,7 @@
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors
// Copyright (c) 2018-2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information