mirror of
https://github.com/apple/swift-nio-extras.git
synced 2025-05-14 00:42:41 +08:00
Gzip request decompress (#59)
### Motivation: There will be times when a client wishes to send larger requests with gzipped bodies to save on network traffic. This PR adds a `NIOHTTPRequestDecompressor` which can be added to the server's channel pipeline so those requests are automatically inflated. ### Modifications: - Added a `CNIOExtrasZlib_voidPtr_to_BytefPtr` C method. - Added a `NIOHTTPRequestDecompressor` type. - Added a `HTTPResponseDecompressorTest` test case. ### Result: Now you don't have to manually check the `Content-Encoding` header and decompress the body on each incoming request.
This commit is contained in:
parent
16fbdf3868
commit
0584020dca
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ Package.pins
|
||||
*.pem
|
||||
/docs
|
||||
Package.resolved
|
||||
.swiftpm/
|
84
Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift
Normal file
84
Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift
Normal file
@ -0,0 +1,84 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2019 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import CNIOExtrasZlib
|
||||
import NIOHTTP1
|
||||
import NIO
|
||||
|
||||
public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableChannelHandler {
|
||||
public typealias InboundIn = HTTPServerRequestPart
|
||||
public typealias InboundOut = HTTPServerRequestPart
|
||||
public typealias OutboundIn = HTTPServerResponsePart
|
||||
public typealias OutboundOut = HTTPServerResponsePart
|
||||
|
||||
private struct Compression {
|
||||
let algorithm: NIOHTTPDecompression.CompressionAlgorithm
|
||||
let contentLength: Int
|
||||
}
|
||||
|
||||
private var decompressor: NIOHTTPDecompression.Decompressor
|
||||
private var compression: Compression?
|
||||
|
||||
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
|
||||
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
|
||||
self.compression = nil
|
||||
}
|
||||
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let request = self.unwrapInboundIn(data)
|
||||
|
||||
switch request {
|
||||
case .head(let head):
|
||||
if
|
||||
let encoding = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased(),
|
||||
let algorithm = NIOHTTPDecompression.CompressionAlgorithm(header: encoding),
|
||||
let length = head.headers[canonicalForm: "Content-Length"].first.flatMap({ Int($0) })
|
||||
{
|
||||
do {
|
||||
try self.decompressor.initializeDecoder(encoding: algorithm, length: length)
|
||||
self.compression = Compression(algorithm: algorithm, contentLength: length)
|
||||
} catch let error {
|
||||
context.fireErrorCaught(error)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
context.fireChannelRead(data)
|
||||
case .body(var part):
|
||||
guard let compression = self.compression else {
|
||||
context.fireChannelRead(data)
|
||||
return
|
||||
}
|
||||
|
||||
while part.readableBytes > 0 {
|
||||
do {
|
||||
var buffer = context.channel.allocator.buffer(capacity: 16384)
|
||||
try self.decompressor.decompress(part: &part, buffer: &buffer, originalLength: compression.contentLength)
|
||||
|
||||
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
|
||||
} catch let error {
|
||||
context.fireErrorCaught(error)
|
||||
return
|
||||
}
|
||||
}
|
||||
case .end:
|
||||
if self.compression != nil {
|
||||
self.decompressor.deinitializeDecoder()
|
||||
self.compression = nil
|
||||
}
|
||||
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
}
|
||||
}
|
@ -30,6 +30,7 @@ import XCTest
|
||||
testCase(DebugInboundEventsHandlerTest.allTests),
|
||||
testCase(DebugOutboundEventsHandlerTest.allTests),
|
||||
testCase(FixedLengthFrameDecoderTest.allTests),
|
||||
testCase(HTTPRequestDecompressorTest.allTests),
|
||||
testCase(HTTPResponseCompressorTest.allTests),
|
||||
testCase(HTTPResponseDecompressorTest.allTests),
|
||||
testCase(JSONRPCFramingContentLengthHeaderDecoderTests.allTests),
|
||||
|
@ -0,0 +1,36 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// HTTPRequestDecompressorTest+XCTest.swift
|
||||
//
|
||||
import XCTest
|
||||
|
||||
///
|
||||
/// NOTE: This file was generated by generate_linux_tests.rb
|
||||
///
|
||||
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
|
||||
///
|
||||
|
||||
extension HTTPRequestDecompressorTest {
|
||||
|
||||
static var allTests : [(String, (HTTPRequestDecompressorTest) -> () throws -> Void)] {
|
||||
return [
|
||||
("testDecompressionNoLimit", testDecompressionNoLimit),
|
||||
("testDecompressionLimitRatio", testDecompressionLimitRatio),
|
||||
("testDecompressionLimitSize", testDecompressionLimitSize),
|
||||
("testDecompression", testDecompression),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
210
Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift
Normal file
210
Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift
Normal file
@ -0,0 +1,210 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import XCTest
|
||||
import CNIOExtrasZlib
|
||||
@testable import NIO
|
||||
@testable import NIOHTTP1
|
||||
@testable import NIOHTTPCompression
|
||||
|
||||
private let testString = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
|
||||
|
||||
private final class DecompressedAssert: ChannelInboundHandler {
|
||||
typealias InboundIn = HTTPServerRequestPart
|
||||
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let request = self.unwrapInboundIn(data)
|
||||
|
||||
switch request {
|
||||
case .body(let buffer):
|
||||
let string = buffer.getString(at: buffer.readerIndex, length: buffer.readableBytes)
|
||||
guard string == testString else {
|
||||
context.fireErrorCaught(NIOHTTPDecompression.DecompressionError.inflationError(42))
|
||||
return
|
||||
}
|
||||
default: context.fireChannelRead(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class HTTPRequestDecompressorTest: XCTestCase {
|
||||
func testDecompressionNoLimit() throws {
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
|
||||
try channel.pipeline.addHandler(DecompressedAssert()).wait()
|
||||
|
||||
let buffer = ByteBuffer.of(string: testString)
|
||||
let compressed = compress(buffer, "gzip")
|
||||
|
||||
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")])
|
||||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
|
||||
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
|
||||
}
|
||||
|
||||
func testDecompressionLimitRatio() throws {
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .ratio(10))).wait()
|
||||
|
||||
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "13")])
|
||||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
|
||||
|
||||
let buffer = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
|
||||
let compressed = compress(buffer, "gzip")
|
||||
|
||||
do {
|
||||
try channel.writeInbound(HTTPServerRequestPart.body(compressed))
|
||||
} catch let error as NIOHTTPDecompression.DecompressionError {
|
||||
switch error {
|
||||
case .limit:
|
||||
// ok
|
||||
break
|
||||
default:
|
||||
XCTFail("Unexptected error: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testDecompressionLimitSize() throws {
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .size(10))).wait()
|
||||
|
||||
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "13")])
|
||||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
|
||||
|
||||
let buffer = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
|
||||
let compressed = compress(buffer, "gzip")
|
||||
|
||||
do {
|
||||
try channel.writeInbound(HTTPServerRequestPart.body(compressed))
|
||||
} catch let error as NIOHTTPDecompression.DecompressionError {
|
||||
switch error {
|
||||
case .limit:
|
||||
// ok
|
||||
break
|
||||
default:
|
||||
XCTFail("Unexptected error: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testDecompression() throws {
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
|
||||
|
||||
let body = Array(repeating: testString, count: 1000).joined()
|
||||
|
||||
for algorithm in [nil, "gzip", "deflate"] {
|
||||
let compressed: ByteBuffer
|
||||
var headers = HTTPHeaders()
|
||||
if let algorithm = algorithm {
|
||||
headers.add(name: "Content-Encoding", value: algorithm)
|
||||
compressed = compress(ByteBuffer.of(string: body), algorithm)
|
||||
} else {
|
||||
compressed = ByteBuffer.of(string: body)
|
||||
}
|
||||
headers.add(name: "Content-Length", value: "\(compressed.readableBytes)")
|
||||
|
||||
XCTAssertNoThrow(
|
||||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))
|
||||
)
|
||||
|
||||
do {
|
||||
try channel.writeInbound(HTTPServerRequestPart.body(compressed))
|
||||
} catch let error as NIOHTTPDecompression.DecompressionError {
|
||||
switch error {
|
||||
case .limit:
|
||||
// ok
|
||||
break
|
||||
default:
|
||||
XCTFail("Unexptected error: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
|
||||
}
|
||||
|
||||
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
|
||||
var stream = z_stream()
|
||||
|
||||
stream.zalloc = nil
|
||||
stream.zfree = nil
|
||||
stream.opaque = nil
|
||||
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: 1000)
|
||||
|
||||
let windowBits: Int32
|
||||
switch algorithm {
|
||||
case "deflate":
|
||||
windowBits = 15
|
||||
case "gzip":
|
||||
windowBits = 16 + 15
|
||||
default:
|
||||
XCTFail("Unsupported algorithm: \(algorithm)")
|
||||
return buffer
|
||||
}
|
||||
|
||||
let rc = CNIOExtrasZlib_deflateInit2(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, 8, Z_DEFAULT_STRATEGY)
|
||||
XCTAssertEqual(Z_OK, rc)
|
||||
|
||||
defer {
|
||||
stream.avail_in = 0
|
||||
stream.next_in = nil
|
||||
stream.avail_out = 0
|
||||
stream.next_out = nil
|
||||
}
|
||||
|
||||
var body = body
|
||||
|
||||
body.readWithUnsafeMutableReadableBytes { dataPtr in
|
||||
let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self)
|
||||
let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr,
|
||||
count: dataPtr.count)
|
||||
|
||||
stream.avail_in = UInt32(typedDataPtr.count)
|
||||
stream.next_in = typedDataPtr.baseAddress!
|
||||
|
||||
buffer.writeWithUnsafeMutableBytes { outputPtr in
|
||||
let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self),
|
||||
count: outputPtr.count)
|
||||
stream.avail_out = UInt32(typedOutputPtr.count)
|
||||
stream.next_out = typedOutputPtr.baseAddress!
|
||||
let rc = deflate(&stream, Z_FINISH)
|
||||
XCTAssertTrue(rc == Z_OK || rc == Z_STREAM_END)
|
||||
return typedOutputPtr.count - Int(stream.avail_out)
|
||||
}
|
||||
|
||||
return typedDataPtr.count - Int(stream.avail_in)
|
||||
}
|
||||
|
||||
deflateEnd(&stream)
|
||||
|
||||
return buffer
|
||||
}
|
||||
}
|
||||
|
||||
extension ByteBuffer {
|
||||
fileprivate static func of(string: String) -> ByteBuffer {
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: string.count)
|
||||
buffer.writeString(string)
|
||||
return buffer
|
||||
}
|
||||
|
||||
fileprivate static func of(bytes: [UInt8]) -> ByteBuffer {
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: bytes.count)
|
||||
buffer.writeBytes(bytes)
|
||||
return buffer
|
||||
}
|
||||
}
|
@ -176,13 +176,13 @@ class HTTPResponseDecompressorTest: XCTestCase {
|
||||
}
|
||||
|
||||
extension ByteBuffer {
|
||||
static func of(string: String) -> ByteBuffer {
|
||||
fileprivate static func of(string: String) -> ByteBuffer {
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: string.count)
|
||||
buffer.writeString(string)
|
||||
return buffer
|
||||
}
|
||||
|
||||
static func of(bytes: [UInt8]) -> ByteBuffer {
|
||||
fileprivate static func of(bytes: [UInt8]) -> ByteBuffer {
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: bytes.count)
|
||||
buffer.writeBytes(bytes)
|
||||
return buffer
|
||||
|
Loading…
x
Reference in New Issue
Block a user