mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-15 01:18:58 +08:00
Motivation Currently we don't confirm that the decompression has completed successfully. This means that we can incorrectly spin forever attempting to decompress past the end of a message, and that we can fail to notice that a message is truncated. Neither of these is good. Modifications Propagate the message zlib gives us as to whether or not decompression is done, and keep track of it. Add some tests written by @vojtarylko to validate the behaviour. Result Correctly police the bounds of the messages.
This commit is contained in:
parent
29e4c0a2b4
commit
e8000754be
@ -52,6 +52,30 @@ public enum NIOHTTPDecompression {
|
||||
case initializationError(Int)
|
||||
}
|
||||
|
||||
// Would have been public, but this is a backport and cannot add new API.
|
||||
internal struct ExtraDecompressionError: Error, Hashable, CustomStringConvertible {
|
||||
private var backing: Backing
|
||||
|
||||
private enum Backing {
|
||||
case invalidTrailingData
|
||||
case truncatedData
|
||||
}
|
||||
|
||||
private init(_ backing: Backing) {
|
||||
self.backing = backing
|
||||
}
|
||||
|
||||
/// Decompression completed but there was invalid trailing data behind the compressed data.
|
||||
static let invalidTrailingData = ExtraDecompressionError(.invalidTrailingData)
|
||||
|
||||
/// The decompressed data was incorrectly truncated.
|
||||
static let truncatedData = ExtraDecompressionError(.truncatedData)
|
||||
|
||||
var description: String {
|
||||
return String(describing: self.backing)
|
||||
}
|
||||
}
|
||||
|
||||
enum CompressionAlgorithm: String {
|
||||
case gzip
|
||||
case deflate
|
||||
@ -86,12 +110,15 @@ public enum NIOHTTPDecompression {
|
||||
self.limit = limit
|
||||
}
|
||||
|
||||
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws {
|
||||
self.inflated += try self.stream.inflatePart(input: &part, output: &buffer)
|
||||
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult {
|
||||
let result = try self.stream.inflatePart(input: &part, output: &buffer)
|
||||
self.inflated += result.written
|
||||
|
||||
if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) {
|
||||
throw NIOHTTPDecompression.DecompressionError.limit
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws {
|
||||
@ -112,9 +139,10 @@ public enum NIOHTTPDecompression {
|
||||
}
|
||||
|
||||
extension z_stream {
|
||||
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int {
|
||||
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> InflateResult {
|
||||
let minimumCapacity = input.readableBytes * 2
|
||||
var written = 0
|
||||
var inflateResult = InflateResult(written: 0, complete: false)
|
||||
|
||||
try input.readWithUnsafeMutableReadableBytes { pointer in
|
||||
self.avail_in = UInt32(pointer.count)
|
||||
self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)
|
||||
@ -126,24 +154,34 @@ extension z_stream {
|
||||
self.next_out = nil
|
||||
}
|
||||
|
||||
written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)
|
||||
inflateResult = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)
|
||||
|
||||
return pointer.count - Int(self.avail_in)
|
||||
}
|
||||
return written
|
||||
return inflateResult
|
||||
}
|
||||
|
||||
private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int {
|
||||
return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
|
||||
private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> InflateResult {
|
||||
var rc = Z_OK
|
||||
|
||||
let written = try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
|
||||
self.avail_out = UInt32(pointer.count)
|
||||
self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)
|
||||
|
||||
let rc = inflate(&self, Z_NO_FLUSH)
|
||||
rc = inflate(&self, Z_NO_FLUSH)
|
||||
guard rc == Z_OK || rc == Z_STREAM_END else {
|
||||
throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc))
|
||||
}
|
||||
|
||||
return pointer.count - Int(self.avail_out)
|
||||
}
|
||||
|
||||
return InflateResult(written: written, complete: rc == Z_STREAM_END)
|
||||
}
|
||||
}
|
||||
|
||||
struct InflateResult {
|
||||
var written: Int
|
||||
|
||||
var complete: Bool
|
||||
}
|
||||
|
@ -29,10 +29,12 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
|
||||
|
||||
private var decompressor: NIOHTTPDecompression.Decompressor
|
||||
private var compression: Compression?
|
||||
private var decompressionComplete: Bool
|
||||
|
||||
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
|
||||
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
|
||||
self.compression = nil
|
||||
self.decompressionComplete = false
|
||||
}
|
||||
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
@ -61,10 +63,13 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
|
||||
return
|
||||
}
|
||||
|
||||
while part.readableBytes > 0 {
|
||||
while part.readableBytes > 0 && !self.decompressionComplete {
|
||||
do {
|
||||
var buffer = context.channel.allocator.buffer(capacity: 16384)
|
||||
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
|
||||
let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
|
||||
if result.complete {
|
||||
self.decompressionComplete = true
|
||||
}
|
||||
|
||||
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
|
||||
} catch let error {
|
||||
@ -72,10 +77,21 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if part.readableBytes > 0 {
|
||||
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData)
|
||||
}
|
||||
case .end:
|
||||
if self.compression != nil {
|
||||
let wasDecompressionComplete = self.decompressionComplete
|
||||
|
||||
self.decompressor.deinitializeDecoder()
|
||||
self.compression = nil
|
||||
self.decompressionComplete = false
|
||||
|
||||
if !wasDecompressionComplete {
|
||||
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData)
|
||||
}
|
||||
}
|
||||
|
||||
context.fireChannelRead(data)
|
||||
|
@ -33,9 +33,11 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
|
||||
|
||||
private var compression: Compression? = nil
|
||||
private var decompressor: NIOHTTPDecompression.Decompressor
|
||||
private var decompressionComplete: Bool
|
||||
|
||||
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
|
||||
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
|
||||
self.decompressionComplete = false
|
||||
}
|
||||
|
||||
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||
@ -77,22 +79,36 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
|
||||
|
||||
do {
|
||||
compression.compressedLength += part.readableBytes
|
||||
while part.readableBytes > 0 {
|
||||
while part.readableBytes > 0 && !self.decompressionComplete {
|
||||
var buffer = context.channel.allocator.buffer(capacity: 16384)
|
||||
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength)
|
||||
let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength)
|
||||
if result.complete {
|
||||
self.decompressionComplete = true
|
||||
}
|
||||
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
|
||||
}
|
||||
|
||||
// assign the changed local property back to the class state
|
||||
self.compression = compression
|
||||
|
||||
if part.readableBytes > 0 {
|
||||
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData)
|
||||
}
|
||||
}
|
||||
catch {
|
||||
context.fireErrorCaught(error)
|
||||
}
|
||||
case .end:
|
||||
if self.compression != nil {
|
||||
let wasDecompressionComplete = self.decompressionComplete
|
||||
|
||||
self.decompressor.deinitializeDecoder()
|
||||
self.compression = nil
|
||||
self.decompressionComplete = false
|
||||
|
||||
if !wasDecompressionComplete {
|
||||
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData)
|
||||
}
|
||||
}
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
|
@ -30,6 +30,8 @@ extension HTTPRequestDecompressorTest {
|
||||
("testDecompressionLimitRatio", testDecompressionLimitRatio),
|
||||
("testDecompressionLimitSize", testDecompressionLimitSize),
|
||||
("testDecompression", testDecompression),
|
||||
("testDecompressionTrailingData", testDecompressionTrailingData),
|
||||
("testDecompressionTruncatedInput", testDecompressionTruncatedInput),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -119,9 +119,33 @@ class HTTPRequestDecompressorTest: XCTestCase {
|
||||
)
|
||||
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
}
|
||||
}
|
||||
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
func testDecompressionTrailingData() throws {
|
||||
// Valid compressed data with some trailing garbage
|
||||
let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3])
|
||||
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
|
||||
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
|
||||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
|
||||
|
||||
XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
|
||||
}
|
||||
|
||||
func testDecompressionTruncatedInput() throws {
|
||||
// Truncated compressed data
|
||||
let compressed = ByteBuffer(bytes: [120, 156, 99, 0])
|
||||
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
|
||||
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
|
||||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
|
||||
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
|
||||
XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
}
|
||||
|
||||
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
|
||||
|
@ -37,6 +37,8 @@ extension HTTPResponseDecompressorTest {
|
||||
("testDecompressionLimitRatioWithoutContentLenghtHeaderFails", testDecompressionLimitRatioWithoutContentLenghtHeaderFails),
|
||||
("testDecompression", testDecompression),
|
||||
("testDecompressionWithoutContentLength", testDecompressionWithoutContentLength),
|
||||
("testDecompressionTrailingData", testDecompressionTrailingData),
|
||||
("testDecompressionTruncatedInput", testDecompressionTruncatedInput),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -238,6 +238,31 @@ class HTTPResponseDecompressorTest: XCTestCase {
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil)))
|
||||
}
|
||||
|
||||
func testDecompressionTrailingData() throws {
|
||||
// Valid compressed data with some trailing garbage
|
||||
let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3])
|
||||
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()
|
||||
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
|
||||
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))
|
||||
|
||||
XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed)))
|
||||
}
|
||||
|
||||
func testDecompressionTruncatedInput() throws {
|
||||
// Truncated compressed data
|
||||
let compressed = ByteBuffer(bytes: [120, 156, 99, 0])
|
||||
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()
|
||||
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
|
||||
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))
|
||||
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed)))
|
||||
XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil)))
|
||||
}
|
||||
|
||||
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
|
||||
var stream = z_stream()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user