diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index d39df2ff8..efe5d86f9 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -20,6 +20,7 @@ dependencies { implementation("mysql:mysql-connector-java:8.0.33") implementation("software.amazon.awssdk:rds:2.21.11") implementation("software.amazon.awssdk:secretsmanager:2.21.21") + implementation("software.amazon.awssdk:sts:2.21.21") implementation("com.fasterxml.jackson.core:jackson-databind:2.16.0") implementation(project(":aws-advanced-jdbc-wrapper")) implementation("io.opentelemetry:opentelemetry-api:1.32.0") diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java new file mode 100644 index 000000000..d00e08638 --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon; + +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +public class FederatedAuthPluginExample { + + private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-identifier.XYZ.us-east-2.rds.amazonaws.com:5432/employees"; + + public static void main(String[] args) throws SQLException { + // Set the AWS Federated Authentication Connection Plugin parameters and the JDBC Wrapper parameters. + final Properties properties = new Properties(); + + // Enable the AWS Federated Authentication Connection Plugin. + properties.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); + properties.setProperty(FederatedAuthPlugin.IDP_NAME.name, "adfs"); + properties.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com"); + properties.setProperty(FederatedAuthPlugin.IAM_ROLE_ARN.name, "arn:aws:iam::123456789012:role/adfs_example_iam_role"); + properties.setProperty(FederatedAuthPlugin.IAM_IDP_ARN.name, "arn:aws:iam::123456789012:saml-provider/adfs_example"); + properties.setProperty(FederatedAuthPlugin.IAM_REGION.name, "us-east-2"); + properties.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, "someFederatedUsername@teamatlas.example.com"); + properties.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, "somePassword"); + properties.setProperty(PropertyDefinition.USER.name, "someIamUser"); + + // Try and make a connection: + try (final Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties); + final Statement statement = conn.createStatement(); + final ResultSet rs = statement.executeQuery("SELECT 1")) { + System.out.println(Util.getResult(rs)); + } + } +} diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 4690be11f..c64f2c3a2 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -28,7 +28,10 @@ plugins { dependencies { implementation("org.checkerframework:checker-qual:3.40.0") + compileOnly("org.apache.httpcomponents:httpclient:4.5.14") compileOnly("software.amazon.awssdk:rds:2.21.11") + compileOnly("software.amazon.awssdk:sts:2.21.11") + compileOnly("software.amazon.awssdk:iam:2.21.11") compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8 compileOnly("software.amazon.awssdk:secretsmanager:2.21.21") compileOnly("com.fasterxml.jackson.core:jackson-databind:2.16.0") @@ -58,9 +61,10 @@ dependencies { testImplementation("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8 testImplementation("org.springframework.boot:spring-boot-starter-jdbc:2.7.13") // 2.7.13 is the last version compatible with Java 8 testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8 - testImplementation("software.amazon.awssdk:rds:2.21.11") testImplementation("software.amazon.awssdk:ec2:2.21.12") + testImplementation("software.amazon.awssdk:rds:2.21.11") testImplementation("software.amazon.awssdk:secretsmanager:2.21.21") + testImplementation("software.amazon.awssdk:sts:2.21.11") testImplementation("org.testcontainers:testcontainers:1.19.1") testImplementation("org.testcontainers:mysql:1.19.1") testImplementation("org.testcontainers:postgresql:1.19.2") diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 6424b654d..b9de6e45a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -39,10 +39,10 @@ import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory; import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory; import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory; +import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory; import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory; import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPluginFactory; import software.amazon.jdbc.profile.ConfigurationProfile; -import software.amazon.jdbc.profile.DriverConfigurationProfiles; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.StringUtils; @@ -65,6 +65,7 @@ public class ConnectionPluginChainBuilder { put("failover", FailoverConnectionPluginFactory.class); put("iam", IamAuthConnectionPluginFactory.class); put("awsSecretsManager", AwsSecretsManagerConnectionPluginFactory.class); + put("federatedAuth", FederatedAuthPluginFactory.class); put("auroraStaleDns", AuroraStaleDnsPluginFactory.class); put("readWriteSplitting", ReadWriteSplittingPluginFactory.class); put("auroraConnectionTracker", AuroraConnectionTrackerPluginFactory.class); @@ -92,7 +93,8 @@ public class ConnectionPluginChainBuilder { put(HostMonitoringConnectionPluginFactory.class, 800); put(IamAuthConnectionPluginFactory.class, 900); put(AwsSecretsManagerConnectionPluginFactory.class, 1000); - put(LogQueryConnectionPluginFactory.class, 1100); + put(FederatedAuthPluginFactory.class, 1100); + put(LogQueryConnectionPluginFactory.class, 1200); put(ConnectTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN); put(ExecutionTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN); put(DeveloperConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java new file mode 100644 index 000000000..9a72e96db --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java @@ -0,0 +1,239 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.http.NameValuePair; +import org.apache.http.client.entity.UrlEncodedFormEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.message.BasicNameValuePair; +import org.apache.http.util.EntityUtils; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class AdfsCredentialsProviderFactory extends SamlCredentialsProviderFactory { + + public static final String IDP_NAME = "adfs"; + private static final String TELEMETRY_FETCH_SAML = "Fetch ADFS SAML Assertion"; + private static final Logger LOGGER = Logger.getLogger(AdfsCredentialsProviderFactory.class.getName()); + private final PluginService pluginService; + private final TelemetryFactory telemetryFactory; + private final CloseableHttpClient httpClient; + private TelemetryContext telemetryContext; + + public AdfsCredentialsProviderFactory(PluginService pluginService, CloseableHttpClient httpClient) { + this.pluginService = pluginService; + this.telemetryFactory = this.pluginService.getTelemetryFactory(); + this.httpClient = httpClient; + } + + @Override + String getSamlAssertion(final @NonNull Properties props) { + this.telemetryContext = telemetryFactory.openTelemetryContext(TELEMETRY_FETCH_SAML, TelemetryTraceLevel.NESTED); + + String uri = "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':' + + FederatedAuthPlugin.IDP_PORT.getString(props) + "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=" + + FederatedAuthPlugin.RELAYING_PARTY_ID.getString(props); + + try { + LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPageUrl", new Object[] {uri})); + validateURL(uri); + HttpGet get = new HttpGet(uri); + CloseableHttpResponse resp = httpClient.execute(get); + + if (resp.getStatusLine().getStatusCode() != 200) { + throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPageRequestFailed", + new Object[] { + resp.getStatusLine().getStatusCode(), + resp.getStatusLine().getReasonPhrase(), + EntityUtils.toString(resp.getEntity())})); + } + + String body = EntityUtils.toString(resp.getEntity()); + + String action = getFormAction(body); + if (!StringUtils.isNullOrEmpty(action) && action.startsWith("/")) { + uri = "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':' + FederatedAuthPlugin.IDP_PORT.getString(props) + action; + } + + LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionUrl", new Object[] {uri})); + + validateURL(uri); + HttpPost post = new HttpPost(uri); + post.setEntity(new UrlEncodedFormEntity(getParameters(body, props))); + resp = httpClient.execute(post); + if (resp.getStatusLine().getStatusCode() != 200) { + throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed", + new Object[] { + resp.getStatusLine().getStatusCode(), + resp.getStatusLine().getReasonPhrase(), + EntityUtils.toString(resp.getEntity())})); + } + + String content = EntityUtils.toString(resp.getEntity()); + Matcher matcher = FederatedAuthPlugin.SAML_RESPONSE_PATTERN.matcher(content); + if (!matcher.find()) { + throw new IOException(Messages.get("AdfsCredentialsProviderFactory.failedLogin", new Object[] {content})); + } + + return matcher.group(1); + } catch (IOException e) { + LOGGER.severe(Messages.get("AdfsCredentialsProviderFactory.getSamlAssertionFailed", new Object[] {e})); + this.telemetryContext.setSuccess(false); + this.telemetryContext.setException(e); + throw new RuntimeException(e); + } finally { + this.telemetryContext.closeContext(); + } + } + + private List getInputTagsFromHTML(String body) { + Set distinctInputTags = new HashSet<>(); + List inputTags = new ArrayList(); + Pattern inputTagPattern = Pattern.compile("", Pattern.DOTALL); + Matcher inputTagMatcher = inputTagPattern.matcher(body); + while (inputTagMatcher.find()) { + String tag = inputTagMatcher.group(0); + String tagNameLower = getValueByKey(tag, "name").toLowerCase(); + if (!tagNameLower.isEmpty() && distinctInputTags.add(tagNameLower)) { + inputTags.add(tag); + } + } + return inputTags; + } + + private String getValueByKey(String input, String key) { + Pattern keyValuePattern = Pattern.compile("(" + Pattern.quote(key) + ")\\s*=\\s*\"(.*?)\""); + Matcher keyValueMatcher = keyValuePattern.matcher(input); + if (keyValueMatcher.find()) { + return escapeHtmlEntity(keyValueMatcher.group(2)); + } + return ""; + } + + private String escapeHtmlEntity(String html) { + StringBuilder sb = new StringBuilder(html.length()); + int i = 0; + int length = html.length(); + while (i < length) { + char c = html.charAt(i); + if (c != '&') { + sb.append(c); + i++; + continue; + } + + if (html.startsWith("&", i)) { + sb.append('&'); + i += 5; + } else if (html.startsWith("'", i)) { + sb.append('\''); + i += 6; + } else if (html.startsWith(""", i)) { + sb.append('"'); + i += 6; + } else if (html.startsWith("<", i)) { + sb.append('<'); + i += 4; + } else if (html.startsWith(">", i)) { + sb.append('>'); + i += 4; + } else { + sb.append(c); + ++i; + } + } + return sb.toString(); + } + + private List getParameters(String body, final @NonNull Properties props) { + List parameters = new ArrayList(); + for (String inputTag : getInputTagsFromHTML(body)) { + String name = getValueByKey(inputTag, "name"); + String value = getValueByKey(inputTag, "value"); + String nameLower = name.toLowerCase(); + + if (nameLower.contains("username")) { + parameters.add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_USERNAME.getString(props))); + } else if (nameLower.contains("authmethod")) { + if (!value.isEmpty()) { + parameters.add(new BasicNameValuePair(name, value)); + } + } else if (nameLower.contains("password")) { + parameters + .add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_PASSWORD.getString(props))); + } else if (!name.isEmpty()) { + parameters.add(new BasicNameValuePair(name, value)); + } + } + return parameters; + } + + private String getFormAction(String body) { + Pattern pattern = Pattern.compile(" tokenCache = new ConcurrentHashMap<>(); + private final CredentialsProviderFactory credentialsProviderFactory; + private static final int DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - 30; + private static final int DEFAULT_HTTP_TIMEOUT = 60000; + public static final AwsWrapperProperty IDP_ENDPOINT = new AwsWrapperProperty("idpEndpoint", null, + "The hosting URL of the Identity Provider"); + public static final AwsWrapperProperty IDP_PORT = + new AwsWrapperProperty("idpPort", "443", "The hosting port of Identity Provider"); + public static final AwsWrapperProperty RELAYING_PARTY_ID = + new AwsWrapperProperty("rpIdentifier", "urn:amazon:webservices", "The relaying party identifier"); + public static final AwsWrapperProperty IAM_ROLE_ARN = + new AwsWrapperProperty("iamRoleArn", null, "The ARN of the IAM Role that is to be assumed."); + public static final AwsWrapperProperty IAM_IDP_ARN = + new AwsWrapperProperty("iamIdpArn", null, "The ARN of the Identity Provider"); + public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty("iamRegion", null, + "Overrides AWS region that is used to generate the IAM token"); + public static final AwsWrapperProperty IAM_TOKEN_EXPIRATION = new AwsWrapperProperty("iamTokenExpiration", + String.valueOf(DEFAULT_TOKEN_EXPIRATION_SEC), "IAM token cache expiration in seconds"); + public static final AwsWrapperProperty IDP_USERNAME = + new AwsWrapperProperty("idpUsername", null, "The federated user name"); + public static final AwsWrapperProperty IDP_PASSWORD = new AwsWrapperProperty("idpPassword", null, + "The federated user password"); + public static final AwsWrapperProperty IAM_HOST = new AwsWrapperProperty( + "iamHost", null, + "Overrides the host that is used to generate the IAM token"); + public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty("iamDefaultPort", null, + "Overrides default port that is used to generate the IAM token"); + public static final AwsWrapperProperty HTTP_CLIENT_SOCKET_TIMEOUT = new AwsWrapperProperty( + "httpClientSocketTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT), + "The socket timeout value for the HttpClient used by the FederatedAuthPlugin"); + public static final AwsWrapperProperty HTTP_CLIENT_CONNECT_TIMEOUT = new AwsWrapperProperty( + "httpClientConnectTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT), + "The connect timeout value for the HttpClient used by the FederatedAuthPlugin"); + public static final AwsWrapperProperty SSL_INSECURE = new AwsWrapperProperty("sslInsecure", "true", + "Whether or not a SSL "); + + public static AwsWrapperProperty + IDP_NAME = new AwsWrapperProperty("idpName", null, "The name of the Identity Provider implementation used "); + protected static final Pattern SAML_RESPONSE_PATTERN = Pattern.compile("SAMLResponse\\W+value=\"([^\"]+)\""); + protected static final Pattern HTTPS_URL_PATTERN = + Pattern.compile("^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']"); + + private static final String TELEMETRY_FETCH_TOKEN = "fetch IAM token"; + private static final Logger LOGGER = Logger.getLogger(FederatedAuthPlugin.class.getName()); + + protected final PluginService pluginService; + + protected final RdsUtils rdsUtils = new RdsUtils(); + + private static final Set subscribedMethods = + Collections.unmodifiableSet(new HashSet() { + { + add("connect"); + add("forceConnect"); + } + }); + + static { + PropertyDefinition.registerPluginProperties(FederatedAuthPlugin.class); + } + + private final TelemetryFactory telemetryFactory; + private final TelemetryGauge cacheSizeGauge; + private final TelemetryCounter fetchTokenCounter; + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + public FederatedAuthPlugin(final PluginService pluginService, + final CredentialsProviderFactory credentialsProviderFactory) { + try { + Class.forName("software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest"); + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("FederatedAuthPlugin.javaStsSdkNotInClasspath")); + } + this.pluginService = pluginService; + this.credentialsProviderFactory = credentialsProviderFactory; + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.cacheSizeGauge = telemetryFactory.createGauge("federatedAuth.tokenCache.size", () -> (long) tokenCache.size()); + this.fetchTokenCounter = telemetryFactory.createCounter("federatedAuth.fetchToken.count"); + } + + + @Override + public Connection connect( + final String driverProtocol, + final HostSpec hostSpec, + final Properties props, + final boolean isInitialConnection, + final JdbcCallable connectFunc) + throws SQLException { + return connectInternal(hostSpec, props, connectFunc); + } + + @Override + public Connection forceConnect( + final @NonNull String driverProtocol, + final @NonNull HostSpec hostSpec, + final @NonNull Properties props, + final boolean isInitialConnection, + final @NonNull JdbcCallable forceConnectFunc) + throws SQLException { + return connectInternal(hostSpec, props, forceConnectFunc); + } + + private Connection connectInternal(HostSpec hostSpec, Properties props, + JdbcCallable connectFunc) throws SQLException { + + String host = getHost(hostSpec, props); + + int port = getPort(hostSpec, props); + + final Region region = getRegion(host, props); + + final String cacheKey = getCacheKey( + PropertyDefinition.USER.getString(props), + host, + port, + region); + + final TokenInfo tokenInfo = tokenCache.get(cacheKey); + + final boolean isCachedToken = tokenInfo != null && !tokenInfo.isExpired(); + + if (isCachedToken) { + LOGGER.finest( + () -> Messages.get( + "FederatedAuthPlugin.useCachedIamToken", + new Object[] {tokenInfo.getToken()})); + PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken()); + } else { + updateAuthenticationToken(hostSpec, props, region, cacheKey); + } + + try { + return connectFunc.call(); + } catch (final SQLException exception) { + updateAuthenticationToken(hostSpec, props, region, cacheKey); + return connectFunc.call(); + } catch (final Exception exception) { + LOGGER.warning( + () -> Messages.get( + "FederatedAuthPlugin.unhandledException", + new Object[] {exception})); + throw new SQLException(exception); + } + } + + private void updateAuthenticationToken(HostSpec hostSpec, Properties props, Region region, String cacheKey) + throws SQLException { + final int tokenExpirationSec = IAM_TOKEN_EXPIRATION.getInteger(props); + final Instant tokenExpiry = Instant.now().plus(tokenExpirationSec, ChronoUnit.SECONDS); + AwsCredentialsProvider credentialsProvider = + this.credentialsProviderFactory.getAwsCredentialsProvider(hostSpec.getHost(), region, props); + String token = generateAuthenticationToken( + props, + hostSpec.getHost(), + getPort(hostSpec, props), + region, + credentialsProvider); + LOGGER.finest( + () -> Messages.get( + "FederatedAuthPlugin.generatedNewIamToken", + new Object[] {token})); + PropertyDefinition.PASSWORD.set(props, token); + tokenCache.put( + cacheKey, + new TokenInfo(token, tokenExpiry)); + try { + credentialsProviderFactory.close(); + } catch (IOException e) { + // Ignore + } + } + + private Region getRegion(final String hostname, final Properties props) throws SQLException { + final String iamRegion = IAM_REGION.getString(props); + if (!StringUtils.isNullOrEmpty(iamRegion)) { + return Region.of(iamRegion); + } + + // Fallback to using host + // Get Region + final String rdsRegion = rdsUtils.getRdsRegion(hostname); + + if (StringUtils.isNullOrEmpty(rdsRegion)) { + // Does not match Amazon's Hostname, throw exception + final String exceptionMessage = Messages.get( + "FederatedAuthPlugin.unsupportedHostname", + new Object[] {hostname}); + + LOGGER.fine(exceptionMessage); + throw new SQLException(exceptionMessage); + } + + // Check Region + final Optional regionOptional = Region.regions().stream() + .filter(r -> r.id().equalsIgnoreCase(rdsRegion)) + .findFirst(); + + if (!regionOptional.isPresent()) { + final String exceptionMessage = Messages.get( + "AwsSdk.unsupportedRegion", + new Object[] {rdsRegion}); + + LOGGER.fine(exceptionMessage); + throw new SQLException(exceptionMessage); + } + + return regionOptional.get(); + } + + String generateAuthenticationToken(final Properties props, final String hostname, + final int port, final Region region, final AwsCredentialsProvider awsCredentialsProvider) { + TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); + TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext( + TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED); + this.fetchTokenCounter.inc(); + try { + final String user = PropertyDefinition.USER.getString(props); + final RdsUtilities utilities = + RdsUtilities.builder().credentialsProvider(awsCredentialsProvider).region(region).build(); + return utilities.generateAuthenticationToken((builder) -> builder.hostname(hostname).port(port).username(user)); + } catch (Exception e) { + telemetryContext.setSuccess(false); + telemetryContext.setException(e); + throw e; + } finally { + telemetryContext.closeContext(); + } + } + + private String getCacheKey( + final String user, + final String hostname, + final int port, + final Region region) { + + return String.format("%s:%s:%d:%s", region, hostname, port, user); + } + + public static void clearCache() { + tokenCache.clear(); + } + + private String getHost(HostSpec hostSpec, Properties props) { + String host = hostSpec.getHost(); + if (!StringUtils.isNullOrEmpty(IAM_HOST.getString(props))) { + host = IAM_HOST.getString(props); + } + return host; + } + + private int getPort(HostSpec hostSpec, Properties props) { + if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) { + int defaultPort = IAM_DEFAULT_PORT.getInteger(props); + if (defaultPort > 0) { + return defaultPort; + } else { + LOGGER.finest(Messages.get("FederatedAuthPlugin.invalidPort", new Object[] {defaultPort})); + } + } + + if (hostSpec.isPortSpecified()) { + return hostSpec.getPort(); + } else { + return this.pluginService.getDialect().getDefaultPort(); + } + } + + static class TokenInfo { + + private final String token; + private final Instant expiration; + + public TokenInfo(final String token, final Instant expiration) { + this.token = token; + this.expiration = expiration; + } + + public String getToken() { + return this.token; + } + + public Instant getExpiration() { + return this.expiration; + } + + public boolean isExpired() { + return Instant.now().isAfter(this.expiration); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java new file mode 100644 index 000000000..99e0b21fe --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.security.GeneralSecurityException; +import java.util.Properties; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.Messages; + +public class FederatedAuthPluginFactory implements ConnectionPluginFactory { + + @Override + public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { + return new FederatedAuthPlugin(pluginService, getCredentialsProviderFactory(pluginService, props)); + } + + private CredentialsProviderFactory getCredentialsProviderFactory(final PluginService pluginService, + final Properties props) { + final String idpName = FederatedAuthPlugin.IDP_NAME.getString(props); + if (AdfsCredentialsProviderFactory.IDP_NAME.equalsIgnoreCase(idpName)) { + try { + return new AdfsCredentialsProviderFactory( + pluginService, + new HttpClientFactory().getCloseableHttpClient( + FederatedAuthPlugin.HTTP_CLIENT_SOCKET_TIMEOUT.getInteger(props), + FederatedAuthPlugin.HTTP_CLIENT_CONNECT_TIMEOUT.getInteger(props), + FederatedAuthPlugin.SSL_INSECURE.getBoolean(props))); + } catch (GeneralSecurityException e) { + throw new RuntimeException( + Messages.get("FederatedAuthPluginFactory.failedToInitializeHttpClient"), e); + } + } + throw new IllegalArgumentException(Messages.get("FederatedAuthPluginFactory.unsupportedIdp", + new Object[] {idpName})); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java new file mode 100644 index 000000000..e7597bbbb --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.security.GeneralSecurityException; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import org.apache.http.client.config.CookieSpecs; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultHttpRequestRetryHandler; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.client.LaxRedirectStrategy; + +/** + * Provides a HttpClient so that requests to HTTP API can be made. This is used by the + * {@link software.amazon.jdbc.plugin.federatedauth.AdfsCredentialsProviderFactory} to make HTTP calls to ADFS HTTP + * endpoints that are not available via SDK. + */ +public class HttpClientFactory { + private static int MAX_REQUEST_RETRIES = 3; + + public CloseableHttpClient getCloseableHttpClient(int socketTimeoutMs, int connectionTimeoutMs, + boolean keySslInsecure) throws GeneralSecurityException { + RequestConfig rc = RequestConfig.custom() + .setSocketTimeout(socketTimeoutMs) + .setConnectTimeout(connectionTimeoutMs) + .setExpectContinueEnabled(false) + .setCookieSpec(CookieSpecs.STANDARD) + .build(); + + HttpClientBuilder builder = HttpClients.custom() + .setDefaultRequestConfig(rc) + .setRedirectStrategy(new LaxRedirectStrategy()) + .setRetryHandler(new DefaultHttpRequestRetryHandler(MAX_REQUEST_RETRIES, true)) + .useSystemProperties(); // this is needed for proxy setting using system properties. + + if (keySslInsecure) { + SSLContext ctx = SSLContext.getInstance("TLSv1.2"); + TrustManager[] tma = new TrustManager[] {new NonValidatingSSLSocketFactory.NonValidatingTrustManager()}; + ctx.init(null, tma, null); + SSLSocketFactory factory = ctx.getSocketFactory(); + + SSLConnectionSocketFactory sf = new SSLConnectionSocketFactory( + factory, + new NoopHostnameVerifier()); + + builder.setSSLSocketFactory(sf); + } + + return builder.build(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java new file mode 100644 index 000000000..113235ac0 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +/** + * Provide a SSLSocketFactory that allows SSL connections to be made without validating the server's + * certificate. This is more convenient for some applications, but is less secure as it allows "man + * in the middle" attacks. + */ +public class NonValidatingSSLSocketFactory extends SSLSocketFactory { + + /** + * We provide a constructor that takes an unused argument solely because the ssl calling code will + * look for this constructor first and then fall back to the no argument constructor, so we avoid + * an exception and additional reflection lookups. + * + * @param arg input argument + * @throws GeneralSecurityException if something goes wrong + */ + public NonValidatingSSLSocketFactory(String arg) throws GeneralSecurityException { + SSLContext ctx = SSLContext.getInstance("TLS"); // or "SSL" ? + + ctx.init(null, new TrustManager[]{new NonValidatingTrustManager()}, null); + + factory = ctx.getSocketFactory(); + } + + protected SSLSocketFactory factory; + + public Socket createSocket(InetAddress host, int port) throws IOException { + return factory.createSocket(host, port); + } + + public Socket createSocket(String host, int port) throws IOException { + return factory.createSocket(host, port); + } + + public Socket createSocket(String host, int port, InetAddress localHost, int localPort) + throws IOException { + return factory.createSocket(host, port, localHost, localPort); + } + + public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) + throws IOException { + return factory.createSocket(address, port, localAddress, localPort); + } + + public Socket createSocket(Socket socket, String host, int port, boolean autoClose) + throws IOException { + return factory.createSocket(socket, host, port, autoClose); + } + + public String[] getDefaultCipherSuites() { + return factory.getDefaultCipherSuites(); + } + + public String[] getSupportedCipherSuites() { + return factory.getSupportedCipherSuites(); + } + + public static class NonValidatingTrustManager implements X509TrustManager { + + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + + public void checkClientTrusted(X509Certificate[] certs, String authType) { + } + + public void checkServerTrusted(X509Certificate[] certs, String authType) { + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java new file mode 100644 index 000000000..93adea8df --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.IAM_IDP_ARN; +import static software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.IAM_ROLE_ARN; + +import java.util.Properties; +import java.util.function.Supplier; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleWithSamlCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest; + +public abstract class SamlCredentialsProviderFactory implements CredentialsProviderFactory { + + @Override + public AwsCredentialsProvider getAwsCredentialsProvider(String host, Region region, final @NonNull Properties props) { + + Supplier assumeRoleWithSamlRequestSupplier = () -> { + String samlAssertion = getSamlAssertion(props); + + return AssumeRoleWithSamlRequest.builder() + .samlAssertion(samlAssertion) + .roleArn(IAM_ROLE_ARN.getString(props)) + .principalArn(IAM_IDP_ARN.getString(props)) + .build(); + }; + + StsClient stsClient = StsClient.builder() + .credentialsProvider(AnonymousCredentialsProvider.create()) + .region(region) + .build(); + + return StsAssumeRoleWithSamlCredentialsProvider.builder() + .refreshRequest(assumeRoleWithSamlRequestSupplier) + .asyncCredentialUpdateEnabled(true) + .stsClient(stsClient) + .build(); + } + + abstract String getSamlAssertion(final @NonNull Properties props); +} diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index 18d97ef5a..51e457ff5 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -14,6 +14,16 @@ # limitations under the License. # +# ADFS Credentials Provider Getter +AdfsCredentialsProviderFactory.failedLogin=Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response: \n``{0}`` +AdfsCredentialsProviderFactory.getSamlAssertionFailed=Failed to get Saml Assertion due to exception: ``{0}`` +AdfsCredentialsProviderFactory.invalidHttpsUrl=Invalid HTTPS URL: ``{0}`` +AdfsCredentialsProviderFactory.signOnPageBody=ADFS SignOn Body: ``{0}`` +AdfsCredentialsProviderFactory.signOnPagePostActionUrl=ADFS SignOn Action URL: ``{0}`` +AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed=ADFS SignOn Page POST action failed with HTTP status ``{0}``, reason phrase ``{1}``, and response ``{2}`` +AdfsCredentialsProviderFactory.signOnPageRequestFailed=ADFS SignOn Page Request Failed with HTTP status ``{0}``, reason phrase ``{1}``, and response ``{2}`` +AdfsCredentialsProviderFactory.signOnPageUrl=ADFS SignOn URL: ``{0}`` + # Aurora Host List Connection Plugin AuroraHostListConnectionPlugin.providerAlreadySet=Another dynamic host list provider has already been set: {0}. @@ -147,6 +157,17 @@ Failover.failedToUpdateCurrentHostspecAvailability=Failed to update current host Failover.noOperationsAfterConnectionClosed=No operations allowed after connection closed. Failover.invalidHostListProvider=Incorrect type of host list provider found, please ensure the correct host list provider is specified. The host list provider in use is: ''{0}'', the plugin is expected a cluster-aware host list provider such as the AuroraHostListProvider. +# Federated Authentication Connection Plugin +FederatedAuthPlugin.generatedNewIamToken=Generated new IAM token = ''{0}'' +FederatedAuthPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero. Falling back to default port. +FederatedAuthPlugin.javaStsSdkNotInClasspath=Required dependency 'AWS Java SDK for AWS Secret Token Service' is not on the classpath. +FederatedAuthPlugin.unhandledException=Unhandled exception: ''{0}'' +FederatedAuthPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected. +FederatedAuthPlugin.useCachedIamToken=Use cached IAM token = ''{0}'' + +# Federated Authentication Connection Plugin Factory +FederatedAuthPluginFactory.failedToInitializeHttpClient=Failed to initialize HttpClient. +FederatedAuthPluginFactory.unsupportedIdp=Unsupported Identity Provider {0}. For a list of Identity Providers supported by the FederatedAuthenticationConnectionPlugin, please refer to the documentation. # HikariPooledConnectionProvider HikariPooledConnectionProvider.errorConnectingWithDataSource=Unable to connect to ''{0}'' using the Hikari data source. diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index b52eecc92..e0c1da927 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -41,8 +41,9 @@ dependencies { testImplementation("com.zaxxer:HikariCP:4.+") // version 4.+ is compatible with Java 8 testImplementation("org.springframework.boot:spring-boot-starter-jdbc:2.7.13") // 2.7.13 is the last version compatible with Java 8 testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8 - testImplementation("software.amazon.awssdk:rds:2.20.49") testImplementation("software.amazon.awssdk:ec2:2.20.49") + testImplementation("software.amazon.awssdk:rds:2.20.49") + testImplementation("software.amazon.awssdk:sts:2.20.49") testImplementation("org.testcontainers:testcontainers:1.17.+") testImplementation("org.testcontainers:mysql:1.17.+") testImplementation("org.testcontainers:postgresql:1.17.+") diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java new file mode 100644 index 000000000..cb7a9835d --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java @@ -0,0 +1,109 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; +import org.apache.http.HttpEntity; +import org.apache.http.StatusLine; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.util.EntityUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testcontainers.shaded.org.apache.commons.io.IOUtils; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class AdfsCredentialsProviderFactoryTest { + + private static final String USERNAME = "someFederatedUsername@teamatlas.example.com"; + private static final String PASSWORD = "somePassword"; + @Mock private PluginService mockPluginService; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private CloseableHttpClient mockHttpClient; + @Mock private CloseableHttpResponse mockHttpGetSignInPageResponse; + @Mock private CloseableHttpResponse mockHttpPostSignInResponse; + @Mock private StatusLine mockStatusLine; + @Mock private HttpEntity mockSignInPageHttpEntity; + @Mock private HttpEntity mockSamlHttpEntity; + private AdfsCredentialsProviderFactory adfsCredentialsProviderFactory; + private Properties props; + + @BeforeEach + public void init() throws IOException { + MockitoAnnotations.openMocks(this); + + this.props = new Properties(); + this.props.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com"); + this.props.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, USERNAME); + this.props.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, PASSWORD); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); + when(mockHttpClient.execute(any(HttpGet.class))).thenReturn(mockHttpGetSignInPageResponse); + when(mockHttpGetSignInPageResponse.getStatusLine()).thenReturn(mockStatusLine); + when(mockStatusLine.getStatusCode()).thenReturn(200); + when(mockHttpGetSignInPageResponse.getEntity()).thenReturn(mockSignInPageHttpEntity); + + String signinPageHtml = IOUtils.toString( + this.getClass().getClassLoader().getResourceAsStream("federated_auth/adfs-sign-in-page.html"), "UTF-8"); + InputStream signInPageHtmlInputStream = new ByteArrayInputStream(signinPageHtml.getBytes()); + when(mockSignInPageHttpEntity.getContent()).thenReturn(signInPageHtmlInputStream); + + when(mockHttpClient.execute(any(HttpPost.class))).thenReturn(mockHttpPostSignInResponse); + when(mockHttpPostSignInResponse.getStatusLine()).thenReturn(mockStatusLine); + when(mockHttpPostSignInResponse.getEntity()).thenReturn(mockSamlHttpEntity); + + String adfsSamlHtml = IOUtils.toString( + this.getClass().getClassLoader().getResourceAsStream("federated_auth/adfs-saml.html"), "UTF-8"); + InputStream samlHtmlInputStream = new ByteArrayInputStream(adfsSamlHtml.getBytes()); + when(mockSamlHttpEntity.getContent()).thenReturn(samlHtmlInputStream); + + this.adfsCredentialsProviderFactory = new AdfsCredentialsProviderFactory(mockPluginService, mockHttpClient); + } + + @Test + void test() throws IOException { + this.adfsCredentialsProviderFactory.getSamlAssertion(props); + + ArgumentCaptor httpPostArgumentCaptor = ArgumentCaptor.forClass(HttpPost.class); + verify(mockHttpClient, times(2)).execute(httpPostArgumentCaptor.capture()); + HttpPost actualHttpPost = httpPostArgumentCaptor.getValue(); + String content = EntityUtils.toString(actualHttpPost.getEntity()); + String[] params = content.split("&"); + assertEquals("UserName=" + USERNAME.replace("@", "%40"), params[0]); + assertEquals("Password=" + PASSWORD, params[1]); + assertEquals("Kmsi=true", params[2]); + assertEquals("AuthMethod=FormsAuthentication", params[3]); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java new file mode 100644 index 000000000..aeba89aeb --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -0,0 +1,167 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class FederatedAuthPluginTest { + + private static final int DEFAULT_PORT = 1234; + private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; + + private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); + private static final String TEST_TOKEN = "someTestToken"; + @Mock private PluginService mockPluginService; + @Mock private Dialect mockDialect; + @Mock JdbcCallable mockLambda; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private TelemetryCounter mockTelemetryCounter; + @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; + @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; + @Mock private CompletableFuture completableFuture; + @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; + private Properties props; + + @BeforeEach + public void init() throws ExecutionException, InterruptedException, SQLException { + MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); + props.setProperty(PropertyDefinition.USER.name, "iamUser"); + FederatedAuthPlugin.clearCache(); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); + when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); + when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) + .thenReturn(mockAwsCredentialsProvider); + when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); + when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); + } + + @Test + void testCachedToken() throws SQLException { + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + FederatedAuthPlugin.TokenInfo tokenInfo = new FederatedAuthPlugin.TokenInfo( + TEST_TOKEN, Instant.now().plusMillis(300000)); + FederatedAuthPlugin.tokenCache.put(key, tokenInfo); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testExpiredCachedToken() throws SQLException { + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + FederatedAuthPlugin spyPlugin = Mockito.spy(plugin); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + String someExpiredToken = "someExpiredToken"; + FederatedAuthPlugin.TokenInfo expiredTokenInfo = new FederatedAuthPlugin.TokenInfo( + someExpiredToken, Instant.now().minusMillis(300000)); + FederatedAuthPlugin.tokenCache.put(key, expiredTokenInfo); + + when( + spyPlugin.generateAuthenticationToken( + props, + HOST_SPEC.getHost(), + DEFAULT_PORT, + Region.US_EAST_2, mockAwsCredentialsProvider)) + .thenReturn(TEST_TOKEN); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testNoCachedToken() throws SQLException { + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + FederatedAuthPlugin spyPlugin = Mockito.spy(plugin); + + when( + spyPlugin.generateAuthenticationToken( + props, + HOST_SPEC.getHost(), + DEFAULT_PORT, + Region.US_EAST_2, mockAwsCredentialsProvider)) + .thenReturn(TEST_TOKEN); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testSpecifiedIamHostPortRegion() throws SQLException { + final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; + final int expectedPort = 9876; + final Region expectedRegion = Region.US_WEST_2; + + props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost); + props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); + props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString()); + + final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + String.valueOf(expectedPort) + ":iamUser"; + FederatedAuthPlugin.TokenInfo tokenInfo = new FederatedAuthPlugin.TokenInfo( + TEST_TOKEN, Instant.now().plusMillis(300000)); + FederatedAuthPlugin.tokenCache.put(key, tokenInfo); + + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + // test throw if AWS Java SDK for AWS STS not available +} diff --git a/wrapper/src/test/resources/federated_auth/adfs-saml.html b/wrapper/src/test/resources/federated_auth/adfs-saml.html new file mode 100644 index 000000000..6686e3db8 --- /dev/null +++ b/wrapper/src/test/resources/federated_auth/adfs-saml.html @@ -0,0 +1 @@ +Working...
diff --git a/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html b/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html new file mode 100644 index 000000000..0b06f6dda --- /dev/null +++ b/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html @@ -0,0 +1,613 @@ + + + + + + + + + + + + Sign In + + + + + + + + + + +
+

JavaScript required

+

JavaScript is required. This web browser does not support JavaScript or JavaScript in this web browser is not enabled.

+

To find out if your web browser supports JavaScript or to enable JavaScript, see web browser help.

+
+ +
+
+
+
+
+
+ +
+
+ +
+ + + +
+
Sign in
+ +
+
+ +
+ +
+
+ + +
+ +
+ + +
+ +
+ Sign in +
+
+ +
+ +
+
+ + + + +
+
+ +
+ +
+ + +
+ +
+ +
+
+
+
+
+ +
+
+
+ + + + + +