diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index 40c4655aa..864ea90d0 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -20,6 +20,7 @@ dependencies { implementation("mysql:mysql-connector-java:8.0.33") implementation("software.amazon.awssdk:rds:2.21.42") implementation("software.amazon.awssdk:secretsmanager:2.21.43") + implementation("software.amazon.awssdk:sts:2.21.42") implementation("com.fasterxml.jackson.core:jackson-databind:2.16.0") implementation(project(":aws-advanced-jdbc-wrapper")) implementation("io.opentelemetry:opentelemetry-api:1.32.0") diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java new file mode 100644 index 000000000..8953d986d --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon; + +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +public class FederatedAuthPluginExample { + + private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-identifier.XYZ.us-east-2.rds.amazonaws.com:5432/employees"; + + public static void main(String[] args) throws SQLException { + // Set the AWS Federated Authentication Connection Plugin parameters and the JDBC Wrapper parameters. + final Properties properties = new Properties(); + + // Enable the AWS Federated Authentication Connection Plugin. + properties.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); + properties.setProperty(FederatedAuthPlugin.IDP_NAME.name, "adfs"); + properties.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com"); + properties.setProperty(FederatedAuthPlugin.IAM_ROLE_ARN.name, "arn:aws:iam::123456789012:role/adfs_example_iam_role"); + properties.setProperty(FederatedAuthPlugin.IAM_IDP_ARN.name, "arn:aws:iam::123456789012:saml-provider/adfs_example"); + properties.setProperty(FederatedAuthPlugin.IAM_REGION.name, "us-east-2"); + properties.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, "someFederatedUsername@example.com"); + properties.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, "somePassword"); + properties.setProperty(PropertyDefinition.USER.name, "someIamUser"); + + // Try and make a connection: + try (final Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties); + final Statement statement = conn.createStatement(); + final ResultSet rs = statement.executeQuery("SELECT 1")) { + System.out.println(Util.getResult(rs)); + } + } +} diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 87ed00590..0155dabf3 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -28,7 +28,9 @@ plugins { dependencies { implementation("org.checkerframework:checker-qual:3.40.0") + compileOnly("org.apache.httpcomponents:httpclient:4.5.14") compileOnly("software.amazon.awssdk:rds:2.21.42") + compileOnly("software.amazon.awssdk:sts:2.21.42") compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8 compileOnly("software.amazon.awssdk:secretsmanager:2.21.43") compileOnly("com.fasterxml.jackson.core:jackson-databind:2.16.0") @@ -61,6 +63,7 @@ dependencies { testImplementation("software.amazon.awssdk:rds:2.21.42") testImplementation("software.amazon.awssdk:ec2:2.21.12") testImplementation("software.amazon.awssdk:secretsmanager:2.21.43") + testImplementation("software.amazon.awssdk:sts:2.21.42") testImplementation("org.testcontainers:testcontainers:1.19.3") testImplementation("org.testcontainers:mysql:1.19.3") testImplementation("org.testcontainers:postgresql:1.19.3") diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 353e3eeb6..2093f4999 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -39,6 +39,7 @@ import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory; import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory; import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory; +import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory; import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory; import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPluginFactory; import software.amazon.jdbc.plugin.strategy.fastestresponse.FastestResponseStrategyPluginFactory; @@ -66,6 +67,7 @@ public class ConnectionPluginChainBuilder { put("failover", FailoverConnectionPluginFactory.class); put("iam", IamAuthConnectionPluginFactory.class); put("awsSecretsManager", AwsSecretsManagerConnectionPluginFactory.class); + put("federatedAuth", FederatedAuthPluginFactory.class); put("auroraStaleDns", AuroraStaleDnsPluginFactory.class); put("readWriteSplitting", ReadWriteSplittingPluginFactory.class); put("auroraConnectionTracker", AuroraConnectionTrackerPluginFactory.class); @@ -96,7 +98,8 @@ public class ConnectionPluginChainBuilder { put(FastestResponseStrategyPluginFactory.class, 900); put(IamAuthConnectionPluginFactory.class, 1000); put(AwsSecretsManagerConnectionPluginFactory.class, 1100); - put(LogQueryConnectionPluginFactory.class, 1200); + put(FederatedAuthPluginFactory.class, 1200); + put(LogQueryConnectionPluginFactory.class, 1300); put(ConnectTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN); put(ExecutionTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN); put(DeveloperConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java index 69fa0c813..d56aafaa2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java @@ -36,6 +36,7 @@ import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.authentication.AwsCredentialsManager; +import software.amazon.jdbc.util.IamAuthUtils; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.StringUtils; @@ -64,7 +65,7 @@ public class IamAuthConnectionPlugin extends AbstractConnectionPlugin { "Overrides the host that is used to generate the IAM token"); public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty( - "iamDefaultPort", null, + "iamDefaultPort", "-1", "Overrides default port that is used to generate the IAM token"); public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty( @@ -115,12 +116,12 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro throw new SQLException(PropertyDefinition.USER.name + " is null or empty."); } - String host = hostSpec.getHost(); - if (!StringUtils.isNullOrEmpty(IAM_HOST.getString(props))) { - host = IAM_HOST.getString(props); - } + String host = IamAuthUtils.getIamHost(IAM_HOST.getString(props), hostSpec); - int port = getPort(props, hostSpec); + int port = IamAuthUtils.getIamPort( + IAM_DEFAULT_PORT.getInteger(props), + hostSpec, + this.pluginService.getDialect().getDefaultPort()); final String iamRegion = IAM_REGION.getString(props); final Region region = StringUtils.isNullOrEmpty(iamRegion) @@ -261,26 +262,6 @@ public static void clearCache() { tokenCache.clear(); } - private int getPort(Properties props, HostSpec hostSpec) { - if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) { - int defaultPort = IAM_DEFAULT_PORT.getInteger(props); - if (defaultPort > 0) { - return defaultPort; - } else { - LOGGER.finest( - () -> Messages.get( - "IamAuthConnectionPlugin.invalidPort", - new Object[] {defaultPort})); - } - } - - if (hostSpec.isPortSpecified()) { - return hostSpec.getPort(); - } else { - return this.pluginService.getDialect().getDefaultPort(); - } - } - private Region getRdsRegion(final String hostname) throws SQLException { // Get Region @@ -312,27 +293,4 @@ private Region getRdsRegion(final String hostname) throws SQLException { return regionOptional.get(); } - - static class TokenInfo { - - private final String token; - private final Instant expiration; - - public TokenInfo(final String token, final Instant expiration) { - this.token = token; - this.expiration = expiration; - } - - public String getToken() { - return this.token; - } - - public Instant getExpiration() { - return this.expiration; - } - - public boolean isExpired() { - return Instant.now().isAfter(this.expiration); - } - } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java new file mode 100644 index 000000000..3fad01093 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java @@ -0,0 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin; + +import java.time.Instant; + +public class TokenInfo { + private final String token; + private final Instant expiration; + + public TokenInfo(final String token, final Instant expiration) { + this.token = token; + this.expiration = expiration; + } + + public String getToken() { + return this.token; + } + + public Instant getExpiration() { + return this.expiration; + } + + public boolean isExpired() { + return Instant.now().isAfter(this.expiration); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java new file mode 100644 index 000000000..aa23a8f45 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java @@ -0,0 +1,252 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.io.IOException; +import java.net.URI; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.function.Supplier; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.http.NameValuePair; +import org.apache.http.StatusLine; +import org.apache.http.client.entity.UrlEncodedFormEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.message.BasicNameValuePair; +import org.apache.http.util.EntityUtils; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class AdfsCredentialsProviderFactory extends SamlCredentialsProviderFactory { + + public static final String IDP_NAME = "adfs"; + private static final String TELEMETRY_FETCH_SAML = "Fetch ADFS SAML Assertion"; + private static final Pattern INPUT_TAG_PATTERN = Pattern.compile("", Pattern.DOTALL); + private static final Pattern FORM_ACTION_PATTERN = Pattern.compile(" httpClientSupplier; + private TelemetryContext telemetryContext; + + public AdfsCredentialsProviderFactory(final PluginService pluginService, + final Supplier httpClientSupplier) { + this.pluginService = pluginService; + this.telemetryFactory = this.pluginService.getTelemetryFactory(); + this.httpClientSupplier = httpClientSupplier; + } + + @Override + String getSamlAssertion(final @NonNull Properties props) throws SQLException { + this.telemetryContext = telemetryFactory.openTelemetryContext(TELEMETRY_FETCH_SAML, TelemetryTraceLevel.NESTED); + try (final CloseableHttpClient httpClient = httpClientSupplier.get()) { + String uri = getSignInPageUrl(props); + final String signInPageBody = getSignInPageBody(httpClient, uri); + final String action = getFormActionFromHtmlBody(signInPageBody); + + if (!StringUtils.isNullOrEmpty(action) && action.startsWith("/")) { + uri = getFormActionUrl(props, action); + } + + final List params = getParametersFromHtmlBody(signInPageBody, props); + final String content = getFormActionBody(httpClient, uri, params); + + final Matcher matcher = FederatedAuthPlugin.SAML_RESPONSE_PATTERN.matcher(content); + if (!matcher.find()) { + throw new IOException(Messages.get("AdfsCredentialsProviderFactory.failedLogin", new Object[] {content})); + } + + // return SAML Response value + return matcher.group(FederatedAuthPlugin.SAML_RESPONSE_PATTERN_GROUP); + } catch (final IOException e) { + LOGGER.severe(Messages.get("AdfsCredentialsProviderFactory.getSamlAssertionFailed", new Object[] {e})); + this.telemetryContext.setSuccess(false); + this.telemetryContext.setException(e); + throw new SQLException(e); + } finally { + this.telemetryContext.closeContext(); + } + } + + private String getSignInPageBody(final CloseableHttpClient httpClient, final String uri) throws IOException { + LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPageUrl", new Object[] {uri})); + validateUrl(uri); + final HttpGet get = new HttpGet(uri); + try (final CloseableHttpResponse resp = httpClient.execute(get)) { + final StatusLine statusLine = resp.getStatusLine(); + // Check HTTP Status Code is 2xx Success + if (statusLine.getStatusCode() / 100 != 2) { + throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPageRequestFailed", + new Object[] { + statusLine.getStatusCode(), + statusLine.getReasonPhrase(), + EntityUtils.toString(resp.getEntity())})); + } + return EntityUtils.toString(resp.getEntity()); + } + } + + private String getFormActionBody(final CloseableHttpClient httpClient, final String uri, + final List params) throws IOException { + LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionUrl", new Object[] {uri})); + validateUrl(uri); + final HttpPost post = new HttpPost(uri); + post.setEntity(new UrlEncodedFormEntity(params)); + try (final CloseableHttpResponse resp = httpClient.execute(post)) { + final StatusLine statusLine = resp.getStatusLine(); + // Check HTTP Status Code is 2xx Success + if (statusLine.getStatusCode() / 100 != 2) { + throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed", + new Object[] { + statusLine.getStatusCode(), + statusLine.getReasonPhrase(), + EntityUtils.toString(resp.getEntity())})); + } + return EntityUtils.toString(resp.getEntity()); + } + } + + private String getSignInPageUrl(final Properties props) { + return "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':' + + FederatedAuthPlugin.IDP_PORT.getString(props) + "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=" + + FederatedAuthPlugin.RELAYING_PARTY_ID.getString(props); + } + + private String getFormActionUrl(final Properties props, final String action) { + return "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':' + + FederatedAuthPlugin.IDP_PORT.getString(props) + action; + } + + private List getInputTagsFromHTML(final String body) { + final Set distinctInputTags = new HashSet<>(); + final List inputTags = new ArrayList<>(); + final Matcher inputTagMatcher = INPUT_TAG_PATTERN.matcher(body); + while (inputTagMatcher.find()) { + final String tag = inputTagMatcher.group(0); + final String tagNameLower = getValueByKey(tag, "name").toLowerCase(); + if (!tagNameLower.isEmpty() && distinctInputTags.add(tagNameLower)) { + inputTags.add(tag); + } + } + return inputTags; + } + + private String getValueByKey(final String input, final String key) { + final Pattern keyValuePattern = Pattern.compile("(" + Pattern.quote(key) + ")\\s*=\\s*\"(.*?)\""); + final Matcher keyValueMatcher = keyValuePattern.matcher(input); + if (keyValueMatcher.find()) { + return escapeHtmlEntity(keyValueMatcher.group(2)); + } + return ""; + } + + private String escapeHtmlEntity(final String html) { + final StringBuilder sb = new StringBuilder(html.length()); + int i = 0; + final int length = html.length(); + while (i < length) { + final char c = html.charAt(i); + if (c != '&') { + sb.append(c); + i++; + continue; + } + + if (html.startsWith("&", i)) { + sb.append('&'); + i += 5; + } else if (html.startsWith("'", i)) { + sb.append('\''); + i += 6; + } else if (html.startsWith(""", i)) { + sb.append('"'); + i += 6; + } else if (html.startsWith("<", i)) { + sb.append('<'); + i += 4; + } else if (html.startsWith(">", i)) { + sb.append('>'); + i += 4; + } else { + sb.append(c); + ++i; + } + } + return sb.toString(); + } + + private List getParametersFromHtmlBody(final String body, final @NonNull Properties props) { + final List parameters = new ArrayList<>(); + for (final String inputTag : getInputTagsFromHTML(body)) { + final String name = getValueByKey(inputTag, "name"); + final String value = getValueByKey(inputTag, "value"); + final String nameLower = name.toLowerCase(); + + if (nameLower.contains("username")) { + parameters.add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_USERNAME.getString(props))); + } else if (nameLower.contains("authmethod")) { + if (!value.isEmpty()) { + parameters.add(new BasicNameValuePair(name, value)); + } + } else if (nameLower.contains("password")) { + parameters + .add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_PASSWORD.getString(props))); + } else if (!name.isEmpty()) { + parameters.add(new BasicNameValuePair(name, value)); + } + } + return parameters; + } + + private String getFormActionFromHtmlBody(final String body) { + final Matcher m = FORM_ACTION_PATTERN.matcher(body); + if (m.find()) { + return escapeHtmlEntity(m.group(1)); + } + return null; + } + + private void validateUrl(final String paramString) throws IOException { + + final URI authorizeRequestUrl = URI.create(paramString); + final String errorMessage = Messages.get("AdfsCredentialsProviderFactory.invalidHttpsUrl", + new Object[] {paramString}); + + if (!authorizeRequestUrl.toURL().getProtocol().equalsIgnoreCase("https")) { + throw new IOException(errorMessage); + } + + final Matcher matcher = FederatedAuthPlugin.HTTPS_URL_PATTERN.matcher(paramString); + if (!matcher.find()) { + throw new IOException(errorMessage); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java new file mode 100644 index 000000000..a43396bf9 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.io.Closeable; +import java.sql.SQLException; +import java.util.Properties; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +public interface CredentialsProviderFactory { + AwsCredentialsProvider getAwsCredentialsProvider(String host, Region region, final @NonNull Properties props) throws + SQLException; +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java new file mode 100644 index 000000000..7def73fec --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -0,0 +1,319 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.logging.Logger; +import java.util.regex.Pattern; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsUtilities; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.util.IamAuthUtils; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryGauge; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +public class FederatedAuthPlugin extends AbstractConnectionPlugin { + + static final ConcurrentHashMap tokenCache = new ConcurrentHashMap<>(); + private final CredentialsProviderFactory credentialsProviderFactory; + private static final int DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - 30; + private static final int DEFAULT_HTTP_TIMEOUT_MILLIS = 60000; + public static final AwsWrapperProperty IDP_ENDPOINT = new AwsWrapperProperty("idpEndpoint", null, + "The hosting URL of the Identity Provider"); + public static final AwsWrapperProperty IDP_PORT = + new AwsWrapperProperty("idpPort", "443", "The hosting port of Identity Provider"); + public static final AwsWrapperProperty RELAYING_PARTY_ID = + new AwsWrapperProperty("rpIdentifier", "urn:amazon:webservices", "The relaying party identifier"); + public static final AwsWrapperProperty IAM_ROLE_ARN = + new AwsWrapperProperty("iamRoleArn", null, "The ARN of the IAM Role that is to be assumed."); + public static final AwsWrapperProperty IAM_IDP_ARN = + new AwsWrapperProperty("iamIdpArn", null, "The ARN of the Identity Provider"); + public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty("iamRegion", null, + "Overrides AWS region that is used to generate the IAM token"); + public static final AwsWrapperProperty IAM_TOKEN_EXPIRATION = new AwsWrapperProperty("iamTokenExpiration", + String.valueOf(DEFAULT_TOKEN_EXPIRATION_SEC), "IAM token cache expiration in seconds"); + public static final AwsWrapperProperty IDP_USERNAME = + new AwsWrapperProperty("idpUsername", null, "The federated user name"); + public static final AwsWrapperProperty IDP_PASSWORD = new AwsWrapperProperty("idpPassword", null, + "The federated user password"); + public static final AwsWrapperProperty IAM_HOST = new AwsWrapperProperty( + "iamHost", null, + "Overrides the host that is used to generate the IAM token"); + public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty("iamDefaultPort", "-1", + "Overrides default port that is used to generate the IAM token"); + public static final AwsWrapperProperty HTTP_CLIENT_SOCKET_TIMEOUT = new AwsWrapperProperty( + "httpClientSocketTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT_MILLIS), + "The socket timeout value in milliseconds for the HttpClient used by the FederatedAuthPlugin"); + public static final AwsWrapperProperty HTTP_CLIENT_CONNECT_TIMEOUT = new AwsWrapperProperty( + "httpClientConnectTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT_MILLIS), + "The connect timeout value in milliseconds for the HttpClient used by the FederatedAuthPlugin"); + public static final AwsWrapperProperty SSL_INSECURE = new AwsWrapperProperty("sslInsecure", "true", + "Whether or not the SSL session is to be secure and the sever's certificates will be verified"); + public static AwsWrapperProperty + IDP_NAME = new AwsWrapperProperty("idpName", null, "The name of the Identity Provider implementation used"); + public static final AwsWrapperProperty DB_USER = + new AwsWrapperProperty("dbUser", null, "The database user used to access the database"); + protected static final Pattern SAML_RESPONSE_PATTERN = Pattern.compile("SAMLResponse\\W+value=\"(?[^\"]+)\""); + protected static final String SAML_RESPONSE_PATTERN_GROUP = "saml"; + protected static final Pattern HTTPS_URL_PATTERN = + Pattern.compile("^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']"); + + private static final String TELEMETRY_FETCH_TOKEN = "fetch IAM token"; + private static final Logger LOGGER = Logger.getLogger(FederatedAuthPlugin.class.getName()); + + protected final PluginService pluginService; + + protected final RdsUtils rdsUtils = new RdsUtils(); + + private static final Set subscribedMethods = + Collections.unmodifiableSet(new HashSet() { + { + add("connect"); + add("forceConnect"); + } + }); + + static { + PropertyDefinition.registerPluginProperties(FederatedAuthPlugin.class); + } + + private final TelemetryFactory telemetryFactory; + private final TelemetryGauge cacheSizeGauge; + private final TelemetryCounter fetchTokenCounter; + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + public FederatedAuthPlugin(final PluginService pluginService, + final CredentialsProviderFactory credentialsProviderFactory) { + try { + Class.forName("software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest"); + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("FederatedAuthPlugin.javaStsSdkNotInClasspath")); + } + this.pluginService = pluginService; + this.credentialsProviderFactory = credentialsProviderFactory; + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.cacheSizeGauge = telemetryFactory.createGauge("federatedAuth.tokenCache.size", () -> (long) tokenCache.size()); + this.fetchTokenCounter = telemetryFactory.createCounter("federatedAuth.fetchToken.count"); + } + + + @Override + public Connection connect( + final String driverProtocol, + final HostSpec hostSpec, + final Properties props, + final boolean isInitialConnection, + final JdbcCallable connectFunc) + throws SQLException { + return connectInternal(hostSpec, props, connectFunc); + } + + @Override + public Connection forceConnect( + final @NonNull String driverProtocol, + final @NonNull HostSpec hostSpec, + final @NonNull Properties props, + final boolean isInitialConnection, + final @NonNull JdbcCallable forceConnectFunc) + throws SQLException { + return connectInternal(hostSpec, props, forceConnectFunc); + } + + private Connection connectInternal(final HostSpec hostSpec, final Properties props, + final JdbcCallable connectFunc) throws SQLException { + + checkIdpCredentialsWithFallback(props); + + final String host = IamAuthUtils.getIamHost(IAM_HOST.getString(props), hostSpec); + + final int port = IamAuthUtils.getIamPort( + IAM_DEFAULT_PORT.getInteger(props), + hostSpec, + this.pluginService.getDialect().getDefaultPort()); + + final Region region = getRegion(host, props); + + final String cacheKey = getCacheKey( + DB_USER.getString(props), + host, + port, + region); + + final TokenInfo tokenInfo = tokenCache.get(cacheKey); + + final boolean isCachedToken = tokenInfo != null && !tokenInfo.isExpired(); + + if (isCachedToken) { + LOGGER.finest( + () -> Messages.get( + "FederatedAuthPlugin.useCachedIamToken", + new Object[] {tokenInfo.getToken()})); + PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken()); + } else { + updateAuthenticationToken(hostSpec, props, region, cacheKey); + } + + PropertyDefinition.USER.set(props, DB_USER.getString(props)); + + try { + return connectFunc.call(); + } catch (final SQLException exception) { + updateAuthenticationToken(hostSpec, props, region, cacheKey); + return connectFunc.call(); + } catch (final Exception exception) { + LOGGER.warning( + () -> Messages.get( + "FederatedAuthPlugin.unhandledException", + new Object[] {exception})); + throw new SQLException(exception); + } + } + + private void checkIdpCredentialsWithFallback(final Properties props) { + if (IDP_USERNAME.getString(props) == null) { + IDP_USERNAME.set(props, PropertyDefinition.USER.getString(props)); + } + + if (IDP_PASSWORD.getString(props) == null) { + IDP_PASSWORD.set(props, PropertyDefinition.PASSWORD.getString(props)); + } + } + + private void updateAuthenticationToken(final HostSpec hostSpec, final Properties props, final Region region, + final String cacheKey) + throws SQLException { + final int tokenExpirationSec = IAM_TOKEN_EXPIRATION.getInteger(props); + final Instant tokenExpiry = Instant.now().plus(tokenExpirationSec, ChronoUnit.SECONDS); + final int port = IamAuthUtils.getIamPort( + StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props)) ? 0 : IAM_DEFAULT_PORT.getInteger(props), + hostSpec, + this.pluginService.getDialect().getDefaultPort()); + final AwsCredentialsProvider credentialsProvider = + this.credentialsProviderFactory.getAwsCredentialsProvider(hostSpec.getHost(), region, props); + final String token = generateAuthenticationToken( + props, + hostSpec.getHost(), + port, + region, + credentialsProvider); + LOGGER.finest( + () -> Messages.get( + "FederatedAuthPlugin.generatedNewIamToken", + new Object[] {token})); + PropertyDefinition.PASSWORD.set(props, token); + tokenCache.put( + cacheKey, + new TokenInfo(token, tokenExpiry)); + } + + private Region getRegion(final String hostname, final Properties props) throws SQLException { + final String iamRegion = IAM_REGION.getString(props); + if (!StringUtils.isNullOrEmpty(iamRegion)) { + return Region.of(iamRegion); + } + + // Fallback to using host + // Get Region + final String rdsRegion = rdsUtils.getRdsRegion(hostname); + + if (StringUtils.isNullOrEmpty(rdsRegion)) { + // Does not match Amazon's Hostname, throw exception + final String exceptionMessage = Messages.get( + "FederatedAuthPlugin.unsupportedHostname", + new Object[] {hostname}); + + LOGGER.fine(exceptionMessage); + throw new SQLException(exceptionMessage); + } + + // Check Region + final Optional regionOptional = Region.regions().stream() + .filter(r -> r.id().equalsIgnoreCase(rdsRegion)) + .findFirst(); + + if (!regionOptional.isPresent()) { + final String exceptionMessage = Messages.get( + "AwsSdk.unsupportedRegion", + new Object[] {rdsRegion}); + + LOGGER.fine(exceptionMessage); + throw new SQLException(exceptionMessage); + } + + return regionOptional.get(); + } + + String generateAuthenticationToken(final Properties props, final String hostname, + final int port, final Region region, final AwsCredentialsProvider awsCredentialsProvider) { + final TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); + final TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext( + TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED); + this.fetchTokenCounter.inc(); + try { + final String user = DB_USER.getString(props); + final RdsUtilities utilities = + RdsUtilities.builder().credentialsProvider(awsCredentialsProvider).region(region).build(); + return utilities.generateAuthenticationToken((builder) -> builder.hostname(hostname).port(port).username(user)); + } catch (final Exception e) { + telemetryContext.setSuccess(false); + telemetryContext.setException(e); + throw e; + } finally { + telemetryContext.closeContext(); + } + } + + private String getCacheKey( + final String user, + final String hostname, + final int port, + final Region region) { + + return String.format("%s:%s:%d:%s", region, hostname, port, user); + } + + public static void clearCache() { + tokenCache.clear(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java new file mode 100644 index 000000000..4236b46f3 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.security.GeneralSecurityException; +import java.util.Properties; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; + +public class FederatedAuthPluginFactory implements ConnectionPluginFactory { + + @Override + public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { + return new FederatedAuthPlugin(pluginService, getCredentialsProviderFactory(pluginService, props)); + } + + private CredentialsProviderFactory getCredentialsProviderFactory(final PluginService pluginService, + final Properties props) { + final String idpName = FederatedAuthPlugin.IDP_NAME.getString(props); + if (StringUtils.isNullOrEmpty(idpName) || AdfsCredentialsProviderFactory.IDP_NAME.equalsIgnoreCase(idpName)) { + return new AdfsCredentialsProviderFactory( + pluginService, + () -> { + try { + return new HttpClientFactory().getCloseableHttpClient( + FederatedAuthPlugin.HTTP_CLIENT_SOCKET_TIMEOUT.getInteger(props), + FederatedAuthPlugin.HTTP_CLIENT_CONNECT_TIMEOUT.getInteger(props), + FederatedAuthPlugin.SSL_INSECURE.getBoolean(props)); + } catch (GeneralSecurityException e) { + throw new RuntimeException( + Messages.get("FederatedAuthPluginFactory.failedToInitializeHttpClient"), e); + } + }); + } + throw new IllegalArgumentException(Messages.get("FederatedAuthPluginFactory.unsupportedIdp", + new Object[] {idpName})); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java new file mode 100644 index 000000000..db44ff92b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.security.GeneralSecurityException; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import org.apache.http.client.config.CookieSpecs; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultHttpRequestRetryHandler; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.client.LaxRedirectStrategy; + +/** + * Provides a HttpClient so that requests to HTTP API can be made. This is used by the + * {@link software.amazon.jdbc.plugin.federatedauth.AdfsCredentialsProviderFactory} to make HTTP calls to ADFS HTTP + * endpoints that are not available via SDK. + */ +public class HttpClientFactory { + private static final int MAX_REQUEST_RETRIES = 3; + + public CloseableHttpClient getCloseableHttpClient(final int socketTimeoutMs, final int connectionTimeoutMs, + final boolean keySslInsecure) throws GeneralSecurityException { + final RequestConfig rc = RequestConfig.custom() + .setSocketTimeout(socketTimeoutMs) + .setConnectTimeout(connectionTimeoutMs) + .setExpectContinueEnabled(false) + .setCookieSpec(CookieSpecs.STANDARD) + .build(); + + final HttpClientBuilder builder = HttpClients.custom() + .setDefaultRequestConfig(rc) + .setRedirectStrategy(new LaxRedirectStrategy()) + .setRetryHandler(new DefaultHttpRequestRetryHandler(MAX_REQUEST_RETRIES, true)) + .useSystemProperties(); // this is needed for proxy setting using system properties. + + if (keySslInsecure) { + final SSLContext ctx = SSLContext.getInstance("TLSv1.2"); + final TrustManager[] tma = new TrustManager[] {new NonValidatingSSLSocketFactory.NonValidatingTrustManager()}; + ctx.init(null, tma, null); + final SSLSocketFactory factory = ctx.getSocketFactory(); + + final SSLConnectionSocketFactory sf = new SSLConnectionSocketFactory( + factory, + new NoopHostnameVerifier()); + + builder.setSSLSocketFactory(sf); + } + + return builder.build(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java new file mode 100644 index 000000000..100d351b3 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +/** + * Provide a SSLSocketFactory that allows SSL connections to be made without validating the server's + * certificate. This is more convenient for some applications, but is less secure as it allows "man + * in the middle" attacks. + */ +public class NonValidatingSSLSocketFactory extends SSLSocketFactory { + + /** + * We provide a constructor that takes an unused argument solely because the ssl calling code will + * look for this constructor first and then fall back to the no argument constructor, so we avoid + * an exception and additional reflection lookups. + * + * @param arg input argument + * @throws GeneralSecurityException if something goes wrong + */ + public NonValidatingSSLSocketFactory(final String arg) throws GeneralSecurityException { + final SSLContext ctx = SSLContext.getInstance("TLS"); // or "SSL" ? + + ctx.init(null, new TrustManager[]{new NonValidatingTrustManager()}, null); + + factory = ctx.getSocketFactory(); + } + + protected SSLSocketFactory factory; + + public Socket createSocket(final InetAddress host, final int port) throws IOException { + return factory.createSocket(host, port); + } + + public Socket createSocket(final String host, final int port) throws IOException { + return factory.createSocket(host, port); + } + + public Socket createSocket(final String host, final int port, final InetAddress localHost, final int localPort) + throws IOException { + return factory.createSocket(host, port, localHost, localPort); + } + + public Socket createSocket(final InetAddress address, final int port, final InetAddress localAddress, + final int localPort) + throws IOException { + return factory.createSocket(address, port, localAddress, localPort); + } + + public Socket createSocket(final Socket socket, final String host, final int port, final boolean autoClose) + throws IOException { + return factory.createSocket(socket, host, port, autoClose); + } + + public String[] getDefaultCipherSuites() { + return factory.getDefaultCipherSuites(); + } + + public String[] getSupportedCipherSuites() { + return factory.getSupportedCipherSuites(); + } + + public static class NonValidatingTrustManager implements X509TrustManager { + + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + + public void checkClientTrusted(final X509Certificate[] certs, final String authType) { + } + + public void checkServerTrusted(final X509Certificate[] certs, final String authType) { + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java new file mode 100644 index 000000000..a2f1081cf --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.IAM_IDP_ARN; +import static software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.IAM_ROLE_ARN; + +import java.sql.SQLException; +import java.util.Properties; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleWithSamlCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest; + +public abstract class SamlCredentialsProviderFactory implements CredentialsProviderFactory { + + @Override + public AwsCredentialsProvider getAwsCredentialsProvider(final String host, final Region region, + final @NonNull Properties props) + throws SQLException { + + final String samlAssertion = getSamlAssertion(props); + + final AssumeRoleWithSamlRequest assumeRoleWithSamlRequest = AssumeRoleWithSamlRequest.builder() + .samlAssertion(samlAssertion) + .roleArn(IAM_ROLE_ARN.getString(props)) + .principalArn(IAM_IDP_ARN.getString(props)) + .build(); + + final StsClient stsClient = StsClient.builder() + .credentialsProvider(AnonymousCredentialsProvider.create()) + .region(region) + .build(); + + return StsAssumeRoleWithSamlCredentialsProvider.builder() + .refreshRequest(assumeRoleWithSamlRequest) + .asyncCredentialUpdateEnabled(true) + .stsClient(stsClient) + .build(); + } + + abstract String getSamlAssertion(final @NonNull Properties props) throws SQLException; +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java new file mode 100644 index 000000000..af6b6af06 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util; + +import software.amazon.jdbc.HostSpec; + +public class IamAuthUtils { + public static String getIamHost(final String iamHost, final HostSpec hostSpec) { + if (!StringUtils.isNullOrEmpty(iamHost)) { + return iamHost; + } + return hostSpec.getHost(); + } + + public static int getIamPort(final int iamDefaultPort, final HostSpec hostSpec, final int dialectDefaultPort) { + if (iamDefaultPort > 0) { + return iamDefaultPort; + } else if (hostSpec.isPortSpecified()) { + return hostSpec.getPort(); + } else { + return dialectDefaultPort; + } + } +} diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index b79c7d242..f703921c1 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -14,6 +14,15 @@ # limitations under the License. # +# ADFS Credentials Provider Getter +AdfsCredentialsProviderFactory.failedLogin=Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response: \n''{0}'' +AdfsCredentialsProviderFactory.getSamlAssertionFailed=Failed to get SAML Assertion due to exception: ''{0}'' +AdfsCredentialsProviderFactory.invalidHttpsUrl=Invalid HTTPS URL: ''{0}'' +AdfsCredentialsProviderFactory.signOnPagePostActionUrl=ADFS SignOn Action URL: ''{0}'' +AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed=ADFS SignOn Page POST action failed with HTTP status ''{0}'', reason phrase ''{1}'', and response ''{2}'' +AdfsCredentialsProviderFactory.signOnPageRequestFailed=ADFS SignOn Page Request Failed with HTTP status ''{0}'', reason phrase ''{1}'', and response ''{2}'' +AdfsCredentialsProviderFactory.signOnPageUrl=ADFS SignOn URL: ''{0}'' + # Aurora Host List Connection Plugin AuroraHostListConnectionPlugin.providerAlreadySet=Another dynamic host list provider has already been set: {0}. @@ -148,6 +157,16 @@ Failover.failedToUpdateCurrentHostspecAvailability=Failed to update current host Failover.noOperationsAfterConnectionClosed=No operations allowed after connection closed. Failover.invalidHostListProvider=Incorrect type of host list provider found, please ensure the correct host list provider is specified. The host list provider in use is: ''{0}'', the plugin is expected a cluster-aware host list provider such as the AuroraHostListProvider. +# Federated Authentication Connection Plugin +FederatedAuthPlugin.generatedNewIamToken=Generated new IAM token = ''{0}'' +FederatedAuthPlugin.javaStsSdkNotInClasspath=Required dependency 'AWS Java SDK for AWS Secret Token Service' is not on the classpath. +FederatedAuthPlugin.unhandledException=Unhandled exception: ''{0}'' +FederatedAuthPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected. +FederatedAuthPlugin.useCachedIamToken=Use cached IAM token = ''{0}'' + +# Federated Authentication Connection Plugin Factory +FederatedAuthPluginFactory.failedToInitializeHttpClient=Failed to initialize HttpClient. +FederatedAuthPluginFactory.unsupportedIdp=Unsupported Identity Provider ''{0}''. Please visit to the documentation for supported Identity Providers. # HikariPooledConnectionProvider HikariPooledConnectionProvider.errorConnectingWithDataSource=Unable to connect to ''{0}'' using the Hikari data source. @@ -173,7 +192,6 @@ HostSelector.roundRobinInvalidDefaultWeight=The provided default weight value is IamAuthConnectionPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected. IamAuthConnectionPlugin.useCachedIamToken=Use cached IAM token = ''{0}'' IamAuthConnectionPlugin.generatedNewIamToken=Generated new IAM token = ''{0}'' -IamAuthConnectionPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero. Falling back to default port. IamAuthConnectionPlugin.unhandledException=Unhandled exception: ''{0}'' IamAuthConnectionPlugin.connectException=Error occurred while opening a connection: ''{0}'' diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 938e05852..29b27b157 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -41,8 +41,9 @@ dependencies { testImplementation("com.zaxxer:HikariCP:4.+") // version 4.+ is compatible with Java 8 testImplementation("org.springframework.boot:spring-boot-starter-jdbc:2.7.13") // 2.7.13 is the last version compatible with Java 8 testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8 - testImplementation("software.amazon.awssdk:rds:2.20.49") testImplementation("software.amazon.awssdk:ec2:2.20.49") + testImplementation("software.amazon.awssdk:rds:2.20.49") + testImplementation("software.amazon.awssdk:sts:2.20.49") testImplementation("org.testcontainers:testcontainers:1.17.+") testImplementation("org.testcontainers:mysql:1.17.+") testImplementation("org.testcontainers:postgresql:1.17.+") diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java index 2323859ec..63741b7b5 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java @@ -102,7 +102,7 @@ public static void registerDrivers() throws SQLException { @Test public void testPostgresConnectValidTokenInCache() throws SQLException { IamAuthConnectionPlugin.tokenCache.put(PG_CACHE_KEY, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); @@ -114,7 +114,7 @@ public void testMySqlConnectValidTokenInCache() throws SQLException { props.setProperty(PropertyDefinition.USER.name, "mysqlUser"); props.setProperty(PropertyDefinition.PASSWORD.name, "mysqlPassword"); IamAuthConnectionPlugin.tokenCache.put(MYSQL_CACHE_KEY, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_MYSQL_PORT); @@ -130,7 +130,7 @@ public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLEx final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser"; IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); } @@ -145,7 +145,7 @@ public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":postgresqlUser"; IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); } @@ -153,7 +153,7 @@ public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() @Test public void testConnectExpiredTokenInCache() throws SQLException { IamAuthConnectionPlugin.tokenCache.put(PG_CACHE_KEY, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000))); when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); @@ -171,7 +171,7 @@ public void testConnectEmptyCache() throws SQLException { public void testConnectWithSpecifiedPort() throws SQLException { final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:" + "postgresqlUser"; IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); } @@ -183,7 +183,7 @@ public void testConnectWithSpecifiedIamDefaultPort() throws SQLException { final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + iamDefaultPort + ":postgresqlUser"; IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); } @@ -194,7 +194,7 @@ public void testConnectWithSpecifiedRegion() throws SQLException { "us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":" + "postgresqlUser"; props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-west-1"); IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewRegion, - new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java new file mode 100644 index 000000000..d3ea267ec --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.sql.SQLException; +import java.util.Properties; +import java.util.function.Supplier; +import org.apache.http.HttpEntity; +import org.apache.http.StatusLine; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.util.EntityUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testcontainers.shaded.org.apache.commons.io.IOUtils; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class AdfsCredentialsProviderFactoryTest { + + private static final String USERNAME = "someFederatedUsername@example.com"; + private static final String PASSWORD = "somePassword"; + @Mock private PluginService mockPluginService; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private Supplier mockHttpClientSupplier; + @Mock private CloseableHttpClient mockHttpClient; + @Mock private CloseableHttpResponse mockHttpGetSignInPageResponse; + @Mock private CloseableHttpResponse mockHttpPostSignInResponse; + @Mock private StatusLine mockStatusLine; + @Mock private HttpEntity mockSignInPageHttpEntity; + @Mock private HttpEntity mockSamlHttpEntity; + private AdfsCredentialsProviderFactory adfsCredentialsProviderFactory; + private Properties props; + + @BeforeEach + public void init() throws IOException { + MockitoAnnotations.openMocks(this); + + this.props = new Properties(); + this.props.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com"); + this.props.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, USERNAME); + this.props.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, PASSWORD); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); + when(mockHttpClientSupplier.get()).thenReturn(mockHttpClient); + when(mockHttpClient.execute(any(HttpGet.class))).thenReturn(mockHttpGetSignInPageResponse); + when(mockHttpGetSignInPageResponse.getStatusLine()).thenReturn(mockStatusLine); + when(mockStatusLine.getStatusCode()).thenReturn(200); + when(mockHttpGetSignInPageResponse.getEntity()).thenReturn(mockSignInPageHttpEntity); + + String signinPageHtml = IOUtils.toString( + this.getClass().getClassLoader().getResourceAsStream("federated_auth/adfs-sign-in-page.html"), "UTF-8"); + InputStream signInPageHtmlInputStream = new ByteArrayInputStream(signinPageHtml.getBytes()); + when(mockSignInPageHttpEntity.getContent()).thenReturn(signInPageHtmlInputStream); + + when(mockHttpClient.execute(any(HttpPost.class))).thenReturn(mockHttpPostSignInResponse); + when(mockHttpPostSignInResponse.getStatusLine()).thenReturn(mockStatusLine); + when(mockHttpPostSignInResponse.getEntity()).thenReturn(mockSamlHttpEntity); + + String adfsSamlHtml = IOUtils.toString( + this.getClass().getClassLoader().getResourceAsStream("federated_auth/adfs-saml.html"), "UTF-8"); + InputStream samlHtmlInputStream = new ByteArrayInputStream(adfsSamlHtml.getBytes()); + when(mockSamlHttpEntity.getContent()).thenReturn(samlHtmlInputStream); + + this.adfsCredentialsProviderFactory = new AdfsCredentialsProviderFactory(mockPluginService, mockHttpClientSupplier); + } + + @Test + void test() throws IOException, SQLException { + this.adfsCredentialsProviderFactory.getSamlAssertion(props); + + ArgumentCaptor httpPostArgumentCaptor = ArgumentCaptor.forClass(HttpPost.class); + verify(mockHttpClient, times(2)).execute(httpPostArgumentCaptor.capture()); + HttpPost actualHttpPost = httpPostArgumentCaptor.getValue(); + String content = EntityUtils.toString(actualHttpPost.getEntity()); + String[] params = content.split("&"); + assertEquals("UserName=" + USERNAME.replace("@", "%40"), params[0]); + assertEquals("Password=" + PASSWORD, params[1]); + assertEquals("Kmsi=true", params[2]); + assertEquals("AuthMethod=FormsAuthentication", params[3]); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java new file mode 100644 index 000000000..06a49d397 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java @@ -0,0 +1,188 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.federatedauth; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +class FederatedAuthPluginTest { + + private static final int DEFAULT_PORT = 1234; + private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; + + private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("pg.testdb.us-east-2.rds.amazonaws.com").build(); + private static final String DB_USER = "iamUser"; + private static final String TEST_TOKEN = "someTestToken"; + private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)); + @Mock private PluginService mockPluginService; + @Mock private Dialect mockDialect; + @Mock JdbcCallable mockLambda; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private TelemetryCounter mockTelemetryCounter; + @Mock private CredentialsProviderFactory mockCredentialsProviderFactory; + @Mock private AwsCredentialsProvider mockAwsCredentialsProvider; + @Mock private CompletableFuture completableFuture; + @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity; + private Properties props; + + @BeforeEach + public void init() throws ExecutionException, InterruptedException, SQLException { + MockitoAnnotations.openMocks(this); + props = new Properties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth"); + props.setProperty(FederatedAuthPlugin.DB_USER.name, DB_USER); + FederatedAuthPlugin.clearCache(); + + when(mockPluginService.getDialect()).thenReturn(mockDialect); + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT); + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter); + when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext); + when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any())) + .thenReturn(mockAwsCredentialsProvider); + when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture); + when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity); + } + + @Test + void testCachedToken() throws SQLException { + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + FederatedAuthPlugin.tokenCache.put(key, TEST_TOKEN_INFO); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testExpiredCachedToken() throws SQLException { + FederatedAuthPlugin spyPlugin = Mockito.spy( + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory)); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + String someExpiredToken = "someExpiredToken"; + TokenInfo expiredTokenInfo = new TokenInfo( + someExpiredToken, Instant.now().minusMillis(300000)); + FederatedAuthPlugin.tokenCache.put(key, expiredTokenInfo); + + when( + spyPlugin.generateAuthenticationToken( + props, + HOST_SPEC.getHost(), + DEFAULT_PORT, + Region.US_EAST_2, mockAwsCredentialsProvider)) + .thenReturn(TEST_TOKEN); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testNoCachedToken() throws SQLException { + FederatedAuthPlugin spyPlugin = Mockito.spy( + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory)); + + when( + spyPlugin.generateAuthenticationToken( + props, + HOST_SPEC.getHost(), + DEFAULT_PORT, + Region.US_EAST_2, mockAwsCredentialsProvider)) + .thenReturn(TEST_TOKEN); + + spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testSpecifiedIamHostPortRegion() throws SQLException { + final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com"; + final int expectedPort = 9876; + final Region expectedRegion = Region.US_WEST_2; + + props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost); + props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort)); + props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString()); + + final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + String.valueOf(expectedPort) + ":iamUser"; + FederatedAuthPlugin.tokenCache.put(key, TEST_TOKEN_INFO); + + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + } + + @Test + void testIdpCredentialsFallback() throws SQLException { + String expectedUser = "expectedUser"; + String expectedPassword = "expectedPassword"; + PropertyDefinition.USER.set(props, expectedUser); + PropertyDefinition.PASSWORD.set(props, expectedPassword); + + FederatedAuthPlugin plugin = + new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory); + + String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser"; + FederatedAuthPlugin.tokenCache.put(key, TEST_TOKEN_INFO); + + plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda); + + assertEquals(DB_USER, PropertyDefinition.USER.getString(props)); + assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props)); + assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props)); + assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props)); + } +} diff --git a/wrapper/src/test/resources/federated_auth/adfs-saml.html b/wrapper/src/test/resources/federated_auth/adfs-saml.html new file mode 100644 index 000000000..6686e3db8 --- /dev/null +++ b/wrapper/src/test/resources/federated_auth/adfs-saml.html @@ -0,0 +1 @@ +Working...
diff --git a/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html b/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html new file mode 100644 index 000000000..0b06f6dda --- /dev/null +++ b/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html @@ -0,0 +1,613 @@ + + + + + + + + + + + + Sign In + + + + + + + + + + +
+

JavaScript required

+

JavaScript is required. This web browser does not support JavaScript or JavaScript in this web browser is not enabled.

+

To find out if your web browser supports JavaScript or to enable JavaScript, see web browser help.

+
+ +
+
+
+
+
+
+ +
+
+ +
+ + + +
+
Sign in
+ +
+
+ +
+ +
+
+ + +
+ +
+ + +
+ +
+ Sign in +
+
+ +
+ +
+
+ + + + +
+
+ +
+ +
+ + +
+ +
+ +
+
+
+
+
+ +
+
+
+ + + + + +