Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK-3171] Fix header claims serialization #549

Merged
merged 2 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions lib/src/main/java/com/auth0/jwt/JWTCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTCreationException;
import com.auth0.jwt.exceptions.SignatureGenerationException;
import com.auth0.jwt.impl.ClaimsHolder;
import com.auth0.jwt.impl.PayloadSerializer;
import com.auth0.jwt.impl.PublicClaims;
import com.auth0.jwt.impl.*;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -34,16 +32,17 @@ public final class JWTCreator {
static {
mapper = new ObjectMapper();
module = new SimpleModule();
module.addSerializer(ClaimsHolder.class, new PayloadSerializer());
module.addSerializer(PayloadClaimsHolder.class, new PayloadSerializer());
module.addSerializer(HeaderClaimsHolder.class, new HeaderSerializer());
mapper.registerModule(module);
mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
}

private JWTCreator(Algorithm algorithm, Map<String, Object> headerClaims, Map<String, Object> payloadClaims) throws JWTCreationException {
this.algorithm = algorithm;
try {
headerJson = mapper.writeValueAsString(headerClaims);
payloadJson = mapper.writeValueAsString(new ClaimsHolder(payloadClaims));
headerJson = mapper.writeValueAsString(new HeaderClaimsHolder(headerClaims));
payloadJson = mapper.writeValueAsString(new PayloadClaimsHolder(payloadClaims));
} catch (JsonProcessingException e) {
throw new JWTCreationException("Some of the Claims couldn't be converted to a valid JSON format.", e);
}
Expand Down
6 changes: 3 additions & 3 deletions lib/src/main/java/com/auth0/jwt/impl/ClaimsHolder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
/**
* The ClaimsHolder class is just a wrapper for the Map of Claims used for building a JWT.
*/
public final class ClaimsHolder {
public abstract class ClaimsHolder {
private Map<String, Object> claims;

public ClaimsHolder(Map<String, Object> claims) {
this.claims = claims == null ? new HashMap<String, Object>() : claims;
protected ClaimsHolder(Map<String, Object> claims) {
this.claims = claims == null ? new HashMap<>() : claims;
}

Map<String, Object> getClaims() {
Expand Down
86 changes: 86 additions & 0 deletions lib/src/main/java/com/auth0/jwt/impl/ClaimsSerializer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package com.auth0.jwt.impl;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;

import java.io.IOException;
import java.time.Instant;
import java.util.Date;
import java.util.List;
import java.util.Map;

/**
* Custom serializer used to write the resulting JWT.
*
* @param <T> the type this serializer operates on.
*/
public class ClaimsSerializer<T extends ClaimsHolder> extends StdSerializer<T> {

public ClaimsSerializer(Class<T> t) {
super(t);
}

@Override
public void serialize(T holder, JsonGenerator gen, SerializerProvider provider) throws IOException {
gen.writeStartObject();
for (Map.Entry<String, Object> entry : holder.getClaims().entrySet()) {
writeClaim(entry, gen);
}
gen.writeEndObject();
}

/**
* Writes the given entry to the JSON representation. Custom claim serialization handling can override this method
* to provide use-case specific serialization. Implementors who override this method must write the field name and the
* field value.
*
* @param entry The entry that corresponds to the JSON field to write
* @param gen The {@code JsonGenerator} to use
* @throws IOException
*/
protected void writeClaim(Map.Entry<String, Object> entry, JsonGenerator gen) throws IOException {
gen.writeFieldName(entry.getKey());
handleSerialization(entry.getValue(), gen);
}

private static void handleSerialization(Object value, JsonGenerator gen) throws IOException {
if (value instanceof Date) {
gen.writeNumber(dateToSeconds((Date) value));
} else if (value instanceof Instant) { // EXPIRES_AT, ISSUED_AT, NOT_BEFORE, custom Instant claims
gen.writeNumber(instantToSeconds((Instant) value));
} else if (value instanceof Map) {
serializeMap((Map<?, ?>) value, gen);
} else if (value instanceof List) {
serializeList((List<?>) value, gen);
} else {
gen.writeObject(value);
}
}

private static void serializeMap(Map<?, ?> map, JsonGenerator gen) throws IOException {
gen.writeStartObject();
for (Map.Entry<?, ?> entry : map.entrySet()) {
gen.writeFieldName((String) entry.getKey());
Object value = entry.getValue();
handleSerialization(value, gen);
}
gen.writeEndObject();
}

private static void serializeList(List<?> list, JsonGenerator gen) throws IOException {
gen.writeStartArray();
for (Object entry : list) {
handleSerialization(entry, gen);
}
gen.writeEndArray();
}

private static long instantToSeconds(Instant instant) {
return instant.getEpochSecond();
}

private static long dateToSeconds(Date date) {
return date.getTime() / 1000;
}
}
12 changes: 12 additions & 0 deletions lib/src/main/java/com/auth0/jwt/impl/HeaderClaimsHolder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.auth0.jwt.impl;

import java.util.Map;

/**
* Holds the header claims when serializing a JWT.
*/
public final class HeaderClaimsHolder extends ClaimsHolder {
public HeaderClaimsHolder(Map<String, Object> claims) {
super(claims);
}
}
10 changes: 10 additions & 0 deletions lib/src/main/java/com/auth0/jwt/impl/HeaderSerializer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.auth0.jwt.impl;

/**
* Responsible for serializing a JWT's header representation to JSON.
*/
public class HeaderSerializer extends ClaimsSerializer<HeaderClaimsHolder> {
public HeaderSerializer() {
super(HeaderClaimsHolder.class);
}
}
12 changes: 12 additions & 0 deletions lib/src/main/java/com/auth0/jwt/impl/PayloadClaimsHolder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.auth0.jwt.impl;

import java.util.Map;

/**
* Holds the payload claims when serializing a JWT.
*/
public final class PayloadClaimsHolder extends ClaimsHolder {
public PayloadClaimsHolder(Map<String, Object> claims) {
super(claims);
}
}
85 changes: 17 additions & 68 deletions lib/src/main/java/com/auth0/jwt/impl/PayloadSerializer.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.auth0.jwt.impl;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;

import java.io.IOException;
import java.time.Instant;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* Jackson serializer implementation for converting into JWT Payload parts.
Expand All @@ -15,32 +15,26 @@
* <p>
* This class is thread-safe.
*/
public class PayloadSerializer extends StdSerializer<ClaimsHolder> {

public class PayloadSerializer extends ClaimsSerializer<PayloadClaimsHolder> {
public PayloadSerializer() {
this(null);
}

private PayloadSerializer(Class<ClaimsHolder> t) {
super(t);
super(PayloadClaimsHolder.class);
}

@Override
public void serialize(ClaimsHolder holder, JsonGenerator gen, SerializerProvider provider) throws IOException {

gen.writeStartObject();
for (Map.Entry<String, Object> e : holder.getClaims().entrySet()) {
if (PublicClaims.AUDIENCE.equals(e.getKey())) {
writeAudience(gen, e);
} else {
gen.writeFieldName(e.getKey());
handleSerialization(e.getValue(), gen);
}
protected void writeClaim(Map.Entry<String, Object> entry, JsonGenerator gen) throws IOException {
if (PublicClaims.AUDIENCE.equals(entry.getKey())) {
writeAudience(gen, entry);
} else {
super.writeClaim(entry, gen);
}

gen.writeEndObject();
}

/**
* Audience may be a list of strings or a single string. This is needed to properly handle the aud claim when
* added with the {@linkplain com.auth0.jwt.JWTCreator.Builder#withPayload(Map)} method.
*/

//
private void writeAudience(JsonGenerator gen, Map.Entry<String, Object> e) throws IOException {
if (e.getValue() instanceof String) {
gen.writeFieldName(e.getKey());
Expand Down Expand Up @@ -70,49 +64,4 @@ private void writeAudience(JsonGenerator gen, Map.Entry<String, Object> e) throw
}
}
}

/**
* Serializes {@linkplain Instant} to epoch second values, traversing maps and lists as needed.
* @param value the object to serialize
* @param gen the JsonGenerator to use for JSON serialization
*/
private void handleSerialization(Object value, JsonGenerator gen) throws IOException {
if (value instanceof Date) {
gen.writeNumber(dateToSeconds((Date) value));
} else if (value instanceof Instant) { // EXPIRES_AT, ISSUED_AT, NOT_BEFORE, custom Instant claims
gen.writeNumber(instantToSeconds((Instant) value));
} else if (value instanceof Map) {
serializeMap((Map<?, ?>) value, gen);
} else if (value instanceof List) {
serializeList((List<?>) value, gen);
} else {
gen.writeObject(value);
}
}

private void serializeMap(Map<?, ?> map, JsonGenerator gen) throws IOException {
gen.writeStartObject();
for (Map.Entry<?, ?> entry : map.entrySet()) {
gen.writeFieldName((String) entry.getKey());
Object value = entry.getValue();
handleSerialization(value, gen);
}
gen.writeEndObject();
}

private void serializeList(List<?> list, JsonGenerator gen) throws IOException {
gen.writeStartArray();
for (Object entry : list) {
handleSerialization(entry, gen);
}
gen.writeEndArray();
}

private long instantToSeconds(Instant instant) {
return instant.getEpochSecond();
}

private long dateToSeconds(Date date) {
return date.getTime() / 1000;
}
}
40 changes: 34 additions & 6 deletions lib/src/test/java/com/auth0/jwt/JWTCreatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.auth0.jwt.interfaces.ECDSAKeyProvider;
import com.auth0.jwt.interfaces.RSAKeyProvider;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand All @@ -16,8 +15,8 @@
import java.time.Instant;
import java.util.*;

import static org.hamcrest.Matchers.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -39,19 +38,48 @@ public void shouldThrowWhenRequestingSignWithoutAlgorithm() {
.sign(null);
}

@SuppressWarnings("Convert2Diamond")
@Test
public void shouldAddHeaderClaim() {
Map<String, Object> header = new HashMap<String, Object>();
header.put("asd", 123);
Date date = new Date(123000);
Instant instant = date.toInstant();

List<Object> list = Arrays.asList(date, instant);
Map<String, Object> map = new HashMap<>();
map.put("date", date);
map.put("instant", instant);

List<Object> expectedSerializedList = Arrays.asList(date.getTime() / 1000, instant.getEpochSecond());
Map<String, Object> expectedSerializedMap = new HashMap<>();
expectedSerializedMap.put("date", date.getTime() / 1000);
expectedSerializedMap.put("instant", instant.getEpochSecond());

Map<String, Object> header = new HashMap<>();
header.put("string", "string");
header.put("int", 42);
header.put("long", 4200000000L);
header.put("double", 123.123);
header.put("bool", true);
header.put("date", date);
header.put("instant", instant);
header.put("list", list);
header.put("map", map);

String signed = JWTCreator.init()
.withHeader(header)
.sign(Algorithm.HMAC256("secret"));

assertThat(signed, is(notNullValue()));
String[] parts = signed.split("\\.");
String headerJson = new String(Base64.getUrlDecoder().decode(parts[0]), StandardCharsets.UTF_8);
assertThat(headerJson, JsonMatcher.hasEntry("asd", 123));
assertThat(headerJson, JsonMatcher.hasEntry("string", "string"));
assertThat(headerJson, JsonMatcher.hasEntry("int", 42));
assertThat(headerJson, JsonMatcher.hasEntry("long", 4200000000L));
assertThat(headerJson, JsonMatcher.hasEntry("double", 123.123));
assertThat(headerJson, JsonMatcher.hasEntry("bool", true));
assertThat(headerJson, JsonMatcher.hasEntry("date", 123));
assertThat(headerJson, JsonMatcher.hasEntry("instant", 123));
assertThat(headerJson, JsonMatcher.hasEntry("list", expectedSerializedList));
assertThat(headerJson, JsonMatcher.hasEntry("map", expectedSerializedMap));
}

@Test
Expand Down
30 changes: 30 additions & 0 deletions lib/src/test/java/com/auth0/jwt/JWTDecoderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.time.Instant;
import java.util.Base64;
import java.util.Date;
import java.util.List;
import java.util.Map;

import static org.hamcrest.MatcherAssert.assertThat;
Expand Down Expand Up @@ -347,6 +348,35 @@ public void shouldSerializeAndDeserialize() throws Exception {
is(equalTo(deserializedJwt.getClaims().get("extraClaim").asString())));
}

@Test
public void shouldDecodeHeaderClaims() {
String jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImRhdGUiOjE2NDczNTgzMjUsInN0cmluZyI6InN0cmluZyIsImJvb2wiOnRydWUsImRvdWJsZSI6MTIzLjEyMywibGlzdCI6WzE2NDczNTgzMjVdLCJtYXAiOnsiZGF0ZSI6MTY0NzM1ODMyNSwiaW5zdGFudCI6MTY0NzM1ODMyNX0sImludCI6NDIsImxvbmciOjQyMDAwMDAwMDAsImluc3RhbnQiOjE2NDczNTgzMjV9.eyJpYXQiOjE2NDczNjA4ODF9.S2nZDM03ZDvLMeJLWOIqWZ9kmYHZUueyQiIZCCjYNL8";

Instant expectedInstant = Instant.ofEpochSecond(1647358325);
Date expectedDate = Date.from(expectedInstant);

DecodedJWT decoded = JWT.decode(jwt);
assertThat(decoded, is(notNullValue()));
assertThat(decoded.getHeaderClaim("date").asDate(), is(expectedDate));
assertThat(decoded.getHeaderClaim("instant").asInstant(), is(expectedInstant));
assertThat(decoded.getHeaderClaim("string").asString(), is("string"));
assertThat(decoded.getHeaderClaim("bool").asBoolean(), is(true));
assertThat(decoded.getHeaderClaim("double").asDouble(), is(123.123));
assertThat(decoded.getHeaderClaim("int").asInt(), is(42));
assertThat(decoded.getHeaderClaim("long").asLong(), is(4200000000L));

Map<String, Object> headerMap = decoded.getHeaderClaim("map").asMap();
assertThat(headerMap, is(notNullValue()));
assertThat(headerMap.size(), is(2));
assertThat(headerMap, hasEntry("date", 1647358325));
assertThat(headerMap, hasEntry("instant", 1647358325));

List<Object> headerList = decoded.getHeaderClaim("list").asList(Object.class);
assertThat(headerList, is(notNullValue()));
assertThat(headerList.size(), is(1));
assertThat(headerList, contains(1647358325));
}

//Helper Methods

private DecodedJWT customJWT(String jsonHeader, String jsonPayload, String signature) {
Expand Down
Loading