Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK-3231] Added support for multiple checks on a single claim #573

Merged
merged 8 commits into from
Apr 13, 2022
Merged
153 changes: 95 additions & 58 deletions lib/src/main/java/com/auth0/jwt/JWTVerifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.auth0.jwt.impl.PublicClaims;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.impl.ExpectedCheckHolder;
import com.auth0.jwt.interfaces.Verification;

import java.time.Clock;
Expand All @@ -25,12 +26,12 @@
*/
public final class JWTVerifier implements com.auth0.jwt.interfaces.JWTVerifier {
private final Algorithm algorithm;
final Map<String, BiPredicate<Claim, DecodedJWT>> expectedChecks;
final List<ExpectedCheckHolder> expectedChecks;
private final JWTParser parser;

JWTVerifier(Algorithm algorithm, Map<String, BiPredicate<Claim, DecodedJWT>> expectedChecks) {
JWTVerifier(Algorithm algorithm, List<ExpectedCheckHolder> expectedChecks) {
this.algorithm = algorithm;
this.expectedChecks = Collections.unmodifiableMap(expectedChecks);
this.expectedChecks = Collections.unmodifiableList(expectedChecks);
this.parser = new JWTParser();
}

Expand All @@ -50,7 +51,7 @@ static Verification init(Algorithm algorithm) throws IllegalArgumentException {
*/
public static class BaseVerification implements Verification {
private final Algorithm algorithm;
private final Map<String, BiPredicate<Claim, DecodedJWT>> expectedChecks;
private final List<ExpectedCheckHolder> expectedChecks;
private long defaultLeeway;
private final Map<String, Long> customLeeways;
private boolean ignoreIssuedAt;
Expand All @@ -62,15 +63,18 @@ public static class BaseVerification implements Verification {
}

this.algorithm = algorithm;
this.expectedChecks = new LinkedHashMap<>();
this.expectedChecks = new ArrayList<>();
this.customLeeways = new HashMap<>();
this.defaultLeeway = 0;
}

@Override
public Verification withIssuer(String... issuer) {
List<String> value = isNullOrEmpty(issuer) ? null : Arrays.asList(issuer);
checkIfNeedToRemove(PublicClaims.ISSUER, value, ((claim, decodedJWT) -> {
addCheck(PublicClaims.ISSUER, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (value == null || !value.contains(claim.asString())) {
throw new IncorrectClaimException(
"The Claim 'iss' value doesn't match the required issuer.", PublicClaims.ISSUER, claim);
Expand All @@ -82,23 +86,40 @@ public Verification withIssuer(String... issuer) {

@Override
public Verification withSubject(String subject) {
checkIfNeedToRemove(PublicClaims.SUBJECT, subject, (claim, decodedJWT) -> subject.equals(claim.asString()));
addCheck(PublicClaims.SUBJECT, (claim, decodedJWT) ->
verifyNull(claim, subject) || subject.equals(claim.asString()));
return this;
}

@Override
public Verification withAudience(String... audience) {
List<String> value = isNullOrEmpty(audience) ? null : Arrays.asList(audience);
checkIfNeedToRemove(PublicClaims.AUDIENCE, value, ((claim, decodedJWT) ->
assertValidAudienceClaim(claim, decodedJWT.getAudience(), value, true)));
addCheck(PublicClaims.AUDIENCE, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (!assertValidAudienceClaim(decodedJWT.getAudience(), value, true)) {
throw new IncorrectClaimException("The Claim 'aud' value doesn't contain the required audience.",
PublicClaims.AUDIENCE, claim);
}
return true;
}));
return this;
}

@Override
public Verification withAnyOfAudience(String... audience) {
List<String> value = isNullOrEmpty(audience) ? null : Arrays.asList(audience);
checkIfNeedToRemove(PublicClaims.AUDIENCE, value, ((claim, decodedJWT) ->
assertValidAudienceClaim(claim, decodedJWT.getAudience(), value, false)));
addCheck(PublicClaims.AUDIENCE, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (!assertValidAudienceClaim(decodedJWT.getAudience(), value, false)) {
throw new IncorrectClaimException("The Claim 'aud' value doesn't contain the required audience.",
PublicClaims.AUDIENCE, claim);
}
return true;
}));
return this;
}

Expand Down Expand Up @@ -138,14 +159,16 @@ public Verification ignoreIssuedAt() {

@Override
public Verification withJWTId(String jwtId) {
checkIfNeedToRemove(PublicClaims.JWT_ID, jwtId, ((claim, decodedJWT) -> jwtId.equals(claim.asString())));
addCheck(PublicClaims.JWT_ID, ((claim, decodedJWT) ->
verifyNull(claim, jwtId) || jwtId.equals(claim.asString())));
return this;
}

@Override
public Verification withClaimPresence(String name) throws IllegalArgumentException {
assertNonNull(name);
withClaim(name, ((claim, decodedJWT) -> assertClaimPresence(name, claim)));
//since addCheck already checks presence, we just return true
withClaim(name, ((claim, decodedJWT) -> true));
return this;
}

Expand All @@ -159,35 +182,40 @@ public Verification withNullClaim(String name) throws IllegalArgumentException {
@Override
public Verification withClaim(String name, Boolean value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asBoolean())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asBoolean())));
return this;
}

@Override
public Verification withClaim(String name, Integer value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asInt())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asInt())));
return this;
}

@Override
public Verification withClaim(String name, Long value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asLong())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asLong())));
return this;
}

@Override
public Verification withClaim(String name, Double value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asDouble())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asDouble())));
return this;
}

@Override
public Verification withClaim(String name, String value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asString())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asString())));
return this;
}

Expand All @@ -201,37 +229,42 @@ public Verification withClaim(String name, Instant value) throws IllegalArgument
assertNonNull(name);
// Since date-time claims are serialized as epoch seconds,
// we need to compare them with only seconds-granularity
checkIfNeedToRemove(name, value,
((claim, decodedJWT) -> value.truncatedTo(ChronoUnit.SECONDS).equals(claim.asInstant())));
addCheck(name,
((claim, decodedJWT) -> verifyNull(claim, value)
|| value.truncatedTo(ChronoUnit.SECONDS).equals(claim.asInstant())));
return this;
}

@Override
public Verification withClaim(String name, BiPredicate<Claim, DecodedJWT> predicate)
throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, predicate, predicate);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, predicate)
|| predicate.test(claim, decodedJWT)));
return this;
}

@Override
public Verification withArrayClaim(String name, String... items) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, items, ((claim, decodedJWT) -> assertValidCollectionClaim(claim, items)));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}

@Override
public Verification withArrayClaim(String name, Integer... items) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, items, ((claim, decodedJWT) -> assertValidCollectionClaim(claim, items)));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}

@Override
public Verification withArrayClaim(String name, Long... items) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, items, ((claim, decodedJWT) -> assertValidCollectionClaim(claim, items)));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}

Expand Down Expand Up @@ -268,13 +301,13 @@ private void addMandatoryClaimChecks() {
long notBeforeLeeway = getLeewayFor(PublicClaims.NOT_BEFORE);
long issuedAtLeeway = getLeewayFor(PublicClaims.ISSUED_AT);

expectedChecks.put(PublicClaims.EXPIRES_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.EXPIRES_AT, claim, expiresAtLeeway, true));
expectedChecks.put(PublicClaims.NOT_BEFORE, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.NOT_BEFORE, claim, notBeforeLeeway, false));
expectedChecks.add(constructExpectedCheck(PublicClaims.EXPIRES_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.EXPIRES_AT, claim, expiresAtLeeway, true)));
expectedChecks.add(constructExpectedCheck(PublicClaims.NOT_BEFORE, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.NOT_BEFORE, claim, notBeforeLeeway, false)));
if (!ignoreIssuedAt) {
expectedChecks.put(PublicClaims.ISSUED_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.ISSUED_AT, claim, issuedAtLeeway, false));
expectedChecks.add(constructExpectedCheck(PublicClaims.ISSUED_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.ISSUED_AT, claim, issuedAtLeeway, false)));
}
}

Expand All @@ -294,8 +327,7 @@ private boolean assertValidCollectionClaim(Claim claim, Object[] expectedClaimVa
}
}
} else {
claimArr = claim.isNull() || claim.isMissing()
? Collections.emptyList() : Arrays.asList(claim.as(Object[].class));
claimArr = Arrays.asList(claim.as(Object[].class));
}
List<Object> valueArr = Arrays.asList(expectedClaimValue);
return claimArr.containsAll(valueArr);
Expand Down Expand Up @@ -329,24 +361,12 @@ private boolean assertInstantIsPast(Instant claimVal, long leeway, Instant now)
}

private boolean assertValidAudienceClaim(
Claim claim,
List<String> audience,
List<String> values,
boolean shouldContainAll
) {
if (audience == null || (shouldContainAll && !audience.containsAll(values))
|| (!shouldContainAll && Collections.disjoint(audience, values))) {
throw new IncorrectClaimException(
"The Claim 'aud' value doesn't contain the required audience.", PublicClaims.AUDIENCE, claim);
}
return true;
}

private boolean assertClaimPresence(String name, Claim claim) {
if (claim.isMissing()) {
throw new MissingClaimException(name);
}
return true;
return !(audience == null || (shouldContainAll && !audience.containsAll(values))
|| (!shouldContainAll && Collections.disjoint(audience, values)));
}

private void assertPositive(long leeway) {
Expand All @@ -361,13 +381,31 @@ private void assertNonNull(String name) {
}
}

private void checkIfNeedToRemove(String name, Object value, BiPredicate<Claim, DecodedJWT> predicate) {
if (value == null) {
expectedChecks.remove(name);
return;
}
expectedChecks.put(name, (claim, decodedJWT) -> assertClaimPresence(name, claim)
&& predicate.test(claim, decodedJWT));
private void addCheck(String name, BiPredicate<Claim, DecodedJWT> predicate) {
expectedChecks.add(constructExpectedCheck(name, (claim, decodedJWT) -> {
if (claim.isMissing()) {
throw new MissingClaimException(name);
}
return predicate.test(claim, decodedJWT);
}));
}

private ExpectedCheckHolder constructExpectedCheck(String claimName, BiPredicate<Claim, DecodedJWT> check) {
return new ExpectedCheckHolder() {
@Override
public String getClaimName() {
return claimName;
}

@Override
public boolean verify(Claim claim, DecodedJWT decodedJWT) {
return check.test(claim, decodedJWT);
}
};
}

private boolean verifyNull(Claim claim, Object value) {
return value == null && claim.isNull();
}

private boolean isNullOrEmpty(String[] args) {
Expand Down Expand Up @@ -431,15 +469,14 @@ private void verifyAlgorithm(DecodedJWT jwt, Algorithm expectedAlgorithm) throws
}
}

private void verifyClaims(DecodedJWT jwt, Map<String, BiPredicate<Claim, DecodedJWT>> claims)
private void verifyClaims(DecodedJWT jwt, List<ExpectedCheckHolder> expectedChecks)
throws TokenExpiredException, InvalidClaimException {
for (Map.Entry<String, BiPredicate<Claim, DecodedJWT>> entry : claims.entrySet()) {
for (ExpectedCheckHolder expectedCheck : expectedChecks) {
boolean isValid;
String claimName = entry.getKey();
BiPredicate<Claim, DecodedJWT> expectedCheck = entry.getValue();
String claimName = expectedCheck.getClaimName();
Claim claim = jwt.getClaim(claimName);

isValid = expectedCheck.test(claim, jwt);
isValid = expectedCheck.verify(claim, jwt);

if (!isValid) {
throw new IncorrectClaimException(
Expand Down
25 changes: 25 additions & 0 deletions lib/src/main/java/com/auth0/jwt/impl/ExpectedCheckHolder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.auth0.jwt.impl;

import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;

/**
* This holds the checks that are run to verify a JWT.
*/
public interface ExpectedCheckHolder {
/**
* The claim name that will be checked.
*
* @return the claim name
*/
String getClaimName();

/**
* The verification that will be run.
*
* @param claim the claim for which verification is done
* @param decodedJWT the JWT on which verification is done
* @return whether the verification passed or not
*/
boolean verify(Claim claim, DecodedJWT decodedJWT);
}
Loading