/* * JsonWebKeySet.cpp * * This source file is part of the FoundationDB open source project * * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "flow/Arena.h" #include "flow/AutoCPointer.h" #include "flow/Error.h" #include "flow/IRandom.h" #include "flow/MkCert.h" #include "flow/PKey.h" #include "flow/UnitTest.h" #include "fdbrpc/Base64UrlEncode.h" #include "fdbrpc/Base64UrlDecode.h" #include "fdbrpc/JsonWebKeySet.h" #if defined(HAVE_WOLFSSL) #include #endif #include #include #include #include #include #include #include #if OPENSSL_VERSION_NUMBER >= 0x30000000L && !defined(_WIN32) #define USE_V3_API 1 #else #define USE_V3_API 0 #endif #if USE_V3_API #include #include #endif #include #include #include #include #include #include #include #define JWKS_ERROR(issue, op) \ TraceEvent(SevWarnAlways, "JsonWebKeySet" #op "Error").suppressFor(10).detail("Issue", issue) #define JWKS_PARSE_ERROR(issue) JWKS_ERROR(issue, Parse) #define JWKS_WRITE_ERROR(issue) JWKS_ERROR(issue, Write) #define JWK_PARSE_ERROR(issue) \ TraceEvent(SevWarnAlways, "JsonWebKeyParseError") \ .suppressFor(10) \ .detail("Issue", issue) \ .detail("KeyIndexBase0", keyIndex) #define JWK_WRITE_ERROR(issue) \ TraceEvent(SevWarnAlways, "JsonWebKeyWriteError") \ .suppressFor(10) \ .detail("Issue", issue) \ .detail("KeyName", keyName.toString()) #define JWK_ERROR_OSSL(issue, op) \ do { \ char buf[256]{ \ 0, \ }; \ if (auto err = ::ERR_get_error()) { \ ::ERR_error_string_n(err, buf, sizeof(buf)); \ } \ JWK_##op##_ERROR(issue).detail("OpenSSLError", static_cast(buf)); \ } while (0) #define JWK_PARSE_ERROR_OSSL(issue) JWK_ERROR_OSSL(issue, PARSE) #define JWK_WRITE_ERROR_OSSL(issue) JWK_ERROR_OSSL(issue, WRITE) namespace { template bool getJwkStringMember(JsonValue const& value, char const* memberName, std::conditional_t>& out, int keyIndex) { auto itr = value.FindMember(memberName); if (itr == value.MemberEnd()) { if constexpr (Required) { JWK_PARSE_ERROR("Missing required member").detail("Member", memberName); return false; } else { return true; } } auto const& member = itr->value; if (!member.IsString()) { JWK_PARSE_ERROR("Expected member is not a string").detail("MemberName", memberName); return false; } out = StringRef(reinterpret_cast(member.GetString()), member.GetStringLength()); return true; } #define DECLARE_JWK_REQUIRED_STRING_MEMBER(value, member) \ auto member = StringRef(); \ if (!getJwkStringMember(value, #member, member, keyIndex)) \ return {} #define DECLARE_JWK_OPTIONAL_STRING_MEMBER(value, member) \ auto member = Optional(); \ if (!getJwkStringMember(value, #member, member, keyIndex)) \ return {} template bool getJwkBigNumMember(Arena& arena, std::conditional_t> const& b64Member, AutoPtr& ptr, char const* memberName, char const* algorithm, int keyIndex) { if constexpr (!Required) { if (!b64Member.present()) return true; } auto data = StringRef(); if constexpr (Required) { data = b64Member; } else { data = b64Member.get(); } auto decoded = base64url::decode(arena, data); if (!decoded.present()) { JWK_PARSE_ERROR("Base64URL decoding for parameter failed") .detail("Algorithm", algorithm) .detail("Parameter", memberName); return false; } data = decoded.get(); auto bn = ::BN_bin2bn(data.begin(), data.size(), nullptr); if (!bn) { JWK_PARSE_ERROR_OSSL("BN_bin2bn"); return false; } ptr.reset(bn); return true; } #define DECL_DECODED_BN_MEMBER_REQUIRED(member, algo) \ auto member = AutoCPointer(nullptr, &::BN_free); \ if (!getJwkBigNumMember(arena, b64##member, member, #member, algo, keyIndex)) \ return {} #define DECL_DECODED_BN_MEMBER_OPTIONAL(member, algo) \ auto member = AutoCPointer(nullptr, &::BN_clear_free); \ if (!getJwkBigNumMember(arena, b64##member, member, #member, algo, keyIndex)) \ return {} #define EC_DECLARE_DECODED_REQUIRED_BN_MEMBER(member) DECL_DECODED_BN_MEMBER_REQUIRED(member, "EC") #define EC_DECLARE_DECODED_OPTIONAL_BN_MEMBER(member) DECL_DECODED_BN_MEMBER_OPTIONAL(member, "EC") #define RSA_DECLARE_DECODED_REQUIRED_BN_MEMBER(member) DECL_DECODED_BN_MEMBER_REQUIRED(member, "RSA") #define RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(member) DECL_DECODED_BN_MEMBER_OPTIONAL(member, "RSA") StringRef bigNumToBase64Url(Arena& arena, const BIGNUM* bn) { auto len = BN_num_bytes(bn); auto buf = new (arena) uint8_t[len]; ::BN_bn2bin(bn, buf); return base64url::encode(arena, StringRef(buf, len)); } Optional parseEcP256Key(StringRef b64x, StringRef b64y, Optional b64d, int keyIndex) { auto arena = Arena(); EC_DECLARE_DECODED_REQUIRED_BN_MEMBER(x); EC_DECLARE_DECODED_REQUIRED_BN_MEMBER(y); EC_DECLARE_DECODED_OPTIONAL_BN_MEMBER(d); #if USE_V3_API // avoid deprecated API auto bld = AutoCPointer(::OSSL_PARAM_BLD_new(), &::OSSL_PARAM_BLD_free); if (!bld) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_new() for EC"); return {}; } // since OSSL_PKEY_PARAM_EC_PUB_{X|Y} are not settable params, we'll need to build a EC_GROUP and serialize it auto group = AutoCPointer(::EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1), &::EC_GROUP_free); if (!group) { JWK_PARSE_ERROR_OSSL("EC_GROUP_new_by_curve_name()"); return {}; } auto point = AutoCPointer(::EC_POINT_new(group), &::EC_POINT_free); if (!point) { JWK_PARSE_ERROR_OSSL("EC_POINT_new()"); return {}; } if (1 != ::EC_POINT_set_affine_coordinates(group, point, x, y, nullptr)) { JWK_PARSE_ERROR_OSSL("EC_POINT_set_affine_coordinates()"); return {}; } auto pointBufLen = ::EC_POINT_point2oct(group, point, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0, nullptr); if (!pointBufLen) { JWK_PARSE_ERROR_OSSL("EC_POINT_point2oct() for length"); return {}; } auto pointBuf = new (arena) uint8_t[pointBufLen]; ::EC_POINT_point2oct(group, point, POINT_CONVERSION_UNCOMPRESSED, pointBuf, pointBufLen, nullptr); if (!::OSSL_PARAM_BLD_push_utf8_string(bld, OSSL_PKEY_PARAM_GROUP_NAME, "prime256v1", sizeof("prime256v1") - 1) || !::OSSL_PARAM_BLD_push_octet_string(bld, OSSL_PKEY_PARAM_PUB_KEY, pointBuf, pointBufLen)) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_push_*() for EC (group, point)"); return {}; } if (d && !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_PRIV_KEY, d)) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_push_BN() for EC (d)"); return {}; } auto params = AutoCPointer(::OSSL_PARAM_BLD_to_param(bld), &OSSL_PARAM_free); if (!params) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_to_param() for EC"); return {}; } auto pctx = AutoCPointer(::EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), &::EVP_PKEY_CTX_free); if (!pctx) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_CTX_new_from_name(EC)"); return {}; } if (1 != ::EVP_PKEY_fromdata_init(pctx)) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_fromdata_init() for EC"); return {}; } auto pkey = std::add_pointer_t(); if (1 != ::EVP_PKEY_fromdata(pctx, &pkey, (d ? EVP_PKEY_KEYPAIR : EVP_PKEY_PUBLIC_KEY), params) || !pkey) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_fromdata() for EC"); return {}; } auto pkeyAutoPtr = AutoCPointer(pkey, &::EVP_PKEY_free); #else // USE_V3_API auto key = AutoCPointer(::EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), &::EC_KEY_free); if (!key) { JWK_PARSE_ERROR_OSSL("EC_KEY_new()"); return {}; } if (d) { if (1 != ::EC_KEY_set_private_key(key, d)) { JWK_PARSE_ERROR_OSSL("EC_KEY_set_private_key()"); return {}; } } if (1 != ::EC_KEY_set_public_key_affine_coordinates(key, x, y)) { JWK_PARSE_ERROR_OSSL("EC_KEY_set_public_key_affine_coordinates(key, x, y)"); return {}; } auto pkey = AutoCPointer(::EVP_PKEY_new(), &::EVP_PKEY_free); if (!pkey) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_new() for EC"); return {}; } if (1 != EVP_PKEY_set1_EC_KEY(pkey, key)) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_set1_EC_KEY()"); return {}; } #endif // USE_V3_API if (d) { auto len = ::i2d_PrivateKey(pkey, nullptr); if (len <= 0) { JWK_PARSE_ERROR_OSSL("i2d_PrivateKey() for EC"); return {}; } auto buf = new (arena) uint8_t[len]; auto out = std::add_pointer_t(buf); len = ::i2d_PrivateKey(pkey, &out); // assign through public API, even if it means some parsing overhead return PrivateKey(DerEncoded{}, StringRef(buf, len)); } else { auto len = ::i2d_PUBKEY(pkey, nullptr); if (len <= 0) { JWK_PARSE_ERROR_OSSL("i2d_PUBKEY() for EC"); return {}; } auto buf = new (arena) uint8_t[len]; auto out = std::add_pointer_t(buf); len = ::i2d_PUBKEY(pkey, &out); // assign through public API, even if it means some parsing overhead return PublicKey(DerEncoded{}, StringRef(buf, len)); } } Optional parseRsaKey(StringRef b64n, StringRef b64e, Optional b64d, Optional b64p, Optional b64q, Optional b64dp, Optional b64dq, Optional b64qi, int keyIndex) { auto arena = Arena(); RSA_DECLARE_DECODED_REQUIRED_BN_MEMBER(n); RSA_DECLARE_DECODED_REQUIRED_BN_MEMBER(e); RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(d); RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(p); RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(q); RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(dp); RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(dq); RSA_DECLARE_DECODED_OPTIONAL_BN_MEMBER(qi); auto const isPublic = !d || !p || !q || !dp || !dq || !qi; #if USE_V3_API // avoid deprecated, algo-specific API auto bld = AutoCPointer(::OSSL_PARAM_BLD_new(), &::OSSL_PARAM_BLD_free); if (!bld) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_new() for EC"); return {}; } if (!::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_N, n) || !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_E, e)) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_push_BN() for RSA (n, e)"); return {}; } if (!isPublic) { if (!::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_D, d) || !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_FACTOR1, p) || !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_FACTOR2, q) || !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_EXPONENT1, dp) || !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_EXPONENT2, dq) || !::OSSL_PARAM_BLD_push_BN(bld, OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qi)) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_push_BN() for RSA (d, p, q, dp, dq, qi)"); return {}; } } auto params = AutoCPointer(::OSSL_PARAM_BLD_to_param(bld), &::OSSL_PARAM_free); if (!params) { JWK_PARSE_ERROR_OSSL("OSSL_PARAM_BLD_to_param() for RSA"); return {}; } auto pctx = AutoCPointer(::EVP_PKEY_CTX_new_from_name(nullptr, "RSA", nullptr), &::EVP_PKEY_CTX_free); if (!pctx) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_CTX_new_from_name(RSA)"); return {}; } if (1 != ::EVP_PKEY_fromdata_init(pctx)) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_fromdata_init() for RSA"); return {}; } auto pkey = std::add_pointer_t(); if (1 != ::EVP_PKEY_fromdata(pctx, &pkey, (!isPublic ? EVP_PKEY_KEYPAIR : EVP_PKEY_PUBLIC_KEY), params)) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_fromdata() for EC"); return {}; } auto pkeyAutoPtr = AutoCPointer(pkey, &::EVP_PKEY_free); #else // USE_V3_API auto rsa = AutoCPointer(RSA_new(), &::RSA_free); if (!rsa) { JWK_PARSE_ERROR_OSSL("RSA_new()"); return {}; } if (1 != ::RSA_set0_key(rsa, n, e, d)) { JWK_PARSE_ERROR_OSSL("RSA_set0_key()"); return {}; } // set0 == ownership taken by rsa, no need to free n.release(); e.release(); d.release(); if (!isPublic) { if (1 != ::RSA_set0_factors(rsa, p, q)) { JWK_PARSE_ERROR_OSSL("RSA_set0_factors()"); return {}; } p.release(); q.release(); if (1 != ::RSA_set0_crt_params(rsa, dp, dq, qi)) { JWK_PARSE_ERROR_OSSL("RSA_set0_crt_params()"); return {}; } dp.release(); dq.release(); qi.release(); } auto pkey = AutoCPointer(::EVP_PKEY_new(), &::EVP_PKEY_free); if (!pkey) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_new() for RSA"); return {}; } if (1 != ::EVP_PKEY_set1_RSA(pkey, rsa)) { JWK_PARSE_ERROR_OSSL("EVP_PKEY_set1_RSA()"); return {}; } #endif // USE_V3_API if (!isPublic) { auto len = ::i2d_PrivateKey(pkey, nullptr); if (len <= 0) { JWK_PARSE_ERROR_OSSL("i2d_PrivateKey() for RSA"); return {}; } auto buf = new (arena) uint8_t[len]; auto out = std::add_pointer_t(buf); len = ::i2d_PrivateKey(pkey, &out); // assign through public API, even if it means some parsing overhead return PrivateKey(DerEncoded{}, StringRef(buf, len)); } else { auto len = ::i2d_PUBKEY(pkey, nullptr); if (len <= 0) { JWK_PARSE_ERROR_OSSL("i2d_PUBKEY() for RSA"); return {}; } auto buf = new (arena) uint8_t[len]; auto out = std::add_pointer_t(buf); len = ::i2d_PUBKEY(pkey, &out); // assign through public API, even if it means some parsing overhead return PublicKey(DerEncoded{}, StringRef(buf, len)); } } template Optional parseKey(const Value& key, StringRef kty, int keyIndex) { if (kty == "EC"_sr) { DECLARE_JWK_REQUIRED_STRING_MEMBER(key, alg); if (alg != "ES256"_sr) { JWK_PARSE_ERROR("Unsupported EC algorithm").detail("Algorithm", alg.toString()); return {}; } DECLARE_JWK_REQUIRED_STRING_MEMBER(key, crv); if (crv != "P-256"_sr) { JWK_PARSE_ERROR("Unsupported EC curve").detail("Curve", crv.toString()); return {}; } DECLARE_JWK_REQUIRED_STRING_MEMBER(key, x); DECLARE_JWK_REQUIRED_STRING_MEMBER(key, y); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, d); return parseEcP256Key(x, y, d, keyIndex); } else if (kty == "RSA"_sr) { DECLARE_JWK_REQUIRED_STRING_MEMBER(key, alg); if (alg != "RS256"_sr) { JWK_PARSE_ERROR("Unsupported RSA algorithm").detail("Algorithm", alg.toString()); return {}; } DECLARE_JWK_REQUIRED_STRING_MEMBER(key, n); DECLARE_JWK_REQUIRED_STRING_MEMBER(key, e); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, d); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, p); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, q); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, dp); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, dq); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, qi); auto privKeyArgs = 0; privKeyArgs += d.present(); privKeyArgs += p.present(); privKeyArgs += q.present(); privKeyArgs += dp.present(); privKeyArgs += dq.present(); privKeyArgs += qi.present(); if (privKeyArgs == 0 || privKeyArgs == 6) { return parseRsaKey(n, e, d, p, q, dp, dq, qi, keyIndex); } else { JWK_PARSE_ERROR("Private key arguments partially exist").detail("NumMissingArgs", 6 - privKeyArgs); return {}; } } else { JWK_PARSE_ERROR("Unsupported key type").detail("KeyType", kty.toString()); return {}; } } bool encodeEcKey(rapidjson::Writer& writer, StringRef keyName, EVP_PKEY* pKey, const bool isPublic) { auto arena = Arena(); writer.StartObject(); writer.Key("kty"); writer.String("EC"); writer.Key("alg"); writer.String("ES256"); writer.Key("kid"); writer.String(reinterpret_cast(keyName.begin()), keyName.size()); #if USE_V3_API auto curveNameBuf = std::array{}; auto curveNameLen = 0ul; if (1 != EVP_PKEY_get_utf8_string_param( pKey, OSSL_PKEY_PARAM_GROUP_NAME, curveNameBuf.begin(), sizeof(curveNameBuf), &curveNameLen)) { JWK_WRITE_ERROR_OSSL("Get group name from EC PKey"); return false; } auto curveName = std::string_view(curveNameBuf.cbegin(), curveNameLen); if (curveName != std::string_view("prime256v1")) { JWK_WRITE_ERROR("Unsupported EC curve").detail("CurveName", curveName); return false; } writer.Key("crv"); writer.String("P-256"); #define JWK_WRITE_BN_EC_PARAM(x, param) \ do { \ auto x = AutoCPointer(nullptr, &::BN_clear_free); \ auto rawX = std::add_pointer_t(); \ if (1 != ::EVP_PKEY_get_bn_param(pKey, param, &rawX)) { \ JWK_WRITE_ERROR_OSSL("EVP_PKEY_get_bn_param(" #param ")"); \ return false; \ } \ x.reset(rawX); \ auto b64##x = bigNumToBase64Url(arena, x); \ writer.Key(#x); \ writer.String(reinterpret_cast(b64##x.begin()), b64##x.size()); \ } while (0) // Get and write affine coordinates, X and Y JWK_WRITE_BN_EC_PARAM(x, OSSL_PKEY_PARAM_EC_PUB_X); JWK_WRITE_BN_EC_PARAM(y, OSSL_PKEY_PARAM_EC_PUB_Y); if (!isPublic) { JWK_WRITE_BN_EC_PARAM(d, OSSL_PKEY_PARAM_PRIV_KEY); } #undef JWK_WRITE_BN_EC_PARAM #else // USE_V3_API auto ecKey = ::EVP_PKEY_get0_EC_KEY(pKey); // get0 == no refcount, no need to free if (!ecKey) { JWK_WRITE_ERROR_OSSL("Could not extract EC_KEY from EVP_PKEY"); return false; } auto group = ::EC_KEY_get0_group(ecKey); if (!group) { JWK_WRITE_ERROR("Could not get EC_GROUP from EVP_PKEY"); return false; } auto curveName = ::EC_GROUP_get_curve_name(group); if (curveName == NID_undef) { JWK_WRITE_ERROR("Could not match EC_GROUP to known curve"); return false; } if (curveName != NID_X9_62_prime256v1) { JWK_WRITE_ERROR("Unsupported curve, expected P-256 (prime256v1)").detail("curveName", ::OBJ_nid2sn(curveName)); return false; } writer.Key("crv"); writer.String("P-256"); auto point = ::EC_KEY_get0_public_key(ecKey); if (!point) { JWK_WRITE_ERROR_OSSL("EC_KEY_get0_public_key() returned null"); return false; } auto x = AutoCPointer(::BN_new(), &::BN_free); if (!x) { JWK_WRITE_ERROR_OSSL("x = BN_new()"); return false; } auto y = AutoCPointer(::BN_new(), &::BN_free); if (!y) { JWK_WRITE_ERROR_OSSL("y = BN_new()"); return false; } if (1 != #ifdef OPENSSL_IS_BORINGSSL ::EC_POINT_get_affine_coordinates_GFp(group, point, x, y, nullptr) #else ::EC_POINT_get_affine_coordinates(group, point, x, y, nullptr) #endif ) { JWK_WRITE_ERROR_OSSL("EC_POINT_get_affine_coordinates()"); return false; } auto b64X = bigNumToBase64Url(arena, x); auto b64Y = bigNumToBase64Url(arena, y); writer.Key("x"); writer.String(reinterpret_cast(b64X.begin()), b64X.size()); writer.Key("y"); writer.String(reinterpret_cast(b64Y.begin()), b64Y.size()); if (!isPublic) { auto d = ::EC_KEY_get0_private_key(ecKey); if (!d) { JWK_WRITE_ERROR("EC_KEY_get0_private_key()"); return false; } auto b64D = bigNumToBase64Url(arena, d); writer.Key("d"); writer.String(reinterpret_cast(b64D.begin()), b64D.size()); } #endif // USE_V3_API writer.EndObject(); return true; } bool encodeRsaKey(rapidjson::Writer& writer, StringRef keyName, EVP_PKEY* pKey, const bool isPublic) { auto arena = Arena(); writer.StartObject(); writer.Key("kty"); writer.String("RSA"); writer.Key("alg"); writer.String("RS256"); writer.Key("kid"); writer.String(reinterpret_cast(keyName.begin()), keyName.size()); #if USE_V3_API #define JWK_WRITE_BN_RSA_PARAM_V3(x, param) \ do { \ auto x = AutoCPointer(nullptr, &::BN_clear_free); \ auto rawX = std::add_pointer_t(); \ if (1 != ::EVP_PKEY_get_bn_param(pKey, param, &rawX)) { \ JWK_WRITE_ERROR_OSSL("EVP_PKEY_get_bn_param(" #x ")"); \ return false; \ } \ x.reset(rawX); \ auto b64##x = bigNumToBase64Url(arena, x); \ writer.Key(#x); \ writer.String(reinterpret_cast(b64##x.begin()), b64##x.size()); \ } while (0) JWK_WRITE_BN_RSA_PARAM_V3(n, OSSL_PKEY_PARAM_RSA_N); JWK_WRITE_BN_RSA_PARAM_V3(e, OSSL_PKEY_PARAM_RSA_E); if (!isPublic) { JWK_WRITE_BN_RSA_PARAM_V3(d, OSSL_PKEY_PARAM_RSA_D); JWK_WRITE_BN_RSA_PARAM_V3(p, OSSL_PKEY_PARAM_RSA_FACTOR1); JWK_WRITE_BN_RSA_PARAM_V3(q, OSSL_PKEY_PARAM_RSA_FACTOR2); JWK_WRITE_BN_RSA_PARAM_V3(dp, OSSL_PKEY_PARAM_RSA_EXPONENT1); JWK_WRITE_BN_RSA_PARAM_V3(dq, OSSL_PKEY_PARAM_RSA_EXPONENT2); JWK_WRITE_BN_RSA_PARAM_V3(qi, OSSL_PKEY_PARAM_RSA_COEFFICIENT1); } #undef JWK_WRITE_BN_RSA_PARAM_V3 #else // USE_V3_API #define JWK_WRITE_BN_RSA_PARAM_V1(x) \ do { \ if (!x) { \ JWK_WRITE_ERROR_OSSL("RSA_get0_* returned null " #x); \ return false; \ } \ auto b64##x = bigNumToBase64Url(arena, x); \ writer.Key(#x); \ writer.String(reinterpret_cast(b64##x.begin()), b64##x.size()); \ } while (0) auto rsaKey = ::EVP_PKEY_get0_RSA(pKey); // get0 == no refcount, no need to free if (!rsaKey) { JWK_WRITE_ERROR_OSSL("Could not extract RSA key from EVP_PKEY"); return false; } auto n = std::add_pointer_t(); auto e = std::add_pointer_t(); auto d = std::add_pointer_t(); auto p = std::add_pointer_t(); auto q = std::add_pointer_t(); auto dp = std::add_pointer_t(); auto dq = std::add_pointer_t(); auto qi = std::add_pointer_t(); ::RSA_get0_key(rsaKey, &n, &e, &d); JWK_WRITE_BN_RSA_PARAM_V1(n); JWK_WRITE_BN_RSA_PARAM_V1(e); if (!isPublic) { ::RSA_get0_factors(rsaKey, &p, &q); ::RSA_get0_crt_params(rsaKey, &dp, &dq, &qi); JWK_WRITE_BN_RSA_PARAM_V1(d); JWK_WRITE_BN_RSA_PARAM_V1(p); JWK_WRITE_BN_RSA_PARAM_V1(q); JWK_WRITE_BN_RSA_PARAM_V1(dp); JWK_WRITE_BN_RSA_PARAM_V1(dq); JWK_WRITE_BN_RSA_PARAM_V1(qi); } #undef JWK_WRITE_BN_RSA_PARAM_V1 #endif // USE_V3_API writer.EndObject(); return true; } // Add exactly one object to context of writer. Object shall contain JWK-encoded public or private key bool encodeKey(rapidjson::Writer& writer, StringRef keyName, const PublicOrPrivateKey& key) { auto const isPublic = key.isPublic(); auto pKey = std::add_pointer_t(); auto alg = PKeyAlgorithm{}; if (isPublic) { auto const& keyObj = key.getPublic(); pKey = keyObj.nativeHandle(); alg = keyObj.algorithm(); } else { auto const& keyObj = key.getPrivate(); pKey = key.getPrivate().nativeHandle(); alg = keyObj.algorithm(); } if (!pKey) { JWK_WRITE_ERROR("PKey object to encode is null"); return false; } if (alg == PKeyAlgorithm::EC) { return encodeEcKey(writer, keyName, pKey, isPublic); } else if (alg == PKeyAlgorithm::RSA) { return encodeRsaKey(writer, keyName, pKey, isPublic); } else { JWK_WRITE_ERROR("Attempted to encode PKey with unsupported algorithm"); return false; } return true; } void testPublicKey(PrivateKey (*factory)()) { // stringify-deserialize public key. // sign some data using private key to see whether deserialized public key can verify it. auto& rng = *deterministicRandom(); auto pubKeyName = Standalone("somePublicKey"_sr); auto privKey = factory(); auto pubKey = privKey.toPublic(); auto jwks = JsonWebKeySet{}; jwks.keys.emplace(pubKeyName, pubKey); auto arena = Arena(); auto jwksStr = jwks.toStringRef(arena).get(); fmt::print("Test JWKS: {}\n", jwksStr.toString()); auto jwksClone = JsonWebKeySet::parse(jwksStr, {}); ASSERT(jwksClone.present()); auto pubKeyClone = jwksClone.get().keys[pubKeyName].getPublic(); auto randByteStr = [&rng, &arena](int len) { auto buf = new (arena) uint8_t[len]; for (auto i = 0; i < len; i++) buf[i] = rng.randomUInt32() % 255u; return StringRef(buf, len); }; auto randData = randByteStr(rng.randomUInt32() % 128 + 16); auto signature = privKey.sign(arena, randData, *::EVP_sha256()); ASSERT(pubKeyClone.verify(randData, signature, *::EVP_sha256())); const_cast(*randData.begin())++; ASSERT(!pubKeyClone.verify(randData, signature, *::EVP_sha256())); fmt::print("TESTED OK FOR OPENSSL V{} API\n", (OPENSSL_VERSION_NUMBER >> 28)); } void testPrivateKey(PrivateKey (*factory)()) { // stringify-deserialize private key. // sign some data using deserialized private key to see whether public key can verify it. auto& rng = *deterministicRandom(); auto privKeyName = Standalone("somePrivateKey"_sr); auto privKey = factory(); auto pubKey = privKey.toPublic(); auto jwks = JsonWebKeySet{}; jwks.keys.emplace(privKeyName, privKey); auto arena = Arena(); auto jwksStr = jwks.toStringRef(arena).get(); fmt::print("Test JWKS: {}\n", jwksStr.toString()); auto jwksClone = JsonWebKeySet::parse(jwksStr, {}); ASSERT(jwksClone.present()); auto privKeyClone = jwksClone.get().keys[privKeyName].getPrivate(); auto randByteStr = [&rng, &arena](int len) { auto buf = new (arena) uint8_t[len]; for (auto i = 0; i < len; i++) buf[i] = rng.randomUInt32() % 255u; return StringRef(buf, len); }; auto randData = randByteStr(rng.randomUInt32() % 128 + 16); auto signature = privKeyClone.sign(arena, randData, *::EVP_sha256()); ASSERT(pubKey.verify(randData, signature, *::EVP_sha256())); const_cast(*randData.begin())++; ASSERT(!pubKey.verify(randData, signature, *::EVP_sha256())); fmt::print("TESTED OK FOR OPENSSL V{} API\n", (OPENSSL_VERSION_NUMBER >> 28)); } } // anonymous namespace Optional JsonWebKeySet::parse(StringRef jwksString, VectorRef allowedUses) { auto d = rapidjson::Document(); d.Parse(reinterpret_cast(jwksString.begin()), jwksString.size()); if (d.HasParseError()) { JWKS_PARSE_ERROR("ParseError") .detail("Message", GetParseError_En(d.GetParseError())) .detail("Offset", d.GetErrorOffset()); return {}; } auto keysItr = d.FindMember("keys"); if (!d.IsObject() || keysItr == d.MemberEnd() || !keysItr->value.IsArray()) { JWKS_PARSE_ERROR("JWKS must be an object and have 'keys' array member"); return {}; } auto const& keys = keysItr->value; auto ret = JsonWebKeySet{}; for (auto keyIndex = 0; keyIndex < keys.Size(); keyIndex++) { if (!keys[keyIndex].IsObject()) { JWKS_PARSE_ERROR("element of 'keys' array must be an object"); return {}; } auto const& key = keys[keyIndex]; DECLARE_JWK_REQUIRED_STRING_MEMBER(key, kty); DECLARE_JWK_REQUIRED_STRING_MEMBER(key, kid); DECLARE_JWK_OPTIONAL_STRING_MEMBER(key, use); if (use.present() && !allowedUses.empty()) { auto allowed = false; for (auto allowedUse : allowedUses) { if (allowedUse == use.get()) { allowed = true; break; } } if (!allowed) { JWK_PARSE_ERROR("Illegal optional 'use' member found").detail("Use", use.get().toString()); return {}; } } auto parsedKey = parseKey(key, kty, keyIndex); if (!parsedKey.present()) return {}; auto [iter, inserted] = ret.keys.insert({ Standalone(kid), parsedKey.get() }); if (!inserted) { JWK_PARSE_ERROR("Duplicate key name").detail("KeyName", kid.toString()); return {}; } } return ret; } Optional JsonWebKeySet::toStringRef(Arena& arena) { using Buffer = rapidjson::StringBuffer; using Writer = rapidjson::Writer; auto buffer = Buffer(); auto writer = Writer(buffer); writer.StartObject(); writer.Key("keys"); writer.StartArray(); for (const auto& [keyName, key] : keys) { if (!encodeKey(writer, keyName, key)) { return {}; } } writer.EndArray(); writer.EndObject(); auto buf = new (arena) uint8_t[buffer.GetSize()]; ::memcpy(buf, buffer.GetString(), buffer.GetSize()); return StringRef(buf, buffer.GetSize()); } void forceLinkJsonWebKeySetTests() {} TEST_CASE("/fdbrpc/JsonWebKeySet/EC/PublicKey") { testPublicKey(&mkcert::makeEcP256); return Void(); } TEST_CASE("/fdbrpc/JsonWebKeySet/EC/PrivateKey") { testPrivateKey(&mkcert::makeEcP256); return Void(); } TEST_CASE("/fdbrpc/JsonWebKeySet/RSA/PublicKey") { testPublicKey(&mkcert::makeRsa2048Bit); return Void(); } TEST_CASE("/fdbrpc/JsonWebKeySet/RSA/PrivateKey") { testPrivateKey(&mkcert::makeRsa2048Bit); return Void(); }