swift-nio-extras/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift
Trevör f21a87da13
Merge pull request from GHSA-xhhr-p2r9-jmm7
Motivation:
NIOHTTPRequestDecompressor and HTTPResponseDecompressor are both affected by an issue where the decompression limits defined by their DecompressionLimit property wasn't correctly checked when is was set with DecompressionLimit.size(...), allowing denial of service attacks.

Modifications:
- Update DecompressionLimit.size(...) to correctly check the size of the decompressed data.
- Update test cases to avoid future regressions regarding the size checks.

Result:
Prevents DoS attacks though maliciously crafted compressed data.
2020-05-02 09:29:33 +01:00

199 lines
7.7 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
import CNIOExtrasZlib
import NIO
@testable import NIOHTTP1
@testable import NIOHTTPCompression
private let testString = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
private final class DecompressedAssert: ChannelInboundHandler {
typealias InboundIn = HTTPServerRequestPart
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)
switch request {
case .body(let buffer):
let string = buffer.getString(at: buffer.readerIndex, length: buffer.readableBytes)
guard string == testString else {
context.fireErrorCaught(NIOHTTPDecompression.DecompressionError.inflationError(42))
return
}
default: context.fireChannelRead(data)
}
}
}
class HTTPRequestDecompressorTest: XCTestCase {
func testDecompressionNoLimit() throws {
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
try channel.pipeline.addHandler(DecompressedAssert()).wait()
let buffer = ByteBuffer.of(string: testString)
let compressed = compress(buffer, "gzip")
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
}
func testDecompressionLimitRatio() throws {
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .ratio(10))).wait()
let decompressed = ByteBuffer.of(bytes: Array(repeating: 0, count: 500))
let compressed = compress(decompressed, "gzip")
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
do {
try channel.writeInbound(HTTPServerRequestPart.body(compressed))
XCTFail("writeShouldFail")
} catch let error as NIOHTTPDecompression.DecompressionError {
switch error {
case .limit:
// ok
break
default:
XCTFail("Unexptected error: \(error)")
}
}
}
func testDecompressionLimitSize() throws {
let channel = EmbeddedChannel()
let decompressed = ByteBuffer.of(bytes: Array(repeating: 0, count: 200))
let compressed = compress(decompressed, "gzip")
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .size(decompressed.readableBytes-1))).wait()
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
do {
try channel.writeInbound(HTTPServerRequestPart.body(compressed))
XCTFail("writeInbound should fail")
} catch let error as NIOHTTPDecompression.DecompressionError {
switch error {
case .limit:
// ok
break
default:
XCTFail("Unexptected error: \(error)")
}
}
}
func testDecompression() throws {
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
let body = Array(repeating: testString, count: 1000).joined()
for algorithm in [nil, "gzip", "deflate"] {
let compressed: ByteBuffer
var headers = HTTPHeaders()
if let algorithm = algorithm {
headers.add(name: "Content-Encoding", value: algorithm)
compressed = compress(ByteBuffer.of(string: body), algorithm)
} else {
compressed = ByteBuffer.of(string: body)
}
headers.add(name: "Content-Length", value: "\(compressed.readableBytes)")
XCTAssertNoThrow(
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
)
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
}
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
}
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
var stream = z_stream()
stream.zalloc = nil
stream.zfree = nil
stream.opaque = nil
var buffer = ByteBufferAllocator().buffer(capacity: 1000)
let windowBits: Int32
switch algorithm {
case "deflate":
windowBits = 15
case "gzip":
windowBits = 16 + 15
default:
XCTFail("Unsupported algorithm: \(algorithm)")
return buffer
}
let rc = CNIOExtrasZlib_deflateInit2(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, 8, Z_DEFAULT_STRATEGY)
XCTAssertEqual(Z_OK, rc)
defer {
stream.avail_in = 0
stream.next_in = nil
stream.avail_out = 0
stream.next_out = nil
}
var body = body
body.readWithUnsafeMutableReadableBytes { dataPtr in
let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self)
let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr,
count: dataPtr.count)
stream.avail_in = UInt32(typedDataPtr.count)
stream.next_in = typedDataPtr.baseAddress!
buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { outputPtr in
let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self),
count: outputPtr.count)
stream.avail_out = UInt32(typedOutputPtr.count)
stream.next_out = typedOutputPtr.baseAddress!
let rc = deflate(&stream, Z_FINISH)
XCTAssertTrue(rc == Z_OK || rc == Z_STREAM_END)
return typedOutputPtr.count - Int(stream.avail_out)
}
return typedDataPtr.count - Int(stream.avail_in)
}
deflateEnd(&stream)
return buffer
}
}
extension ByteBuffer {
fileprivate static func of(string: String) -> ByteBuffer {
var buffer = ByteBufferAllocator().buffer(capacity: string.count)
buffer.writeString(string)
return buffer
}
fileprivate static func of(bytes: [UInt8]) -> ByteBuffer {
var buffer = ByteBufferAllocator().buffer(capacity: bytes.count)
buffer.writeBytes(bytes)
return buffer
}
}