diff --git a/fdbrpc/TokenCache.cpp b/fdbrpc/TokenCache.cpp index 363655db0d..d7dbcb25db 100644 --- a/fdbrpc/TokenCache.cpp +++ b/fdbrpc/TokenCache.cpp @@ -161,12 +161,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, TEST(true); // Token can't be parsed return false; } - if (!t.keyId.present()) { - TEST(true); // Token with no key id - TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "NoKeyID"); - return false; - } - auto key = FlowTransport::transport().getPublicKeyByName(t.keyId.get()); + auto key = FlowTransport::transport().getPublicKeyByName(t.keyId); if (!key.present()) { TEST(true); // Token referencing non-existing key TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "UnknownKey"); @@ -258,10 +253,6 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") { [](Arena&, IRandom&, authz::jwt::TokenRef&) { FlowTransport::transport().removeAllPublicKeys(); }, "NoKeyWithSuchName", }, - { - [](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.keyId.reset(); }, - "NoKeyId", - }, { [](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.expiresAtUnixTime.reset(); }, "NoExpirationTime", @@ -282,10 +273,6 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") { }, "TokenNotYetValid", }, - { - [](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.keyId.reset(); }, - "UnknownKey", - }, { [](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); }, "NoTenants", diff --git a/fdbrpc/TokenSign.cpp b/fdbrpc/TokenSign.cpp index 66930f2250..2e97c9354a 100644 --- a/fdbrpc/TokenSign.cpp +++ b/fdbrpc/TokenSign.cpp @@ -185,14 +185,13 @@ void appendField(fmt::memory_buffer& b, char const (&name)[NameLen], Optionalvalue; - auto const& typ = typItr->value; - if (alg.IsString() && typ.IsString()) { - auto algValue = StringRef(reinterpret_cast(alg.GetString()), alg.GetStringLength()); - auto algType = algorithmFromString(algValue); - if (algType == Algorithm::UNKNOWN) - return false; - token.algorithm = algType; - auto typValue = StringRef(reinterpret_cast(typ.GetString()), typ.GetStringLength()); - if (typValue != "JWT"_sr) - return false; - return true; - } - } - return false; + if (typItr == d.MemberEnd() || !typItr->value.IsString()) + return false; + auto algItr = d.FindMember("alg"); + if (algItr == d.MemberEnd() || !algItr->value.IsString()) + return false; + auto kidItr = d.FindMember("kid"); + if (kidItr == d.MemberEnd() || !kidItr->value.IsString()) + return false; + auto const& typ = typItr->value; + auto const& alg = algItr->value; + auto const& kid = kidItr->value; + auto typValue = StringRef(reinterpret_cast(typ.GetString()), typ.GetStringLength()); + if (typValue != "JWT"_sr) + return false; + auto algValue = StringRef(reinterpret_cast(alg.GetString()), alg.GetStringLength()); + auto algType = algorithmFromString(algValue); + if (algType == Algorithm::UNKNOWN) + return false; + token.algorithm = algType; + token.keyId = StringRef(arena, reinterpret_cast(kid.GetString()), kid.GetStringLength()); + return true; } template @@ -382,8 +389,6 @@ bool parsePayloadPart(Arena& arena, TokenRef& token, StringRef b64urlPayload) { return false; if (!parseField(arena, token.notBeforeUnixTime, d, "nbf")) return false; - if (!parseField(arena, token.keyId, d, "kid")) - return false; if (!parseField(arena, token.tenants, d, "tenants")) return false; return true; @@ -409,7 +414,7 @@ bool parseToken(Arena& arena, TokenRef& token, StringRef signedToken) { auto b64urlSignature = signedToken; if (b64urlHeader.empty() || b64urlPayload.empty() || b64urlSignature.empty()) return false; - if (!parseHeaderPart(token, b64urlHeader)) + if (!parseHeaderPart(arena, token, b64urlHeader)) return false; if (!parsePayloadPart(arena, token, b64urlPayload)) return false; @@ -432,7 +437,7 @@ bool verifyToken(StringRef signedToken, PublicKey publicKey) { return false; auto sig = optSig.get(); auto parsedToken = TokenRef(); - if (!parseHeaderPart(parsedToken, b64urlHeader)) + if (!parseHeaderPart(arena, parsedToken, b64urlHeader)) return false; auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm); if (!checkVerifyAlgorithm(verifyAlgo, publicKey)) @@ -446,6 +451,7 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) { } auto ret = TokenRef{}; ret.algorithm = alg; + ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1); ret.issuer = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1); ret.subject = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1); ret.tokenId = genRandomAlphanumStringRef(arena, rng, 31); @@ -457,7 +463,6 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) { ret.issuedAtUnixTime = timer_int() / 1'000'000'000ul; ret.notBeforeUnixTime = ret.issuedAtUnixTime.get(); ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1); - ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1); auto numTenants = rng.randomInt(1, 3); auto tenants = new (arena) StringRef[numTenants]; for (auto i = 0; i < numTenants; i++) @@ -553,7 +558,7 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") { auto arena = Arena(); auto tokenStr = t.toStringRef(arena); auto tokenStrExpected = - "alg=ES256 iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 kid=keyId jti=tokenId tenants=[tenant1,tenant2]"_sr; + "alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[tenant1,tenant2]"_sr; if (tokenStr != tokenStrExpected) { fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView()); ASSERT(false); diff --git a/fdbrpc/include/fdbrpc/TokenSign.h b/fdbrpc/include/fdbrpc/TokenSign.h index 49efd1f067..13217f16cd 100644 --- a/fdbrpc/include/fdbrpc/TokenSign.h +++ b/fdbrpc/include/fdbrpc/TokenSign.h @@ -85,6 +85,7 @@ namespace authz::jwt { struct TokenRef { // header part ("typ": "JWT" implicitly enforced) Algorithm algorithm; // alg + StringRef keyId; // kid // payload part Optional issuer; // iss Optional subject; // sub @@ -92,7 +93,6 @@ struct TokenRef { Optional issuedAtUnixTime; // iat Optional expiresAtUnixTime; // exp Optional notBeforeUnixTime; // nbf - Optional keyId; // kid Optional tokenId; // jti Optional> tenants; // tenants // signature part @@ -113,7 +113,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey); // Parse passed b64url-encoded header part and materialize its contents into tokenOut, // using memory allocated from arena -bool parseHeaderPart(TokenRef& tokenOut, StringRef b64urlHeaderIn); +bool parseHeaderPart(Arena& arena, TokenRef& tokenOut, StringRef b64urlHeaderIn); // Parse passed b64url-encoded payload part and materialize its contents into tokenOut, // using memory allocated from arena