swift-nio-extras/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift
Fabian Fett b8fd38c1db
Support decompression of HTTP responses that do not contain a Content… (#93)
Fixes #78

Co-authored-by: Trevör Anne Denise <trevor.annedenise@icloud.com>
2020-07-24 11:12:09 +01:00

85 lines
3.0 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2019 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import CNIOExtrasZlib
import NIOHTTP1
import NIO
public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableChannelHandler {
public typealias InboundIn = HTTPServerRequestPart
public typealias InboundOut = HTTPServerRequestPart
public typealias OutboundIn = HTTPServerResponsePart
public typealias OutboundOut = HTTPServerResponsePart
private struct Compression {
let algorithm: NIOHTTPDecompression.CompressionAlgorithm
let contentLength: Int
}
private var decompressor: NIOHTTPDecompression.Decompressor
private var compression: Compression?
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
self.compression = nil
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)
switch request {
case .head(let head):
if
let encoding = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased(),
let algorithm = NIOHTTPDecompression.CompressionAlgorithm(header: encoding),
let length = head.headers[canonicalForm: "Content-Length"].first.flatMap({ Int($0) })
{
do {
try self.decompressor.initializeDecoder(encoding: algorithm)
self.compression = Compression(algorithm: algorithm, contentLength: length)
} catch let error {
context.fireErrorCaught(error)
return
}
}
context.fireChannelRead(data)
case .body(var part):
guard let compression = self.compression else {
context.fireChannelRead(data)
return
}
while part.readableBytes > 0 {
do {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
} catch let error {
context.fireErrorCaught(error)
return
}
}
case .end:
if self.compression != nil {
self.decompressor.deinitializeDecoder()
self.compression = nil
}
context.fireChannelRead(data)
}
}
}