//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2019-2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import CNIOExtrasZlib
import NIOCore

/// Namespace for decompression code.
public enum NIOHTTPDecompression {
    /// Specifies how to limit decompression inflation.
    public struct DecompressionLimit {
        private enum Limit {
            case none
            case size(Int)
            case ratio(Int)
        }

        private var limit: Limit

        /// No limit will be set.
        /// - warning: Setting `limit` to `.none` leaves you vulnerable to denial of service attacks.
        public static let none = DecompressionLimit(limit: .none)
        /// Limit will be set on the request body size.
        public static func size(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .size(value)) }
        /// Limit will be set on a ratio between compressed body size and decompressed result.
        public static func ratio(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .ratio(value)) }

        func exceeded(compressed: Int, decompressed: Int) -> Bool {
            switch self.limit {
            case .none:
                return false
            case .size(let allowed):
                return decompressed > allowed
            case .ratio(let ratio):
                return decompressed > compressed * ratio
            }
        }
    }

    /// Error types for ``NIOHTTPCompression``
    public enum DecompressionError: Error, Equatable {
        /// The set ``NIOHTTPDecompression/DecompressionLimit`` has been exceeded
        case limit
        /// An error occured when inflating.  Error code is included to aid diagnosis.
        case inflationError(Int)
        /// Decoder could not be initialised.  Error code is included to aid diagnosis.
        case initializationError(Int)
    }

    enum CompressionAlgorithm: String {
        case gzip
        case deflate

        init?(header: String?) {
            switch header {
            case .some("gzip"):
                self = .gzip
            case .some("deflate"):
                self = .deflate
            default:
                return nil
            }
        }

        var window: CInt {
            switch self {
            case .deflate:
                return 15
            case .gzip:
                return 15 + 16
            }
        }
    }

    struct Decompressor {
        private let limit: NIOHTTPDecompression.DecompressionLimit
        private var stream = z_stream()
        private var inflated = 0

        init(limit: NIOHTTPDecompression.DecompressionLimit) {
            self.limit = limit
        }

        mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws {
            self.inflated += try self.stream.inflatePart(input: &part, output: &buffer)

            if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) {
                throw NIOHTTPDecompression.DecompressionError.limit
            }
        }

        mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws {
            self.stream.zalloc = nil
            self.stream.zfree = nil
            self.stream.opaque = nil

            let rc = CNIOExtrasZlib_inflateInit2(&self.stream, encoding.window)
            guard rc == Z_OK else {
                throw NIOHTTPDecompression.DecompressionError.initializationError(Int(rc))
            }
        }

        mutating func deinitializeDecoder() {
            inflateEnd(&self.stream)
        }
    }
}

extension z_stream {
    mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int {
        let minimumCapacity = input.readableBytes * 2
        var written = 0
        try input.readWithUnsafeMutableReadableBytes { pointer in
            self.avail_in = UInt32(pointer.count)
            self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

            defer {
                self.avail_in = 0
                self.next_in = nil
                self.avail_out = 0
                self.next_out = nil
            }

            written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)

            return pointer.count - Int(self.avail_in)
        }
        return written
    }

    private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int {
        return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
            self.avail_out = UInt32(pointer.count)
            self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

            let rc = inflate(&self, Z_NO_FLUSH)
            guard rc == Z_OK || rc == Z_STREAM_END else {
                throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc))
            }

            return pointer.count - Int(self.avail_out)
        }
    }
}