Backport #177 to 1.9 (#179)

Motivation

Currently we don't confirm that the decompression has completed
successfully. This means that we can incorrectly spin forever attempting
to decompress past the end of a message, and that we can fail to notice
that a message is truncated. Neither of these is good.

Modifications

Propagate the message zlib gives us as to whether or not decompression
is done, and keep track of it.
Add some tests written by @vojtarylko to validate the behaviour.

Result

Correctly police the bounds of the messages.
This commit is contained in:
Cory Benfield 2022-09-16 10:30:05 +02:00 committed by GitHub
parent 29e4c0a2b4
commit e8000754be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 14 deletions

View File

@ -52,6 +52,30 @@ public enum NIOHTTPDecompression {
case initializationError(Int)
}
// Would have been public, but this is a backport and cannot add new API.
internal struct ExtraDecompressionError: Error, Hashable, CustomStringConvertible {
private var backing: Backing
private enum Backing {
case invalidTrailingData
case truncatedData
}
private init(_ backing: Backing) {
self.backing = backing
}
/// Decompression completed but there was invalid trailing data behind the compressed data.
static let invalidTrailingData = ExtraDecompressionError(.invalidTrailingData)
/// The decompressed data was incorrectly truncated.
static let truncatedData = ExtraDecompressionError(.truncatedData)
var description: String {
return String(describing: self.backing)
}
}
enum CompressionAlgorithm: String {
case gzip
case deflate
@ -86,12 +110,15 @@ public enum NIOHTTPDecompression {
self.limit = limit
}
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws {
self.inflated += try self.stream.inflatePart(input: &part, output: &buffer)
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult {
let result = try self.stream.inflatePart(input: &part, output: &buffer)
self.inflated += result.written
if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) {
throw NIOHTTPDecompression.DecompressionError.limit
}
return result
}
mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws {
@ -112,9 +139,10 @@ public enum NIOHTTPDecompression {
}
extension z_stream {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> InflateResult {
let minimumCapacity = input.readableBytes * 2
var written = 0
var inflateResult = InflateResult(written: 0, complete: false)
try input.readWithUnsafeMutableReadableBytes { pointer in
self.avail_in = UInt32(pointer.count)
self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)
@ -126,24 +154,34 @@ extension z_stream {
self.next_out = nil
}
written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)
inflateResult = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)
return pointer.count - Int(self.avail_in)
}
return written
return inflateResult
}
private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int {
return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> InflateResult {
var rc = Z_OK
let written = try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
self.avail_out = UInt32(pointer.count)
self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)
let rc = inflate(&self, Z_NO_FLUSH)
rc = inflate(&self, Z_NO_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc))
}
return pointer.count - Int(self.avail_out)
}
return InflateResult(written: written, complete: rc == Z_STREAM_END)
}
}
struct InflateResult {
var written: Int
var complete: Bool
}

View File

@ -29,10 +29,12 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
private var decompressor: NIOHTTPDecompression.Decompressor
private var compression: Compression?
private var decompressionComplete: Bool
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
self.compression = nil
self.decompressionComplete = false
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
@ -61,10 +63,13 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
return
}
while part.readableBytes > 0 {
while part.readableBytes > 0 && !self.decompressionComplete {
do {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
if result.complete {
self.decompressionComplete = true
}
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
} catch let error {
@ -72,10 +77,21 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
return
}
}
if part.readableBytes > 0 {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData)
}
case .end:
if self.compression != nil {
let wasDecompressionComplete = self.decompressionComplete
self.decompressor.deinitializeDecoder()
self.compression = nil
self.decompressionComplete = false
if !wasDecompressionComplete {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData)
}
}
context.fireChannelRead(data)

View File

@ -33,9 +33,11 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
private var compression: Compression? = nil
private var decompressor: NIOHTTPDecompression.Decompressor
private var decompressionComplete: Bool
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
self.decompressionComplete = false
}
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
@ -77,22 +79,36 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
do {
compression.compressedLength += part.readableBytes
while part.readableBytes > 0 {
while part.readableBytes > 0 && !self.decompressionComplete {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength)
let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength)
if result.complete {
self.decompressionComplete = true
}
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
}
// assign the changed local property back to the class state
self.compression = compression
if part.readableBytes > 0 {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData)
}
}
catch {
context.fireErrorCaught(error)
}
case .end:
if self.compression != nil {
let wasDecompressionComplete = self.decompressionComplete
self.decompressor.deinitializeDecoder()
self.compression = nil
self.decompressionComplete = false
if !wasDecompressionComplete {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData)
}
}
context.fireChannelRead(data)
}

View File

@ -30,6 +30,8 @@ extension HTTPRequestDecompressorTest {
("testDecompressionLimitRatio", testDecompressionLimitRatio),
("testDecompressionLimitSize", testDecompressionLimitSize),
("testDecompression", testDecompression),
("testDecompressionTrailingData", testDecompressionTrailingData),
("testDecompressionTruncatedInput", testDecompressionTruncatedInput),
]
}
}

View File

@ -119,9 +119,33 @@ class HTTPRequestDecompressorTest: XCTestCase {
)
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
}
}
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
func testDecompressionTrailingData() throws {
// Valid compressed data with some trailing garbage
let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3])
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
}
func testDecompressionTruncatedInput() throws {
// Truncated compressed data
let compressed = ByteBuffer(bytes: [120, 156, 99, 0])
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
}
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {

View File

@ -37,6 +37,8 @@ extension HTTPResponseDecompressorTest {
("testDecompressionLimitRatioWithoutContentLenghtHeaderFails", testDecompressionLimitRatioWithoutContentLenghtHeaderFails),
("testDecompression", testDecompression),
("testDecompressionWithoutContentLength", testDecompressionWithoutContentLength),
("testDecompressionTrailingData", testDecompressionTrailingData),
("testDecompressionTruncatedInput", testDecompressionTruncatedInput),
]
}
}

View File

@ -238,6 +238,31 @@ class HTTPResponseDecompressorTest: XCTestCase {
XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil)))
}
func testDecompressionTrailingData() throws {
// Valid compressed data with some trailing garbage
let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3])
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))
XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed)))
}
func testDecompressionTruncatedInput() throws {
// Truncated compressed data
let compressed = ByteBuffer(bytes: [120, 156, 99, 0])
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))
XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed)))
XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil)))
}
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
var stream = z_stream()