add support for a 24 bit (3 byte) length field (#114)

Motivation:

The RSocket protocol uses a 24 bit length field

Modifications:

- add two new methods readInteger and writeInteger on ByteBuffer that support reading and writing integers of any size.
- add a new case (.three) to ByteLength

Result:

LengthFieldBasedFrameDecoder & LengthFieldPrepender do now support a 24 bit length field

Co-authored-by: Johannes Weiss <johannesweiss@apple.com>
This commit is contained in:
David Nadoba 2021-02-17 10:04:24 +01:00 committed by GitHub
parent f9a828d8b3
commit 3d14afbe3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 289 additions and 66 deletions

View File

@ -14,6 +14,38 @@
import NIO
extension ByteBuffer {
@inlinable
mutating func get24UInt(
at index: Int,
endianness: Endianness = .big
) -> UInt32? {
let mostSignificant: UInt16
let leastSignificant: UInt8
switch endianness {
case .big:
guard let uint16 = self.getInteger(at: index, endianness: .big, as: UInt16.self),
let uint8 = self.getInteger(at: index + 2, endianness: .big, as: UInt8.self) else { return nil }
mostSignificant = uint16
leastSignificant = uint8
case .little:
guard let uint8 = self.getInteger(at: index, endianness: .little, as: UInt8.self),
let uint16 = self.getInteger(at: index + 1, endianness: .little, as: UInt16.self) else { return nil }
mostSignificant = uint16
leastSignificant = uint8
}
return (UInt32(mostSignificant) << 8) &+ UInt32(leastSignificant)
}
@inlinable
mutating func read24UInt(
endianness: Endianness = .big
) -> UInt32? {
guard let integer = get24UInt(at: self.readerIndex, endianness: endianness) else { return nil }
self.moveReaderIndex(forwardBy: 3)
return integer
}
}
///
/// A decoder that splits the received `ByteBuffer` by the number of bytes specified in a fixed length header
/// contained within the buffer.
@ -35,7 +67,6 @@ import NIO
public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder {
///
/// An enumeration to describe the length of a piece of data in bytes.
/// It is contained to lengths that can be converted to integer types.
///
public enum ByteLength {
case one
@ -43,16 +74,12 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder {
case four
case eight
var length: Int {
fileprivate var bitLength: NIOLengthFieldBitLength {
switch self {
case .one:
return 1
case .two:
return 2
case .four:
return 4
case .eight:
return 8
case .one: return .oneByte
case .two: return .twoBytes
case .four: return .fourBytes
case .eight: return .eightBytes
}
}
}
@ -73,7 +100,7 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder {
public var cumulationBuffer: ByteBuffer?
private var readState: DecoderReadState = .waitingForHeader
private let lengthFieldLength: ByteLength
private let lengthFieldLength: NIOLengthFieldBitLength
private let lengthFieldEndianness: Endianness
/// Create `LengthFieldBasedFrameDecoder` with a given frame length.
@ -82,13 +109,23 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder {
/// - lengthFieldLength: The length of the field specifying the remaining length of the frame.
/// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame.
///
public init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) {
public convenience init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) {
self.init(lengthFieldBitLength: lengthFieldLength.bitLength, lengthFieldEndianness: lengthFieldEndianness)
}
/// Create `LengthFieldBasedFrameDecoder` with a given frame length.
///
/// - parameters:
/// - lengthFieldBitLength: The length of the field specifying the remaining length of the frame.
/// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame.
///
public init(lengthFieldBitLength: NIOLengthFieldBitLength, lengthFieldEndianness: Endianness = .big) {
// The value contained in the length field must be able to be represented by an integer type on the platform.
// ie. .eight == 64bit which would not fit into the Int type on a 32bit platform.
precondition(lengthFieldLength.length <= Int.bitWidth/8)
precondition(lengthFieldBitLength.length <= Int.bitWidth/8)
self.lengthFieldLength = lengthFieldLength
self.lengthFieldLength = lengthFieldBitLength
self.lengthFieldEndianness = lengthFieldEndianness
}
@ -156,21 +193,23 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder {
///
/// Decodes the specified region of the buffer into an unadjusted frame length. The default implementation is
/// capable of decoding the specified region into an unsigned 8/16/32/64 bit integer.
/// capable of decoding the specified region into an unsigned 8/16/24/32/64 bit integer.
///
/// - parameters:
/// - buffer: The buffer containing the integer frame length.
///
private func readFrameLength(for buffer: inout ByteBuffer) -> Int? {
switch self.lengthFieldLength {
case .one:
switch self.lengthFieldLength.bitLength {
case .bits8:
return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt8.self).map { Int($0) }
case .two:
case .bits16:
return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt16.self).map { Int($0) }
case .four:
case .bits24:
return buffer.read24UInt(endianness: self.lengthFieldEndianness).map { Int($0) }
case .bits32:
return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt32.self).map { Int($0) }
case .eight:
case .bits64:
return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt64.self).map { Int($0) }
}
}

View File

@ -14,6 +14,26 @@
import NIO
extension ByteBuffer {
@discardableResult
@inlinable
mutating func write24UInt(
_ integer: UInt32,
endianness: Endianness = .big
) -> Int {
precondition(integer & 0xFF_FF_FF == integer, "integer value does not fit into 24 bit integer")
switch endianness {
case .little:
return writeInteger(UInt8(integer & 0xFF), endianness: .little) +
writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .little)
case .big:
return writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .big) +
writeInteger(UInt8(integer & 0xFF), endianness: .big)
}
}
}
public enum LengthFieldPrependerError: Error {
case messageDataTooLongForLengthField
}
@ -31,42 +51,21 @@ public enum LengthFieldPrependerError: Error {
/// This initial prepended byte is called the 'length field'.
///
public final class LengthFieldPrepender: ChannelOutboundHandler {
///
/// An enumeration to describe the length of a piece of data in bytes.
/// It is constrained to lengths that can be converted to integer types.
///
public enum ByteLength {
case one
case two
case four
case eight
fileprivate var length: Int {
switch self {
case .one:
return 1
case .two:
return 2
case .four:
return 4
case .eight:
return 8
}
}
fileprivate var max: UInt {
fileprivate var bitLength: NIOLengthFieldBitLength {
switch self {
case .one:
return UInt(UInt8.max)
case .two:
return UInt(UInt16.max)
case .four:
return UInt(UInt32.max)
case .eight:
return UInt(UInt64.max)
case .one: return .oneByte
case .two: return .twoBytes
case .four: return .fourBytes
case .eight: return .eightBytes
}
}
}
@ -74,7 +73,7 @@ public final class LengthFieldPrepender: ChannelOutboundHandler {
public typealias OutboundIn = ByteBuffer
public typealias OutboundOut = ByteBuffer
private let lengthFieldLength: LengthFieldPrepender.ByteLength
private let lengthFieldLength: NIOLengthFieldBitLength
private let lengthFieldEndianness: Endianness
private var lengthBuffer: ByteBuffer?
@ -85,13 +84,15 @@ public final class LengthFieldPrepender: ChannelOutboundHandler {
/// - lengthFieldLength: The length of the field specifying the remaining length of the frame.
/// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame.
///
public init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) {
public convenience init(lengthFieldLength: ByteLength, lengthFieldEndianness: Endianness = .big) {
self.init(lengthFieldBitLength: lengthFieldLength.bitLength, lengthFieldEndianness: lengthFieldEndianness)
}
public init(lengthFieldBitLength: NIOLengthFieldBitLength, lengthFieldEndianness: Endianness = .big) {
// The value contained in the length field must be able to be represented by an integer type on the platform.
// ie. .eight == 64bit which would not fit into the Int type on a 32bit platform.
precondition(lengthFieldLength.length <= Int.bitWidth/8)
precondition(lengthFieldBitLength.length <= Int.bitWidth/8)
self.lengthFieldLength = lengthFieldLength
self.lengthFieldLength = lengthFieldBitLength
self.lengthFieldEndianness = lengthFieldEndianness
}
@ -115,14 +116,16 @@ public final class LengthFieldPrepender: ChannelOutboundHandler {
self.lengthBuffer = dataLengthBuffer
}
switch self.lengthFieldLength {
case .one:
switch self.lengthFieldLength.bitLength {
case .bits8:
dataLengthBuffer.writeInteger(UInt8(dataLength), endianness: self.lengthFieldEndianness)
case .two:
case .bits16:
dataLengthBuffer.writeInteger(UInt16(dataLength), endianness: self.lengthFieldEndianness)
case .four:
case .bits24:
dataLengthBuffer.write24UInt(UInt32(dataLength), endianness: self.lengthFieldEndianness)
case .bits32:
dataLengthBuffer.writeInteger(UInt32(dataLength), endianness: self.lengthFieldEndianness)
case .eight:
case .bits64:
dataLengthBuffer.writeInteger(UInt64(dataLength), endianness: self.lengthFieldEndianness)
}

View File

@ -0,0 +1,67 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
/// An struct to describe the length of a piece of data in bits
public struct NIOLengthFieldBitLength {
internal enum Backing {
case bits8
case bits16
case bits24
case bits32
case bits64
}
internal let bitLength: Backing
public static let oneByte = NIOLengthFieldBitLength(bitLength: .bits8)
public static let twoBytes = NIOLengthFieldBitLength(bitLength: .bits16)
public static let threeBytes = NIOLengthFieldBitLength(bitLength: .bits24)
public static let fourBytes = NIOLengthFieldBitLength(bitLength: .bits32)
public static let eightBytes = NIOLengthFieldBitLength(bitLength: .bits64)
public static let eightBits = NIOLengthFieldBitLength(bitLength: .bits8)
public static let sixteenBits = NIOLengthFieldBitLength(bitLength: .bits16)
public static let twentyFourBits = NIOLengthFieldBitLength(bitLength: .bits24)
public static let thirtyTwoBits = NIOLengthFieldBitLength(bitLength: .bits32)
public static let sixtyFourBits = NIOLengthFieldBitLength(bitLength: .bits64)
internal var length: Int {
switch bitLength {
case .bits8:
return 1
case .bits16:
return 2
case .bits24:
return 3
case .bits32:
return 4
case .bits64:
return 8
}
}
internal var max: UInt {
switch bitLength {
case .bits8:
return UInt(UInt8.max)
case .bits16:
return UInt(UInt16.max)
case .bits24:
return (UInt(UInt16.max) << 8) &+ UInt(UInt8.max)
case .bits32:
return UInt(UInt32.max)
case .bits64:
return UInt(UInt64.max)
}
}
}

View File

@ -26,8 +26,11 @@ extension LengthFieldBasedFrameDecoderTest {
static var allTests : [(String, (LengthFieldBasedFrameDecoderTest) -> () throws -> Void)] {
return [
("testReadUInt32From3Bytes", testReadUInt32From3Bytes),
("testReadAndWriteUInt32From3BytesBasicVerification", testReadAndWriteUInt32From3BytesBasicVerification),
("testDecodeWithUInt8HeaderWithData", testDecodeWithUInt8HeaderWithData),
("testDecodeWithUInt16HeaderWithString", testDecodeWithUInt16HeaderWithString),
("testDecodeWithUInt24HeaderWithString", testDecodeWithUInt24HeaderWithString),
("testDecodeWithUInt32HeaderWithString", testDecodeWithUInt32HeaderWithString),
("testDecodeWithUInt64HeaderWithString", testDecodeWithUInt64HeaderWithString),
("testDecodeWithInt64HeaderWithString", testDecodeWithInt64HeaderWithString),

View File

@ -27,7 +27,39 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase {
override func setUp() {
self.channel = EmbeddedChannel()
}
func testReadUInt32From3Bytes() {
var buffer = ByteBuffer(bytes: [
0, 0, 5,
5, 0, 0,
])
XCTAssertEqual(buffer.read24UInt(endianness: .big), 5)
print(buffer.readableBytesView)
XCTAssertEqual(buffer.read24UInt(endianness: .little), 5)
}
func testReadAndWriteUInt32From3BytesBasicVerification() {
let inputs: [UInt32] = [
0,
1,
5,
UInt32(UInt8.max),
UInt32(UInt16.max),
UInt32(UInt16.max) << 8 &+ UInt32(UInt8.max),
UInt32(UInt8.max) - 1,
UInt32(UInt16.max) - 1,
UInt32(UInt16.max) << 8 &+ UInt32(UInt8.max) - 1,
UInt32(UInt8.max) + 1,
UInt32(UInt16.max) + 1,
]
for input in inputs {
var buffer = ByteBuffer()
buffer.write24UInt(input, endianness: .big)
XCTAssertEqual(buffer.read24UInt(endianness: .big), input)
buffer.write24UInt(input, endianness: .little)
XCTAssertEqual(buffer.read24UInt(endianness: .little), input)
}
}
func testDecodeWithUInt8HeaderWithData() throws {
self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .one,
@ -71,6 +103,25 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase {
XCTAssertTrue(try self.channel.finish().isClean)
}
func testDecodeWithUInt24HeaderWithString() throws {
self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldBitLength: .threeBytes,
lengthFieldEndianness: .big))
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait())
var buffer = self.channel.allocator.buffer(capacity: 8) // 3 byte header + 5 character string
buffer.writeBytes([0, 0, 5])
buffer.writeString(standardDataString)
XCTAssertTrue(try self.channel.writeInbound(buffer).isFull)
XCTAssertNoThrow(XCTAssertEqual(standardDataString,
try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map {
String(decoding: $0, as: Unicode.UTF8.self)
}))
XCTAssertTrue(try self.channel.finish().isClean)
}
func testDecodeWithUInt32HeaderWithString() throws {
self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four,
@ -404,29 +455,36 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase {
}
func testBasicVerification() {
let inputs: [(LengthFieldBasedFrameDecoder.ByteLength, [(Int, String)])] = [
(.one, [
let inputs: [(NIOLengthFieldBitLength, [(Int, String)])] = [
(.oneByte, [
(6, "abcdef"),
(0, ""),
(9, "123456789"),
(Int(UInt8.max),
String(decoding: Array(repeating: UInt8(ascii: "X"), count: Int(UInt8.max)), as: Unicode.UTF8.self)),
]),
(.two, [
(.twoBytes, [
(1, "a"),
(0, ""),
(9, "123456789"),
(307,
String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)),
]),
(.four, [
(.threeBytes, [
(1, "a"),
(0, ""),
(9, "123456789"),
(307,
String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)),
]),
(.fourBytes, [
(1, "a"),
(0, ""),
(3, "333"),
(307,
String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)),
]),
(.eight, [
(.eightBytes, [
(1, "a"),
(0, ""),
(4, "aaaa"),
@ -451,7 +509,7 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase {
return (bytes, [bytes.getSlice(at: bytes.readerIndex + lenBytes.length, length: input.0)!])
}
XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: inputOutputPairs) {
LengthFieldBasedFrameDecoder(lengthFieldLength: lenBytes)
LengthFieldBasedFrameDecoder(lengthFieldBitLength: lenBytes)
})
}
}

View File

@ -26,8 +26,10 @@ extension LengthFieldPrependerTest {
static var allTests : [(String, (LengthFieldPrependerTest) -> () throws -> Void)] {
return [
("testWrite3BytesOfUInt32Write", testWrite3BytesOfUInt32Write),
("testEncodeWithUInt8HeaderWithData", testEncodeWithUInt8HeaderWithData),
("testEncodeWithUInt16HeaderWithString", testEncodeWithUInt16HeaderWithString),
("testEncodeWithUInt24HeaderWithString", testEncodeWithUInt24HeaderWithString),
("testEncodeWithUInt32HeaderWithString", testEncodeWithUInt32HeaderWithString),
("testEncodeWithUInt64HeaderWithString", testEncodeWithUInt64HeaderWithString),
("testEncodeWithInt64HeaderWithString", testEncodeWithInt64HeaderWithString),

View File

@ -14,7 +14,7 @@
import XCTest
import NIO
import NIOExtras
@testable import NIOExtras
private let standardDataString = "abcde"
private let standardDataStringCount = standardDataString.utf8.count
@ -26,7 +26,16 @@ class LengthFieldPrependerTest: XCTestCase {
override func setUp() {
self.channel = EmbeddedChannel()
}
func testWrite3BytesOfUInt32Write() {
var buffer = ByteBuffer()
buffer.write24UInt(5, endianness: .little)
XCTAssertEqual(Array(buffer.readableBytesView), [5, 0, 0])
XCTAssertEqual(buffer.read24UInt(endianness: .little), 5)
buffer.write24UInt(5, endianness: .big)
XCTAssertEqual(Array(buffer.readableBytesView), [0, 0, 5])
XCTAssertEqual(buffer.read24UInt(endianness: .big), 5)
}
func testEncodeWithUInt8HeaderWithData() throws {
self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .one,
@ -105,6 +114,48 @@ class LengthFieldPrependerTest: XCTestCase {
XCTAssertTrue(try self.channel.finish().isClean)
}
func testEncodeWithUInt24HeaderWithString() throws {
let endianness: Endianness = .little
self.encoderUnderTest = LengthFieldPrepender(lengthFieldBitLength: .threeBytes,
lengthFieldEndianness: endianness)
XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait())
var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount)
buffer.writeString(standardDataString)
XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait())
if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) {
let sizeInHeader = outputBuffer.read24UInt(endianness: endianness).map({ Int($0) })
XCTAssertEqual(standardDataStringCount, sizeInHeader)
let additionalData = outputBuffer.readBytes(length: 1)
XCTAssertNil(additionalData)
} else {
XCTFail("couldn't read ByteBuffer from channel")
}
if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) {
let bodyString = outputBuffer.readString(length: standardDataStringCount)
XCTAssertEqual(standardDataString, bodyString)
let additionalData = outputBuffer.readBytes(length: 1)
XCTAssertNil(additionalData)
} else {
XCTFail("couldn't read ByteBuffer from channel")
}
XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound()))
XCTAssertTrue(try self.channel.finish().isClean)
}
func testEncodeWithUInt32HeaderWithString() throws {
let endianness: Endianness = .little

View File

@ -18,7 +18,7 @@ here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
function replace_acceptable_years() {
# this needs to replace all acceptable forms with 'YEARS'
sed -e 's/2017-201[89]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/g'
sed -e 's/2017-201[89]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/g' -e 's/2021/YEARS/g'
}
printf "=> Checking linux tests... "