Skip to content

Commit

Permalink
Add self-signed certificate Mutual-TLS client authentication method
Browse files Browse the repository at this point in the history
Issue gh-101

Closes gh-1559
  • Loading branch information
jgrandja committed Mar 29, 2024
1 parent a0b7f6f commit 79fe240
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

/**
* An {@link AuthenticationProvider} implementation used for OAuth 2.0 Client Authentication,
* which authenticates the client {@code X509Certificate} received when the {@code tls_client_auth} authentication method is used.
* which authenticates the client {@code X509Certificate} received
* when the {@code tls_client_auth} or {@code self_signed_tls_client_auth} authentication method is used.
*
* @author Joe Grandja
* @since 1.3
Expand All @@ -51,10 +52,14 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
private static final ClientAuthenticationMethod TLS_CLIENT_AUTH_AUTHENTICATION_METHOD =
new ClientAuthenticationMethod("tls_client_auth");
private static final ClientAuthenticationMethod SELF_SIGNED_TLS_CLIENT_AUTH_AUTHENTICATION_METHOD =
new ClientAuthenticationMethod("self_signed_tls_client_auth");
private final Log logger = LogFactory.getLog(getClass());
private final RegisteredClientRepository registeredClientRepository;
private final CodeVerifierAuthenticator codeVerifierAuthenticator;
private Consumer<OAuth2ClientAuthenticationContext> certificateVerifier = this::verifyX509CertificateSubjectDN;
private final Consumer<OAuth2ClientAuthenticationContext> selfSignedCertificateVerifier =
new X509SelfSignedCertificateVerifier();
private Consumer<OAuth2ClientAuthenticationContext> certificateVerifier = this::verifyX509Certificate;

/**
* Constructs a {@code X509ClientCertificateAuthenticationProvider} using the provided parameters.
Expand All @@ -75,7 +80,8 @@ public Authentication authenticate(Authentication authentication) throws Authent
OAuth2ClientAuthenticationToken clientAuthentication =
(OAuth2ClientAuthenticationToken) authentication;

if (!TLS_CLIENT_AUTH_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) {
if (!TLS_CLIENT_AUTH_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod()) &&
!SELF_SIGNED_TLS_CLIENT_AUTH_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) {
return null;
}

Expand Down Expand Up @@ -127,7 +133,8 @@ public boolean supports(Class<?> authentication) {
/**
* Sets the {@code Consumer} providing access to the {@link OAuth2ClientAuthenticationContext}
* and is responsible for verifying the client {@code X509Certificate} associated in the {@link OAuth2ClientAuthenticationToken}.
* The default implementation verifies the {@link ClientSettings#getX509CertificateSubjectDN() expected subject distinguished name}.
* The default implementation for the {@code tls_client_auth} authentication method
* verifies the {@link ClientSettings#getX509CertificateSubjectDN() expected subject distinguished name}.
*
* <p>
* <b>NOTE:</b> If verification fails, an {@link OAuth2AuthenticationException} MUST be thrown.
Expand All @@ -139,6 +146,15 @@ public void setCertificateVerifier(Consumer<OAuth2ClientAuthenticationContext> c
this.certificateVerifier = certificateVerifier;
}

private void verifyX509Certificate(OAuth2ClientAuthenticationContext clientAuthenticationContext) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
if (SELF_SIGNED_TLS_CLIENT_AUTH_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) {
this.selfSignedCertificateVerifier.accept(clientAuthenticationContext);
} else {
verifyX509CertificateSubjectDN(clientAuthenticationContext);
}
}

private void verifyX509CertificateSubjectDN(OAuth2ClientAuthenticationContext clientAuthenticationContext) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;

import java.net.URI;
import java.net.URISyntaxException;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import javax.security.auth.x500.X500Principal;

import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSet;

import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

/**
* The default {@code X509Certificate} verifier for the {@code self_signed_tls_client_auth} authentication method.
*
* @author Joe Grandja
* @since 1.3
* @see X509ClientCertificateAuthenticationProvider#setCertificateVerifier(Consumer)
*/
final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAuthenticationContext> {
private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
private static final JWKMatcher HAS_X509_CERT_CHAIN_MATCHER = new JWKMatcher.Builder().hasX509CertChain(true).build();
private final Function<RegisteredClient, JWKSet> jwkSetSupplier = new JwkSetSupplier();

@Override
public void accept(OAuth2ClientAuthenticationContext clientAuthenticationContext) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
X509Certificate clientCertificate = clientCertificateChain[0];

X500Principal issuer = clientCertificate.getIssuerX500Principal();
X500Principal subject = clientCertificate.getSubjectX500Principal();
if (issuer == null || !issuer.equals(subject)) {
throwInvalidClient("x509_certificate_issuer");
}

JWKSet jwkSet = this.jwkSetSupplier.apply(registeredClient);

boolean publicKeyMatches = false;
for (JWK jwk : jwkSet.filter(HAS_X509_CERT_CHAIN_MATCHER).getKeys()) {
X509Certificate x509Certificate = jwk.getParsedX509CertChain().get(0);
PublicKey publicKey = x509Certificate.getPublicKey();
if (Arrays.equals(clientCertificate.getPublicKey().getEncoded(), publicKey.getEncoded())) {
publicKeyMatches = true;
break;
}
}

if (!publicKeyMatches) {
throwInvalidClient("x509_certificate");
}
}

private static void throwInvalidClient(String parameterName) {
throwInvalidClient(parameterName, null);
}

private static void throwInvalidClient(String parameterName, Throwable cause) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName,
ERROR_URI
);
throw new OAuth2AuthenticationException(error, error.toString(), cause);
}

private static class JwkSetSupplier implements Function<RegisteredClient, JWKSet> {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
private final RestOperations restOperations;
private final Map<String, Supplier<JWKSet>> jwkSets = new ConcurrentHashMap<>();

private JwkSetSupplier() {
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
requestFactory.setConnectTimeout(15_000);
requestFactory.setReadTimeout(15_000);
this.restOperations = new RestTemplate(requestFactory);
}

@Override
public JWKSet apply(RegisteredClient registeredClient) {
Supplier<JWKSet> jwkSetSupplier = this.jwkSets.computeIfAbsent(
registeredClient.getId(), (key) -> {
if (!StringUtils.hasText(registeredClient.getClientSettings().getJwkSetUrl())) {
throwInvalidClient("client_jwk_set_url");
}
return new JwkSetHolder(registeredClient.getClientSettings().getJwkSetUrl());
});
return jwkSetSupplier.get();
}

private JWKSet retrieve(String jwkSetUrl) {
URI jwkSetUri = null;
try {
jwkSetUri = new URI(jwkSetUrl);
} catch (URISyntaxException ex) {
throwInvalidClient("jwk_set_uri", ex);
}

HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, jwkSetUri);
ResponseEntity<String> response = null;
try {
response = this.restOperations.exchange(request, String.class);
} catch (Exception ex) {
throwInvalidClient("jwk_set_response_error", ex);
}
if (response.getStatusCode().value() != 200) {
throwInvalidClient("jwk_set_response_status");
}

JWKSet jwkSet = null;
try {
jwkSet = JWKSet.parse(response.getBody());
} catch (ParseException ex) {
throwInvalidClient("jwk_set_response_body", ex);
}

return jwkSet;
}

private class JwkSetHolder implements Supplier<JWKSet> {
private final String jwkSetUrl;
private JWKSet jwkSet;

private JwkSetHolder(String jwkSetUrl) {
this.jwkSetUrl = jwkSetUrl;
}

@Override
public JWKSet get() {
if (this.jwkSet == null) {
this.jwkSet = retrieve(this.jwkSetUrl);
}
return this.jwkSet;
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ private static Consumer<List<String>> clientAuthenticationMethods() {
authenticationMethods.add(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue());
authenticationMethods.add(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue());
authenticationMethods.add("tls_client_auth");
authenticationMethods.add("self_signed_tls_client_auth");
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private static Consumer<List<String>> clientAuthenticationMethods() {
authenticationMethods.add(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue());
authenticationMethods.add(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue());
authenticationMethods.add("tls_client_auth");
authenticationMethods.add("self_signed_tls_client_auth");
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
/**
* Attempts to extract a client {@code X509Certificate} chain from {@link HttpServletRequest}
* and then converts to an {@link OAuth2ClientAuthenticationToken} used for authenticating the client
* using the {@code tls_client_auth} method.
* using the {@code tls_client_auth} or {@code self_signed_tls_client_auth} method.
*
* @author Joe Grandja
* @since 1.3
Expand All @@ -46,13 +46,15 @@
public final class X509ClientCertificateAuthenticationConverter implements AuthenticationConverter {
private static final ClientAuthenticationMethod TLS_CLIENT_AUTH_AUTHENTICATION_METHOD =
new ClientAuthenticationMethod("tls_client_auth");
private static final ClientAuthenticationMethod SELF_SIGNED_TLS_CLIENT_AUTH_AUTHENTICATION_METHOD =
new ClientAuthenticationMethod("self_signed_tls_client_auth");

@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
X509Certificate[] clientCertificateChain =
(X509Certificate[]) request.getAttribute("jakarta.servlet.request.X509Certificate");
if (clientCertificateChain == null || clientCertificateChain.length <= 1) {
if (clientCertificateChain == null || clientCertificateChain.length == 0) {
return null;
}

Expand All @@ -68,7 +70,12 @@ public Authentication convert(HttpServletRequest request) {
Map<String, Object> additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(
request, OAuth2ParameterNames.CLIENT_ID);

return new OAuth2ClientAuthenticationToken(clientId, TLS_CLIENT_AUTH_AUTHENTICATION_METHOD,
ClientAuthenticationMethod clientAuthenticationMethod =
clientCertificateChain.length == 1 ?
SELF_SIGNED_TLS_CLIENT_AUTH_AUTHENTICATION_METHOD :
TLS_CLIENT_AUTH_AUTHENTICATION_METHOD;

return new OAuth2ClientAuthenticationToken(clientId, clientAuthenticationMethod,
clientCertificateChain, additionalParameters);
}

Expand Down
Loading

0 comments on commit 79fe240

Please sign in to comment.