diff --git a/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift b/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift index feb5b03..4117015 100644 --- a/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift +++ b/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift @@ -14,6 +14,38 @@ import NIO +extension ByteBuffer { + @inlinable + mutating func get24UInt( + at index: Int, + endianness: Endianness = .big + ) -> UInt32? { + let mostSignificant: UInt16 + let leastSignificant: UInt8 + switch endianness { + case .big: + guard let uint16 = self.getInteger(at: index, endianness: .big, as: UInt16.self), + let uint8 = self.getInteger(at: index + 2, endianness: .big, as: UInt8.self) else { return nil } + mostSignificant = uint16 + leastSignificant = uint8 + case .little: + guard let uint8 = self.getInteger(at: index, endianness: .little, as: UInt8.self), + let uint16 = self.getInteger(at: index + 1, endianness: .little, as: UInt16.self) else { return nil } + mostSignificant = uint16 + leastSignificant = uint8 + } + return (UInt32(mostSignificant) << 8) &+ UInt32(leastSignificant) + } + @inlinable + mutating func read24UInt( + endianness: Endianness = .big + ) -> UInt32? { + guard let integer = get24UInt(at: self.readerIndex, endianness: endianness) else { return nil } + self.moveReaderIndex(forwardBy: 3) + return integer + } +} + /// /// A decoder that splits the received `ByteBuffer` by the number of bytes specified in a fixed length header /// contained within the buffer. @@ -35,7 +67,6 @@ import NIO public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// /// An enumeration to describe the length of a piece of data in bytes. - /// It is contained to lengths that can be converted to integer types. /// public enum ByteLength { case one @@ -43,16 +74,12 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { case four case eight - var length: Int { + fileprivate var bitLength: NIOLengthFieldBitLength { switch self { - case .one: - return 1 - case .two: - return 2 - case .four: - return 4 - case .eight: - return 8 + case .one: return .oneByte + case .two: return .twoBytes + case .four: return .fourBytes + case .eight: return .eightBytes } } } @@ -73,7 +100,7 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { public var cumulationBuffer: ByteBuffer? private var readState: DecoderReadState = .waitingForHeader - private let lengthFieldLength: ByteLength + private let lengthFieldLength: NIOLengthFieldBitLength private let lengthFieldEndianness: Endianness /// Create `LengthFieldBasedFrameDecoder` with a given frame length. @@ -82,13 +109,23 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// - lengthFieldLength: The length of the field specifying the remaining length of the frame. /// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame. /// - public init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) { + public convenience init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) { + self.init(lengthFieldBitLength: lengthFieldLength.bitLength, lengthFieldEndianness: lengthFieldEndianness) + } + + /// Create `LengthFieldBasedFrameDecoder` with a given frame length. + /// + /// - parameters: + /// - lengthFieldBitLength: The length of the field specifying the remaining length of the frame. + /// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame. + /// + public init(lengthFieldBitLength: NIOLengthFieldBitLength, lengthFieldEndianness: Endianness = .big) { // The value contained in the length field must be able to be represented by an integer type on the platform. // ie. .eight == 64bit which would not fit into the Int type on a 32bit platform. - precondition(lengthFieldLength.length <= Int.bitWidth/8) + precondition(lengthFieldBitLength.length <= Int.bitWidth/8) - self.lengthFieldLength = lengthFieldLength + self.lengthFieldLength = lengthFieldBitLength self.lengthFieldEndianness = lengthFieldEndianness } @@ -156,21 +193,23 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// /// Decodes the specified region of the buffer into an unadjusted frame length. The default implementation is - /// capable of decoding the specified region into an unsigned 8/16/32/64 bit integer. + /// capable of decoding the specified region into an unsigned 8/16/24/32/64 bit integer. /// /// - parameters: /// - buffer: The buffer containing the integer frame length. /// private func readFrameLength(for buffer: inout ByteBuffer) -> Int? { - switch self.lengthFieldLength { - case .one: + switch self.lengthFieldLength.bitLength { + case .bits8: return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt8.self).map { Int($0) } - case .two: + case .bits16: return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt16.self).map { Int($0) } - case .four: + case .bits24: + return buffer.read24UInt(endianness: self.lengthFieldEndianness).map { Int($0) } + case .bits32: return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt32.self).map { Int($0) } - case .eight: + case .bits64: return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt64.self).map { Int($0) } } } diff --git a/Sources/NIOExtras/LengthFieldPrepender.swift b/Sources/NIOExtras/LengthFieldPrepender.swift index 106f214..98b347c 100644 --- a/Sources/NIOExtras/LengthFieldPrepender.swift +++ b/Sources/NIOExtras/LengthFieldPrepender.swift @@ -14,6 +14,26 @@ import NIO +extension ByteBuffer { + @discardableResult + @inlinable + mutating func write24UInt( + _ integer: UInt32, + endianness: Endianness = .big + ) -> Int { + precondition(integer & 0xFF_FF_FF == integer, "integer value does not fit into 24 bit integer") + switch endianness { + case .little: + return writeInteger(UInt8(integer & 0xFF), endianness: .little) + + writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .little) + case .big: + return writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .big) + + writeInteger(UInt8(integer & 0xFF), endianness: .big) + } + } +} + + public enum LengthFieldPrependerError: Error { case messageDataTooLongForLengthField } @@ -31,42 +51,21 @@ public enum LengthFieldPrependerError: Error { /// This initial prepended byte is called the 'length field'. /// public final class LengthFieldPrepender: ChannelOutboundHandler { - /// /// An enumeration to describe the length of a piece of data in bytes. - /// It is constrained to lengths that can be converted to integer types. /// public enum ByteLength { case one case two case four case eight - - fileprivate var length: Int { - - switch self { - case .one: - return 1 - case .two: - return 2 - case .four: - return 4 - case .eight: - return 8 - } - } - fileprivate var max: UInt { - + fileprivate var bitLength: NIOLengthFieldBitLength { switch self { - case .one: - return UInt(UInt8.max) - case .two: - return UInt(UInt16.max) - case .four: - return UInt(UInt32.max) - case .eight: - return UInt(UInt64.max) + case .one: return .oneByte + case .two: return .twoBytes + case .four: return .fourBytes + case .eight: return .eightBytes } } } @@ -74,7 +73,7 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { public typealias OutboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer - private let lengthFieldLength: LengthFieldPrepender.ByteLength + private let lengthFieldLength: NIOLengthFieldBitLength private let lengthFieldEndianness: Endianness private var lengthBuffer: ByteBuffer? @@ -85,13 +84,15 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { /// - lengthFieldLength: The length of the field specifying the remaining length of the frame. /// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame. /// - public init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) { - + public convenience init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) { + self.init(lengthFieldBitLength: lengthFieldLength.bitLength, lengthFieldEndianness: lengthFieldEndianness) + } + public init(lengthFieldBitLength: NIOLengthFieldBitLength, lengthFieldEndianness: Endianness = .big) { // The value contained in the length field must be able to be represented by an integer type on the platform. // ie. .eight == 64bit which would not fit into the Int type on a 32bit platform. - precondition(lengthFieldLength.length <= Int.bitWidth/8) + precondition(lengthFieldBitLength.length <= Int.bitWidth/8) - self.lengthFieldLength = lengthFieldLength + self.lengthFieldLength = lengthFieldBitLength self.lengthFieldEndianness = lengthFieldEndianness } @@ -115,14 +116,16 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { self.lengthBuffer = dataLengthBuffer } - switch self.lengthFieldLength { - case .one: + switch self.lengthFieldLength.bitLength { + case .bits8: dataLengthBuffer.writeInteger(UInt8(dataLength), endianness: self.lengthFieldEndianness) - case .two: + case .bits16: dataLengthBuffer.writeInteger(UInt16(dataLength), endianness: self.lengthFieldEndianness) - case .four: + case .bits24: + dataLengthBuffer.write24UInt(UInt32(dataLength), endianness: self.lengthFieldEndianness) + case .bits32: dataLengthBuffer.writeInteger(UInt32(dataLength), endianness: self.lengthFieldEndianness) - case .eight: + case .bits64: dataLengthBuffer.writeInteger(UInt64(dataLength), endianness: self.lengthFieldEndianness) } diff --git a/Sources/NIOExtras/NIOLengthFieldBitLength.swift b/Sources/NIOExtras/NIOLengthFieldBitLength.swift new file mode 100644 index 0000000..38978b7 --- /dev/null +++ b/Sources/NIOExtras/NIOLengthFieldBitLength.swift @@ -0,0 +1,67 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// An struct to describe the length of a piece of data in bits +public struct NIOLengthFieldBitLength { + internal enum Backing { + case bits8 + case bits16 + case bits24 + case bits32 + case bits64 + } + internal let bitLength: Backing + + public static let oneByte = NIOLengthFieldBitLength(bitLength: .bits8) + public static let twoBytes = NIOLengthFieldBitLength(bitLength: .bits16) + public static let threeBytes = NIOLengthFieldBitLength(bitLength: .bits24) + public static let fourBytes = NIOLengthFieldBitLength(bitLength: .bits32) + public static let eightBytes = NIOLengthFieldBitLength(bitLength: .bits64) + + public static let eightBits = NIOLengthFieldBitLength(bitLength: .bits8) + public static let sixteenBits = NIOLengthFieldBitLength(bitLength: .bits16) + public static let twentyFourBits = NIOLengthFieldBitLength(bitLength: .bits24) + public static let thirtyTwoBits = NIOLengthFieldBitLength(bitLength: .bits32) + public static let sixtyFourBits = NIOLengthFieldBitLength(bitLength: .bits64) + + internal var length: Int { + switch bitLength { + case .bits8: + return 1 + case .bits16: + return 2 + case .bits24: + return 3 + case .bits32: + return 4 + case .bits64: + return 8 + } + } + + internal var max: UInt { + switch bitLength { + case .bits8: + return UInt(UInt8.max) + case .bits16: + return UInt(UInt16.max) + case .bits24: + return (UInt(UInt16.max) << 8) &+ UInt(UInt8.max) + case .bits32: + return UInt(UInt32.max) + case .bits64: + return UInt(UInt64.max) + } + } +} diff --git a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift index 35d1c34..dd48cef 100644 --- a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift @@ -26,8 +26,11 @@ extension LengthFieldBasedFrameDecoderTest { static var allTests : [(String, (LengthFieldBasedFrameDecoderTest) -> () throws -> Void)] { return [ + ("testReadUInt32From3Bytes", testReadUInt32From3Bytes), + ("testReadAndWriteUInt32From3BytesBasicVerification", testReadAndWriteUInt32From3BytesBasicVerification), ("testDecodeWithUInt8HeaderWithData", testDecodeWithUInt8HeaderWithData), ("testDecodeWithUInt16HeaderWithString", testDecodeWithUInt16HeaderWithString), + ("testDecodeWithUInt24HeaderWithString", testDecodeWithUInt24HeaderWithString), ("testDecodeWithUInt32HeaderWithString", testDecodeWithUInt32HeaderWithString), ("testDecodeWithUInt64HeaderWithString", testDecodeWithUInt64HeaderWithString), ("testDecodeWithInt64HeaderWithString", testDecodeWithInt64HeaderWithString), diff --git a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift index a2f2093..7cfb0fb 100644 --- a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift @@ -27,7 +27,39 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { override func setUp() { self.channel = EmbeddedChannel() } - + func testReadUInt32From3Bytes() { + var buffer = ByteBuffer(bytes: [ + 0, 0, 5, + 5, 0, 0, + ]) + XCTAssertEqual(buffer.read24UInt(endianness: .big), 5) + print(buffer.readableBytesView) + XCTAssertEqual(buffer.read24UInt(endianness: .little), 5) + } + func testReadAndWriteUInt32From3BytesBasicVerification() { + let inputs: [UInt32] = [ + 0, + 1, + 5, + UInt32(UInt8.max), + UInt32(UInt16.max), + UInt32(UInt16.max) << 8 &+ UInt32(UInt8.max), + UInt32(UInt8.max) - 1, + UInt32(UInt16.max) - 1, + UInt32(UInt16.max) << 8 &+ UInt32(UInt8.max) - 1, + UInt32(UInt8.max) + 1, + UInt32(UInt16.max) + 1, + ] + + for input in inputs { + var buffer = ByteBuffer() + buffer.write24UInt(input, endianness: .big) + XCTAssertEqual(buffer.read24UInt(endianness: .big), input) + + buffer.write24UInt(input, endianness: .little) + XCTAssertEqual(buffer.read24UInt(endianness: .little), input) + } + } func testDecodeWithUInt8HeaderWithData() throws { self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .one, @@ -71,6 +103,25 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { XCTAssertTrue(try self.channel.finish().isClean) } + func testDecodeWithUInt24HeaderWithString() throws { + + self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldBitLength: .threeBytes, + lengthFieldEndianness: .big)) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) + + var buffer = self.channel.allocator.buffer(capacity: 8) // 3 byte header + 5 character string + buffer.writeBytes([0, 0, 5]) + buffer.writeString(standardDataString) + + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) + + XCTAssertNoThrow(XCTAssertEqual(standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + })) + XCTAssertTrue(try self.channel.finish().isClean) + } + func testDecodeWithUInt32HeaderWithString() throws { self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four, @@ -404,29 +455,36 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { } func testBasicVerification() { - let inputs: [(LengthFieldBasedFrameDecoder.ByteLength, [(Int, String)])] = [ - (.one, [ + let inputs: [(NIOLengthFieldBitLength, [(Int, String)])] = [ + (.oneByte, [ (6, "abcdef"), (0, ""), (9, "123456789"), (Int(UInt8.max), String(decoding: Array(repeating: UInt8(ascii: "X"), count: Int(UInt8.max)), as: Unicode.UTF8.self)), ]), - (.two, [ + (.twoBytes, [ (1, "a"), (0, ""), (9, "123456789"), (307, String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), ]), - (.four, [ + (.threeBytes, [ + (1, "a"), + (0, ""), + (9, "123456789"), + (307, + String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), + ]), + (.fourBytes, [ (1, "a"), (0, ""), (3, "333"), (307, String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), ]), - (.eight, [ + (.eightBytes, [ (1, "a"), (0, ""), (4, "aaaa"), @@ -451,7 +509,7 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { return (bytes, [bytes.getSlice(at: bytes.readerIndex + lenBytes.length, length: input.0)!]) } XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: inputOutputPairs) { - LengthFieldBasedFrameDecoder(lengthFieldLength: lenBytes) + LengthFieldBasedFrameDecoder(lengthFieldBitLength: lenBytes) }) } } diff --git a/Tests/NIOExtrasTests/LengthFieldPrependerTest+XCTest.swift b/Tests/NIOExtrasTests/LengthFieldPrependerTest+XCTest.swift index 010e1e9..a168cbf 100644 --- a/Tests/NIOExtrasTests/LengthFieldPrependerTest+XCTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldPrependerTest+XCTest.swift @@ -26,8 +26,10 @@ extension LengthFieldPrependerTest { static var allTests : [(String, (LengthFieldPrependerTest) -> () throws -> Void)] { return [ + ("testWrite3BytesOfUInt32Write", testWrite3BytesOfUInt32Write), ("testEncodeWithUInt8HeaderWithData", testEncodeWithUInt8HeaderWithData), ("testEncodeWithUInt16HeaderWithString", testEncodeWithUInt16HeaderWithString), + ("testEncodeWithUInt24HeaderWithString", testEncodeWithUInt24HeaderWithString), ("testEncodeWithUInt32HeaderWithString", testEncodeWithUInt32HeaderWithString), ("testEncodeWithUInt64HeaderWithString", testEncodeWithUInt64HeaderWithString), ("testEncodeWithInt64HeaderWithString", testEncodeWithInt64HeaderWithString), diff --git a/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift b/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift index 7f1b789..630d0bc 100644 --- a/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift @@ -14,7 +14,7 @@ import XCTest import NIO -import NIOExtras +@testable import NIOExtras private let standardDataString = "abcde" private let standardDataStringCount = standardDataString.utf8.count @@ -26,7 +26,16 @@ class LengthFieldPrependerTest: XCTestCase { override func setUp() { self.channel = EmbeddedChannel() } - + func testWrite3BytesOfUInt32Write() { + var buffer = ByteBuffer() + buffer.write24UInt(5, endianness: .little) + XCTAssertEqual(Array(buffer.readableBytesView), [5, 0, 0]) + XCTAssertEqual(buffer.read24UInt(endianness: .little), 5) + + buffer.write24UInt(5, endianness: .big) + XCTAssertEqual(Array(buffer.readableBytesView), [0, 0, 5]) + XCTAssertEqual(buffer.read24UInt(endianness: .big), 5) + } func testEncodeWithUInt8HeaderWithData() throws { self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .one, @@ -105,6 +114,48 @@ class LengthFieldPrependerTest: XCTestCase { XCTAssertTrue(try self.channel.finish().isClean) } + func testEncodeWithUInt24HeaderWithString() throws { + + let endianness: Endianness = .little + + self.encoderUnderTest = LengthFieldPrepender(lengthFieldBitLength: .threeBytes, + lengthFieldEndianness: endianness) + + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) + + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) + buffer.writeString(standardDataString) + + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) + + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { + + let sizeInHeader = outputBuffer.read24UInt(endianness: endianness).map({ Int($0) }) + XCTAssertEqual(standardDataStringCount, sizeInHeader) + + let additionalData = outputBuffer.readBytes(length: 1) + XCTAssertNil(additionalData) + + } else { + XCTFail("couldn't read ByteBuffer from channel") + } + + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { + + let bodyString = outputBuffer.readString(length: standardDataStringCount) + XCTAssertEqual(standardDataString, bodyString) + + let additionalData = outputBuffer.readBytes(length: 1) + XCTAssertNil(additionalData) + + } else { + XCTFail("couldn't read ByteBuffer from channel") + } + + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) + XCTAssertTrue(try self.channel.finish().isClean) + } + func testEncodeWithUInt32HeaderWithString() throws { let endianness: Endianness = .little diff --git a/scripts/soundness.sh b/scripts/soundness.sh index 7180486..64eb373 100755 --- a/scripts/soundness.sh +++ b/scripts/soundness.sh @@ -18,7 +18,7 @@ here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" function replace_acceptable_years() { # this needs to replace all acceptable forms with 'YEARS' - sed -e 's/2017-201[89]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/g' + sed -e 's/2017-201[89]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/g' -e 's/2021/YEARS/g' } printf "=> Checking linux tests... "