diff --git a/Sources/NIOExtras/NIORequestIdentifiable.swift b/Sources/NIOExtras/NIORequestIdentifiable.swift new file mode 100644 index 0000000..6b5ba8c --- /dev/null +++ b/Sources/NIOExtras/NIORequestIdentifiable.swift @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +public protocol NIORequestIdentifiable { + associatedtype RequestID: Hashable + + var requestID: RequestID { get } +} diff --git a/Sources/NIOExtras/RequestResponseWithIDHandler.swift b/Sources/NIOExtras/RequestResponseWithIDHandler.swift new file mode 100644 index 0000000..ace1f59 --- /dev/null +++ b/Sources/NIOExtras/RequestResponseWithIDHandler.swift @@ -0,0 +1,144 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +/// `NIORequestResponseWithIDHandler` receives a `Request` alongside an `EventLoopPromise` from the +/// `Channel`'s outbound side. It will fulfill the promise with the `Response` once it's received from the `Channel`'s +/// inbound side. Requests and responses can arrive out-of-order and are matched by the virtue of being +/// `NIORequestIdentifiable`. +/// +/// `NIORequestResponseWithIDHandler` does support pipelining `Request`s and it will send them pipelined further down the +/// `Channel`. Should `RequestResponseHandler` receive an error from the `Channel`, it will fail all promises meant for +/// the outstanding `Reponse`s and close the `Channel`. All requests enqueued after an error occured will be immediately +/// failed with the first error the channel received. +/// +/// `NIORequestResponseWithIDHandler` does _not_ require that the `Response`s arrive on `Channel` in the same order as +/// the `Request`s were submitted. They are matched by their `requestID` property (from `NIORequestIdentifiable`). +public final class NIORequestResponseWithIDHandler: ChannelDuplexHandler + where Request.RequestID == Response.RequestID { + public typealias InboundIn = Response + public typealias InboundOut = Never + public typealias OutboundIn = (Request, EventLoopPromise) + public typealias OutboundOut = Request + + private enum State { + case operational + case inactive + case error(Error) + + var isOperational: Bool { + switch self { + case .operational: + return true + case .error, .inactive: + return false + } + } + } + + private var state: State = .operational + private var promiseBuffer: [Request.RequestID: EventLoopPromise] + + /// Create a new `RequestResponseHandler`. + /// + /// - parameters: + /// - initialBufferCapacity: `RequestResponseHandler` saves the promises for all outstanding responses in a + /// buffer. `initialBufferCapacity` is the initial capacity for this buffer. You usually do not need to set + /// this parameter unless you intend to pipeline very deeply and don't want the buffer to resize. + public init(initialBufferCapacity: Int = 4) { + self.promiseBuffer = [:] + self.promiseBuffer.reserveCapacity(initialBufferCapacity) + } + + public func channelInactive(context: ChannelHandlerContext) { + switch self.state { + case .error: + // We failed any outstanding promises when we entered the error state and will fail any + // new promises in write. + assert(self.promiseBuffer.count == 0) + case .inactive: + assert(self.promiseBuffer.count == 0) + // This is weird, we shouldn't get this more than once but it's not the end of the world either. But in + // debug we probably want to crash. + assertionFailure("Received channelInactive on an already-inactive NIORequestResponseWithIDHandler") + case .operational: + let promiseBuffer = self.promiseBuffer + self.promiseBuffer.removeAll() + self.state = .inactive + promiseBuffer.forEach { promise in + promise.value.fail(NIOExtrasErrors.ClosedBeforeReceivingResponse()) + } + } + context.fireChannelInactive() + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + guard self.state.isOperational else { + // we're in an error state, ignore further responses + assert(self.promiseBuffer.count == 0) + return + } + + let response = self.unwrapInboundIn(data) + if let promise = self.promiseBuffer.removeValue(forKey: response.requestID) { + promise.succeed(response) + } else { + context.fireErrorCaught(NIOExtrasErrors.ResponseForInvalidRequest(requestID: response.requestID)) + } + } + + public func errorCaught(context: ChannelHandlerContext, error: Error) { + guard self.state.isOperational else { + assert(self.promiseBuffer.count == 0) + return + } + self.state = .error(error) + let promiseBuffer = self.promiseBuffer + self.promiseBuffer.removeAll() + context.close(promise: nil) + promiseBuffer.forEach { + $0.value.fail(error) + } + } + + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let (request, responsePromise) = self.unwrapOutboundIn(data) + switch self.state { + case .error(let error): + assert(self.promiseBuffer.count == 0) + responsePromise.fail(error) + promise?.fail(error) + case .inactive: + assert(self.promiseBuffer.count == 0) + promise?.fail(ChannelError.ioOnClosedChannel) + responsePromise.fail(ChannelError.ioOnClosedChannel) + case .operational: + self.promiseBuffer[request.requestID] = responsePromise + context.write(self.wrapOutboundOut(request), promise: promise) + } + } +} + +extension NIOExtrasErrors { + public struct ResponseForInvalidRequest: NIOExtrasError, Equatable { + public var requestID: Response.RequestID + + public init(requestID: Response.RequestID) { + self.requestID = requestID + } + } +} + diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index b96c849..ae62878 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -59,6 +59,7 @@ class LinuxMainRunner { testCase(PCAPRingBufferTest.allTests), testCase(QuiescingHelperTest.allTests), testCase(RequestResponseHandlerTest.allTests), + testCase(RequestResponseWithIDHandlerTest.allTests), testCase(SOCKSServerHandlerTests.allTests), testCase(ServerResponseTests.allTests), testCase(ServerStateMachineTests.allTests), diff --git a/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest+XCTest.swift b/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest+XCTest.swift new file mode 100644 index 0000000..f77ee8e --- /dev/null +++ b/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest+XCTest.swift @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// RequestResponseWithIDHandlerTest+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension RequestResponseWithIDHandlerTest { + + @available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings") + static var allTests : [(String, (RequestResponseWithIDHandlerTest) -> () throws -> Void)] { + return [ + ("testSimpleRequestWorks", testSimpleRequestWorks), + ("testEnqueingMultipleRequestsWorks", testEnqueingMultipleRequestsWorks), + ("testRequestsEnqueuedAfterErrorAreFailed", testRequestsEnqueuedAfterErrorAreFailed), + ("testRequestsEnqueuedJustBeforeErrorAreFailed", testRequestsEnqueuedJustBeforeErrorAreFailed), + ("testClosedConnectionFailsOutstandingPromises", testClosedConnectionFailsOutstandingPromises), + ("testOutOfOrderResponsesWork", testOutOfOrderResponsesWork), + ("testErrorOnResponseForNonExistantRequest", testErrorOnResponseForNonExistantRequest), + ("testMoreRequestsAfterChannelInactiveFail", testMoreRequestsAfterChannelInactiveFail), + ] + } +} + diff --git a/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift b/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift new file mode 100644 index 0000000..51e6a35 --- /dev/null +++ b/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift @@ -0,0 +1,305 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +import NIOCore +import NIOEmbedded +import NIOExtras + +class RequestResponseWithIDHandlerTest: XCTestCase { + private var eventLoop: EmbeddedEventLoop! + private var channel: EmbeddedChannel! + private var buffer: ByteBuffer! + + override func setUp() { + super.setUp() + + self.eventLoop = EmbeddedEventLoop() + self.channel = EmbeddedChannel(loop: self.eventLoop) + self.buffer = self.channel.allocator.buffer(capacity: 16) + } + + override func tearDown() { + self.buffer = nil + self.eventLoop = nil + if self.channel.isActive { + XCTAssertNoThrow(XCTAssertTrue(try self.channel.finish().isClean)) + } + + super.tearDown() + } + + func testSimpleRequestWorks() { + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + self.buffer.writeString("hello") + + // pretend to connect to the EmbeddedChannel knows it's supposed to be active + XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait()) + + let p: EventLoopPromise> = self.channel.eventLoop.makePromise() + // write request + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + p))) + // write response + XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay"))) + // verify request was forwarded + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + try self.channel.readOutbound()) + // verify response was not forwarded + XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID.self)) + // verify the promise got succeeded with the response + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay"), try p.futureResult.wait()) + } + + func testEnqueingMultipleRequestsWorks() throws { + struct DummyError: Error {} + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + + var futures: [EventLoopFuture>] = [] + // pretend to connect to the EmbeddedChannel knows it's supposed to be active + XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait()) + + for reqId in 0..<5 { + self.buffer.clear() + self.buffer.writeString("\(reqId)") + + let p: EventLoopPromise> = self.channel.eventLoop.makePromise() + futures.append(p.futureResult) + + // write request + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: reqId, + value: IOData.byteBuffer(self.buffer)), p))) + } + + // let's have 3 successful responses + for reqIdExpected in 0..<3 { + switch try self.channel.readOutbound(as: ValueWithRequestID.self) { + case .some(let req): + guard case .byteBuffer(var buffer) = req.value else { + XCTFail("wrong type") + return + } + if let reqId = buffer.readString(length: buffer.readableBytes).flatMap(Int.init) { + // write response + try self.channel.writeInbound(ValueWithRequestID(requestID: reqId, value: reqId)) + } else { + XCTFail("couldn't get request id") + } + default: + XCTFail("could not find request") + } + XCTAssertNoThrow(XCTAssertEqual(ValueWithRequestID(requestID: reqIdExpected, value: reqIdExpected), + try futures[reqIdExpected].wait())) + } + + // validate the Channel is active + XCTAssertTrue(self.channel.isActive) + self.channel.pipeline.fireErrorCaught(DummyError()) + + // after receiving an error, it should be closed + XCTAssertFalse(self.channel.isActive) + + for failedReqId in 3..<5 { + XCTAssertThrowsError(try futures[failedReqId].wait()) { error in + XCTAssertNotNil(error as? DummyError) + } + } + + // verify no response was not forwarded + XCTAssertNoThrow(XCTAssertEqual(nil, try self.channel.readInbound(as: IOData.self))) + } + + func testRequestsEnqueuedAfterErrorAreFailed() { + struct DummyError: Error {} + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + + self.channel.pipeline.fireErrorCaught(DummyError()) + + let p: EventLoopPromise> = self.eventLoop.makePromise() + XCTAssertThrowsError(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, + value: IOData.byteBuffer(self.buffer)), p))) { error in + XCTAssertNotNil(error as? DummyError) + } + XCTAssertThrowsError(try p.futureResult.wait()) { error in + XCTAssertNotNil(error as? DummyError) + } + } + + func testRequestsEnqueuedJustBeforeErrorAreFailed() { + struct DummyError1: Error {} + struct DummyError2: Error {} + + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + + let p: EventLoopPromise> = self.eventLoop.makePromise() + // right now, everything's still okay so the enqueued request won't immediately be failed + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + p))) + + // but whilst we're waiting for the response, an error turns up + self.channel.pipeline.fireErrorCaught(DummyError1()) + + // we'll also fire a second error through the pipeline that shouldn't do anything + self.channel.pipeline.fireErrorCaught(DummyError2()) + + + // and just after the error, the response arrives too (but too late) + XCTAssertNoThrow(try self.channel.writeInbound(())) + + XCTAssertThrowsError(try p.futureResult.wait()) { error in + XCTAssertNotNil(error as? DummyError1) + } + } + + func testClosedConnectionFailsOutstandingPromises() { + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + + let promise = self.eventLoop.makePromise(of: ValueWithRequestID.self) + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "Hello!"), promise))) + + XCTAssertNoThrow(try self.channel.close().wait()) + XCTAssertThrowsError(try promise.futureResult.wait()) { error in + XCTAssertTrue(error is NIOExtrasErrors.ClosedBeforeReceivingResponse) + } + } + + func testOutOfOrderResponsesWork() { + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + self.buffer.writeString("hello") + + // pretend to connect to the EmbeddedChannel knows it's supposed to be active + XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait()) + + let p1: EventLoopPromise> = self.channel.eventLoop.makePromise() + let p2: EventLoopPromise> = self.channel.eventLoop.makePromise() + + // write requests + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "1"), p1))) + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 2, value: "2"), p2))) + // write responses but out of order + XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 2, value: "okay 2"))) + XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay 1"))) + // verify requests was forwarded + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "1"), try self.channel.readOutbound()) + XCTAssertEqual(ValueWithRequestID(requestID: 2, value: "2"), try self.channel.readOutbound()) + // verify responses were not forwarded + XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID.self)) + // verify the promises got succeeded with the response + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay 1"), try p1.futureResult.wait()) + XCTAssertEqual(ValueWithRequestID(requestID: 2, value: "okay 2"), try p2.futureResult.wait()) + } + + func testErrorOnResponseForNonExistantRequest() { + XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + self.buffer.writeString("hello") + + // pretend to connect to the EmbeddedChannel knows it's supposed to be active + XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait()) + + let p1: EventLoopPromise> = self.channel.eventLoop.makePromise() + + // write request + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "1"), p1))) + // write wrong response + XCTAssertThrowsError(try self.channel.writeInbound(ValueWithRequestID(requestID: 2, value: "okay 2"))) { error in + guard let error = error as? NIOExtrasErrors.ResponseForInvalidRequest> else { + XCTFail("wrong error") + return + } + XCTAssertEqual(2, error.requestID) + + } + // verify requests was forwarded + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "1"), try self.channel.readOutbound()) + // verify responses were not forwarded + XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID.self)) + // verify the promises got succeeded with the response + } + + func testMoreRequestsAfterChannelInactiveFail() { + final class EmitRequestOnInactiveHandler: ChannelDuplexHandler { + typealias InboundIn = Never + typealias OutboundIn = (ValueWithRequestID, EventLoopPromise>) + typealias OutboundOut = (ValueWithRequestID, EventLoopPromise>) + + func channelInactive(context: ChannelHandlerContext) { + let responsePromise = context.eventLoop.makePromise(of: ValueWithRequestID.self) + let writePromise = context.eventLoop.makePromise(of: Void.self) + context.writeAndFlush(self.wrapOutboundOut( + (ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(ByteBuffer(string: "hi"))), responsePromise) + ), + promise: writePromise) + var writePromiseCompleted = false + defer { + XCTAssertTrue(writePromiseCompleted) + } + var responsePromiseCompleted = false + defer { + XCTAssertTrue(responsePromiseCompleted) + } + writePromise.futureResult.whenComplete { result in + writePromiseCompleted = true + switch result { + case .success: + XCTFail("shouldn't succeed") + case .failure(let error): + XCTAssertEqual(.ioOnClosedChannel, error as? ChannelError) + } + } + responsePromise.futureResult.whenComplete { result in + responsePromiseCompleted = true + switch result { + case .success: + XCTFail("shouldn't succeed") + case .failure(let error): + XCTAssertEqual(.ioOnClosedChannel, error as? ChannelError) + } + } + } + } + + XCTAssertNoThrow(try self.channel.pipeline.addHandlers( + NIORequestResponseWithIDHandler, ValueWithRequestID>(), + EmitRequestOnInactiveHandler() + ).wait()) + self.buffer.writeString("hello") + + // pretend to connect to the EmbeddedChannel knows it's supposed to be active + XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait()) + + let p: EventLoopPromise> = self.channel.eventLoop.makePromise() + // write request + XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + p))) + // write response + XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay"))) + + // verify request was forwarded + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + try self.channel.readOutbound()) + + // verify the promise got succeeded with the response + XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay"), try p.futureResult.wait()) + } +} + +struct ValueWithRequestID: NIORequestIdentifiable { + typealias RequestID = Int + + var requestID: Int + var value: T +} + +extension ValueWithRequestID: Equatable where T: Equatable { +} diff --git a/Tests/NIONFS3Tests/NFS3FileSystemTests+XCTest.swift b/Tests/NIONFS3Tests/NFS3FileSystemTests+XCTest.swift index a45d480..78cacfe 100644 --- a/Tests/NIONFS3Tests/NFS3FileSystemTests+XCTest.swift +++ b/Tests/NIONFS3Tests/NFS3FileSystemTests+XCTest.swift @@ -2,7 +2,7 @@ // // This source file is part of the SwiftNIO open source project // -// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors +// Copyright (c) 2018-2023 Apple Inc. and the SwiftNIO project authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information diff --git a/Tests/NIONFS3Tests/NFS3ReplyEncoderTest+XCTest.swift b/Tests/NIONFS3Tests/NFS3ReplyEncoderTest+XCTest.swift index f2c5309..96148cb 100644 --- a/Tests/NIONFS3Tests/NFS3ReplyEncoderTest+XCTest.swift +++ b/Tests/NIONFS3Tests/NFS3ReplyEncoderTest+XCTest.swift @@ -2,7 +2,7 @@ // // This source file is part of the SwiftNIO open source project // -// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors +// Copyright (c) 2018-2023 Apple Inc. and the SwiftNIO project authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information diff --git a/Tests/NIONFS3Tests/NFS3RoundtripTests+XCTest.swift b/Tests/NIONFS3Tests/NFS3RoundtripTests+XCTest.swift index 8dff52c..c23d70b 100644 --- a/Tests/NIONFS3Tests/NFS3RoundtripTests+XCTest.swift +++ b/Tests/NIONFS3Tests/NFS3RoundtripTests+XCTest.swift @@ -2,7 +2,7 @@ // // This source file is part of the SwiftNIO open source project // -// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors +// Copyright (c) 2018-2023 Apple Inc. and the SwiftNIO project authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information