diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts
index 40c4655aa..864ea90d0 100644
--- a/examples/AWSDriverExample/build.gradle.kts
+++ b/examples/AWSDriverExample/build.gradle.kts
@@ -20,6 +20,7 @@ dependencies {
implementation("mysql:mysql-connector-java:8.0.33")
implementation("software.amazon.awssdk:rds:2.21.42")
implementation("software.amazon.awssdk:secretsmanager:2.21.43")
+ implementation("software.amazon.awssdk:sts:2.21.42")
implementation("com.fasterxml.jackson.core:jackson-databind:2.16.0")
implementation(project(":aws-advanced-jdbc-wrapper"))
implementation("io.opentelemetry:opentelemetry-api:1.32.0")
diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java
new file mode 100644
index 000000000..8953d986d
--- /dev/null
+++ b/examples/AWSDriverExample/src/main/java/software/amazon/FederatedAuthPluginExample.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon;
+
+import software.amazon.jdbc.PropertyDefinition;
+import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin;
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.Properties;
+
+public class FederatedAuthPluginExample {
+
+ private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-identifier.XYZ.us-east-2.rds.amazonaws.com:5432/employees";
+
+ public static void main(String[] args) throws SQLException {
+ // Set the AWS Federated Authentication Connection Plugin parameters and the JDBC Wrapper parameters.
+ final Properties properties = new Properties();
+
+ // Enable the AWS Federated Authentication Connection Plugin.
+ properties.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth");
+ properties.setProperty(FederatedAuthPlugin.IDP_NAME.name, "adfs");
+ properties.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com");
+ properties.setProperty(FederatedAuthPlugin.IAM_ROLE_ARN.name, "arn:aws:iam::123456789012:role/adfs_example_iam_role");
+ properties.setProperty(FederatedAuthPlugin.IAM_IDP_ARN.name, "arn:aws:iam::123456789012:saml-provider/adfs_example");
+ properties.setProperty(FederatedAuthPlugin.IAM_REGION.name, "us-east-2");
+ properties.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, "someFederatedUsername@example.com");
+ properties.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, "somePassword");
+ properties.setProperty(PropertyDefinition.USER.name, "someIamUser");
+
+ // Try and make a connection:
+ try (final Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties);
+ final Statement statement = conn.createStatement();
+ final ResultSet rs = statement.executeQuery("SELECT 1")) {
+ System.out.println(Util.getResult(rs));
+ }
+ }
+}
diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts
index 87ed00590..0155dabf3 100644
--- a/wrapper/build.gradle.kts
+++ b/wrapper/build.gradle.kts
@@ -28,7 +28,9 @@ plugins {
dependencies {
implementation("org.checkerframework:checker-qual:3.40.0")
+ compileOnly("org.apache.httpcomponents:httpclient:4.5.14")
compileOnly("software.amazon.awssdk:rds:2.21.42")
+ compileOnly("software.amazon.awssdk:sts:2.21.42")
compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8
compileOnly("software.amazon.awssdk:secretsmanager:2.21.43")
compileOnly("com.fasterxml.jackson.core:jackson-databind:2.16.0")
@@ -61,6 +63,7 @@ dependencies {
testImplementation("software.amazon.awssdk:rds:2.21.42")
testImplementation("software.amazon.awssdk:ec2:2.21.12")
testImplementation("software.amazon.awssdk:secretsmanager:2.21.43")
+ testImplementation("software.amazon.awssdk:sts:2.21.42")
testImplementation("org.testcontainers:testcontainers:1.19.3")
testImplementation("org.testcontainers:mysql:1.19.3")
testImplementation("org.testcontainers:postgresql:1.19.3")
diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java
index 353e3eeb6..2093f4999 100644
--- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java
+++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java
@@ -39,6 +39,7 @@
import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory;
import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory;
import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory;
+import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory;
import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory;
import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPluginFactory;
import software.amazon.jdbc.plugin.strategy.fastestresponse.FastestResponseStrategyPluginFactory;
@@ -66,6 +67,7 @@ public class ConnectionPluginChainBuilder {
put("failover", FailoverConnectionPluginFactory.class);
put("iam", IamAuthConnectionPluginFactory.class);
put("awsSecretsManager", AwsSecretsManagerConnectionPluginFactory.class);
+ put("federatedAuth", FederatedAuthPluginFactory.class);
put("auroraStaleDns", AuroraStaleDnsPluginFactory.class);
put("readWriteSplitting", ReadWriteSplittingPluginFactory.class);
put("auroraConnectionTracker", AuroraConnectionTrackerPluginFactory.class);
@@ -96,7 +98,8 @@ public class ConnectionPluginChainBuilder {
put(FastestResponseStrategyPluginFactory.class, 900);
put(IamAuthConnectionPluginFactory.class, 1000);
put(AwsSecretsManagerConnectionPluginFactory.class, 1100);
- put(LogQueryConnectionPluginFactory.class, 1200);
+ put(FederatedAuthPluginFactory.class, 1200);
+ put(LogQueryConnectionPluginFactory.class, 1300);
put(ConnectTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
put(ExecutionTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
put(DeveloperConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java
index 69fa0c813..d56aafaa2 100644
--- a/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java
@@ -36,6 +36,7 @@
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.authentication.AwsCredentialsManager;
+import software.amazon.jdbc.util.IamAuthUtils;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
@@ -64,7 +65,7 @@ public class IamAuthConnectionPlugin extends AbstractConnectionPlugin {
"Overrides the host that is used to generate the IAM token");
public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty(
- "iamDefaultPort", null,
+ "iamDefaultPort", "-1",
"Overrides default port that is used to generate the IAM token");
public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty(
@@ -115,12 +116,12 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro
throw new SQLException(PropertyDefinition.USER.name + " is null or empty.");
}
- String host = hostSpec.getHost();
- if (!StringUtils.isNullOrEmpty(IAM_HOST.getString(props))) {
- host = IAM_HOST.getString(props);
- }
+ String host = IamAuthUtils.getIamHost(IAM_HOST.getString(props), hostSpec);
- int port = getPort(props, hostSpec);
+ int port = IamAuthUtils.getIamPort(
+ IAM_DEFAULT_PORT.getInteger(props),
+ hostSpec,
+ this.pluginService.getDialect().getDefaultPort());
final String iamRegion = IAM_REGION.getString(props);
final Region region = StringUtils.isNullOrEmpty(iamRegion)
@@ -261,26 +262,6 @@ public static void clearCache() {
tokenCache.clear();
}
- private int getPort(Properties props, HostSpec hostSpec) {
- if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) {
- int defaultPort = IAM_DEFAULT_PORT.getInteger(props);
- if (defaultPort > 0) {
- return defaultPort;
- } else {
- LOGGER.finest(
- () -> Messages.get(
- "IamAuthConnectionPlugin.invalidPort",
- new Object[] {defaultPort}));
- }
- }
-
- if (hostSpec.isPortSpecified()) {
- return hostSpec.getPort();
- } else {
- return this.pluginService.getDialect().getDefaultPort();
- }
- }
-
private Region getRdsRegion(final String hostname) throws SQLException {
// Get Region
@@ -312,27 +293,4 @@ private Region getRdsRegion(final String hostname) throws SQLException {
return regionOptional.get();
}
-
- static class TokenInfo {
-
- private final String token;
- private final Instant expiration;
-
- public TokenInfo(final String token, final Instant expiration) {
- this.token = token;
- this.expiration = expiration;
- }
-
- public String getToken() {
- return this.token;
- }
-
- public Instant getExpiration() {
- return this.expiration;
- }
-
- public boolean isExpired() {
- return Instant.now().isAfter(this.expiration);
- }
- }
}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java
new file mode 100644
index 000000000..3fad01093
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin;
+
+import java.time.Instant;
+
+public class TokenInfo {
+ private final String token;
+ private final Instant expiration;
+
+ public TokenInfo(final String token, final Instant expiration) {
+ this.token = token;
+ this.expiration = expiration;
+ }
+
+ public String getToken() {
+ return this.token;
+ }
+
+ public Instant getExpiration() {
+ return this.expiration;
+ }
+
+ public boolean isExpired() {
+ return Instant.now().isAfter(this.expiration);
+ }
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java
new file mode 100644
index 000000000..aa23a8f45
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactory.java
@@ -0,0 +1,252 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import java.io.IOException;
+import java.net.URI;
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Properties;
+import java.util.Set;
+import java.util.function.Supplier;
+import java.util.logging.Logger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import org.apache.http.NameValuePair;
+import org.apache.http.StatusLine;
+import org.apache.http.client.entity.UrlEncodedFormEntity;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.message.BasicNameValuePair;
+import org.apache.http.util.EntityUtils;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import software.amazon.jdbc.PluginService;
+import software.amazon.jdbc.util.Messages;
+import software.amazon.jdbc.util.StringUtils;
+import software.amazon.jdbc.util.telemetry.TelemetryContext;
+import software.amazon.jdbc.util.telemetry.TelemetryFactory;
+import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;
+
+public class AdfsCredentialsProviderFactory extends SamlCredentialsProviderFactory {
+
+ public static final String IDP_NAME = "adfs";
+ private static final String TELEMETRY_FETCH_SAML = "Fetch ADFS SAML Assertion";
+ private static final Pattern INPUT_TAG_PATTERN = Pattern.compile("", Pattern.DOTALL);
+ private static final Pattern FORM_ACTION_PATTERN = Pattern.compile("
httpClientSupplier;
+ private TelemetryContext telemetryContext;
+
+ public AdfsCredentialsProviderFactory(final PluginService pluginService,
+ final Supplier httpClientSupplier) {
+ this.pluginService = pluginService;
+ this.telemetryFactory = this.pluginService.getTelemetryFactory();
+ this.httpClientSupplier = httpClientSupplier;
+ }
+
+ @Override
+ String getSamlAssertion(final @NonNull Properties props) throws SQLException {
+ this.telemetryContext = telemetryFactory.openTelemetryContext(TELEMETRY_FETCH_SAML, TelemetryTraceLevel.NESTED);
+ try (final CloseableHttpClient httpClient = httpClientSupplier.get()) {
+ String uri = getSignInPageUrl(props);
+ final String signInPageBody = getSignInPageBody(httpClient, uri);
+ final String action = getFormActionFromHtmlBody(signInPageBody);
+
+ if (!StringUtils.isNullOrEmpty(action) && action.startsWith("/")) {
+ uri = getFormActionUrl(props, action);
+ }
+
+ final List params = getParametersFromHtmlBody(signInPageBody, props);
+ final String content = getFormActionBody(httpClient, uri, params);
+
+ final Matcher matcher = FederatedAuthPlugin.SAML_RESPONSE_PATTERN.matcher(content);
+ if (!matcher.find()) {
+ throw new IOException(Messages.get("AdfsCredentialsProviderFactory.failedLogin", new Object[] {content}));
+ }
+
+ // return SAML Response value
+ return matcher.group(FederatedAuthPlugin.SAML_RESPONSE_PATTERN_GROUP);
+ } catch (final IOException e) {
+ LOGGER.severe(Messages.get("AdfsCredentialsProviderFactory.getSamlAssertionFailed", new Object[] {e}));
+ this.telemetryContext.setSuccess(false);
+ this.telemetryContext.setException(e);
+ throw new SQLException(e);
+ } finally {
+ this.telemetryContext.closeContext();
+ }
+ }
+
+ private String getSignInPageBody(final CloseableHttpClient httpClient, final String uri) throws IOException {
+ LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPageUrl", new Object[] {uri}));
+ validateUrl(uri);
+ final HttpGet get = new HttpGet(uri);
+ try (final CloseableHttpResponse resp = httpClient.execute(get)) {
+ final StatusLine statusLine = resp.getStatusLine();
+ // Check HTTP Status Code is 2xx Success
+ if (statusLine.getStatusCode() / 100 != 2) {
+ throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPageRequestFailed",
+ new Object[] {
+ statusLine.getStatusCode(),
+ statusLine.getReasonPhrase(),
+ EntityUtils.toString(resp.getEntity())}));
+ }
+ return EntityUtils.toString(resp.getEntity());
+ }
+ }
+
+ private String getFormActionBody(final CloseableHttpClient httpClient, final String uri,
+ final List params) throws IOException {
+ LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionUrl", new Object[] {uri}));
+ validateUrl(uri);
+ final HttpPost post = new HttpPost(uri);
+ post.setEntity(new UrlEncodedFormEntity(params));
+ try (final CloseableHttpResponse resp = httpClient.execute(post)) {
+ final StatusLine statusLine = resp.getStatusLine();
+ // Check HTTP Status Code is 2xx Success
+ if (statusLine.getStatusCode() / 100 != 2) {
+ throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed",
+ new Object[] {
+ statusLine.getStatusCode(),
+ statusLine.getReasonPhrase(),
+ EntityUtils.toString(resp.getEntity())}));
+ }
+ return EntityUtils.toString(resp.getEntity());
+ }
+ }
+
+ private String getSignInPageUrl(final Properties props) {
+ return "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':'
+ + FederatedAuthPlugin.IDP_PORT.getString(props) + "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp="
+ + FederatedAuthPlugin.RELAYING_PARTY_ID.getString(props);
+ }
+
+ private String getFormActionUrl(final Properties props, final String action) {
+ return "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':'
+ + FederatedAuthPlugin.IDP_PORT.getString(props) + action;
+ }
+
+ private List getInputTagsFromHTML(final String body) {
+ final Set distinctInputTags = new HashSet<>();
+ final List inputTags = new ArrayList<>();
+ final Matcher inputTagMatcher = INPUT_TAG_PATTERN.matcher(body);
+ while (inputTagMatcher.find()) {
+ final String tag = inputTagMatcher.group(0);
+ final String tagNameLower = getValueByKey(tag, "name").toLowerCase();
+ if (!tagNameLower.isEmpty() && distinctInputTags.add(tagNameLower)) {
+ inputTags.add(tag);
+ }
+ }
+ return inputTags;
+ }
+
+ private String getValueByKey(final String input, final String key) {
+ final Pattern keyValuePattern = Pattern.compile("(" + Pattern.quote(key) + ")\\s*=\\s*\"(.*?)\"");
+ final Matcher keyValueMatcher = keyValuePattern.matcher(input);
+ if (keyValueMatcher.find()) {
+ return escapeHtmlEntity(keyValueMatcher.group(2));
+ }
+ return "";
+ }
+
+ private String escapeHtmlEntity(final String html) {
+ final StringBuilder sb = new StringBuilder(html.length());
+ int i = 0;
+ final int length = html.length();
+ while (i < length) {
+ final char c = html.charAt(i);
+ if (c != '&') {
+ sb.append(c);
+ i++;
+ continue;
+ }
+
+ if (html.startsWith("&", i)) {
+ sb.append('&');
+ i += 5;
+ } else if (html.startsWith("'", i)) {
+ sb.append('\'');
+ i += 6;
+ } else if (html.startsWith(""", i)) {
+ sb.append('"');
+ i += 6;
+ } else if (html.startsWith("<", i)) {
+ sb.append('<');
+ i += 4;
+ } else if (html.startsWith(">", i)) {
+ sb.append('>');
+ i += 4;
+ } else {
+ sb.append(c);
+ ++i;
+ }
+ }
+ return sb.toString();
+ }
+
+ private List getParametersFromHtmlBody(final String body, final @NonNull Properties props) {
+ final List parameters = new ArrayList<>();
+ for (final String inputTag : getInputTagsFromHTML(body)) {
+ final String name = getValueByKey(inputTag, "name");
+ final String value = getValueByKey(inputTag, "value");
+ final String nameLower = name.toLowerCase();
+
+ if (nameLower.contains("username")) {
+ parameters.add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_USERNAME.getString(props)));
+ } else if (nameLower.contains("authmethod")) {
+ if (!value.isEmpty()) {
+ parameters.add(new BasicNameValuePair(name, value));
+ }
+ } else if (nameLower.contains("password")) {
+ parameters
+ .add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_PASSWORD.getString(props)));
+ } else if (!name.isEmpty()) {
+ parameters.add(new BasicNameValuePair(name, value));
+ }
+ }
+ return parameters;
+ }
+
+ private String getFormActionFromHtmlBody(final String body) {
+ final Matcher m = FORM_ACTION_PATTERN.matcher(body);
+ if (m.find()) {
+ return escapeHtmlEntity(m.group(1));
+ }
+ return null;
+ }
+
+ private void validateUrl(final String paramString) throws IOException {
+
+ final URI authorizeRequestUrl = URI.create(paramString);
+ final String errorMessage = Messages.get("AdfsCredentialsProviderFactory.invalidHttpsUrl",
+ new Object[] {paramString});
+
+ if (!authorizeRequestUrl.toURL().getProtocol().equalsIgnoreCase("https")) {
+ throw new IOException(errorMessage);
+ }
+
+ final Matcher matcher = FederatedAuthPlugin.HTTPS_URL_PATTERN.matcher(paramString);
+ if (!matcher.find()) {
+ throw new IOException(errorMessage);
+ }
+ }
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java
new file mode 100644
index 000000000..a43396bf9
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/CredentialsProviderFactory.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import java.io.Closeable;
+import java.sql.SQLException;
+import java.util.Properties;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
+
+public interface CredentialsProviderFactory {
+ AwsCredentialsProvider getAwsCredentialsProvider(String host, Region region, final @NonNull Properties props) throws
+ SQLException;
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java
new file mode 100644
index 000000000..7def73fec
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java
@@ -0,0 +1,319 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Optional;
+import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.logging.Logger;
+import java.util.regex.Pattern;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.rds.RdsUtilities;
+import software.amazon.jdbc.AwsWrapperProperty;
+import software.amazon.jdbc.HostSpec;
+import software.amazon.jdbc.JdbcCallable;
+import software.amazon.jdbc.PluginService;
+import software.amazon.jdbc.PropertyDefinition;
+import software.amazon.jdbc.plugin.AbstractConnectionPlugin;
+import software.amazon.jdbc.plugin.TokenInfo;
+import software.amazon.jdbc.util.IamAuthUtils;
+import software.amazon.jdbc.util.Messages;
+import software.amazon.jdbc.util.RdsUtils;
+import software.amazon.jdbc.util.StringUtils;
+import software.amazon.jdbc.util.telemetry.TelemetryContext;
+import software.amazon.jdbc.util.telemetry.TelemetryCounter;
+import software.amazon.jdbc.util.telemetry.TelemetryFactory;
+import software.amazon.jdbc.util.telemetry.TelemetryGauge;
+import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;
+
+public class FederatedAuthPlugin extends AbstractConnectionPlugin {
+
+ static final ConcurrentHashMap tokenCache = new ConcurrentHashMap<>();
+ private final CredentialsProviderFactory credentialsProviderFactory;
+ private static final int DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 - 30;
+ private static final int DEFAULT_HTTP_TIMEOUT_MILLIS = 60000;
+ public static final AwsWrapperProperty IDP_ENDPOINT = new AwsWrapperProperty("idpEndpoint", null,
+ "The hosting URL of the Identity Provider");
+ public static final AwsWrapperProperty IDP_PORT =
+ new AwsWrapperProperty("idpPort", "443", "The hosting port of Identity Provider");
+ public static final AwsWrapperProperty RELAYING_PARTY_ID =
+ new AwsWrapperProperty("rpIdentifier", "urn:amazon:webservices", "The relaying party identifier");
+ public static final AwsWrapperProperty IAM_ROLE_ARN =
+ new AwsWrapperProperty("iamRoleArn", null, "The ARN of the IAM Role that is to be assumed.");
+ public static final AwsWrapperProperty IAM_IDP_ARN =
+ new AwsWrapperProperty("iamIdpArn", null, "The ARN of the Identity Provider");
+ public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty("iamRegion", null,
+ "Overrides AWS region that is used to generate the IAM token");
+ public static final AwsWrapperProperty IAM_TOKEN_EXPIRATION = new AwsWrapperProperty("iamTokenExpiration",
+ String.valueOf(DEFAULT_TOKEN_EXPIRATION_SEC), "IAM token cache expiration in seconds");
+ public static final AwsWrapperProperty IDP_USERNAME =
+ new AwsWrapperProperty("idpUsername", null, "The federated user name");
+ public static final AwsWrapperProperty IDP_PASSWORD = new AwsWrapperProperty("idpPassword", null,
+ "The federated user password");
+ public static final AwsWrapperProperty IAM_HOST = new AwsWrapperProperty(
+ "iamHost", null,
+ "Overrides the host that is used to generate the IAM token");
+ public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty("iamDefaultPort", "-1",
+ "Overrides default port that is used to generate the IAM token");
+ public static final AwsWrapperProperty HTTP_CLIENT_SOCKET_TIMEOUT = new AwsWrapperProperty(
+ "httpClientSocketTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT_MILLIS),
+ "The socket timeout value in milliseconds for the HttpClient used by the FederatedAuthPlugin");
+ public static final AwsWrapperProperty HTTP_CLIENT_CONNECT_TIMEOUT = new AwsWrapperProperty(
+ "httpClientConnectTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT_MILLIS),
+ "The connect timeout value in milliseconds for the HttpClient used by the FederatedAuthPlugin");
+ public static final AwsWrapperProperty SSL_INSECURE = new AwsWrapperProperty("sslInsecure", "true",
+ "Whether or not the SSL session is to be secure and the sever's certificates will be verified");
+ public static AwsWrapperProperty
+ IDP_NAME = new AwsWrapperProperty("idpName", null, "The name of the Identity Provider implementation used");
+ public static final AwsWrapperProperty DB_USER =
+ new AwsWrapperProperty("dbUser", null, "The database user used to access the database");
+ protected static final Pattern SAML_RESPONSE_PATTERN = Pattern.compile("SAMLResponse\\W+value=\"(?[^\"]+)\"");
+ protected static final String SAML_RESPONSE_PATTERN_GROUP = "saml";
+ protected static final Pattern HTTPS_URL_PATTERN =
+ Pattern.compile("^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']");
+
+ private static final String TELEMETRY_FETCH_TOKEN = "fetch IAM token";
+ private static final Logger LOGGER = Logger.getLogger(FederatedAuthPlugin.class.getName());
+
+ protected final PluginService pluginService;
+
+ protected final RdsUtils rdsUtils = new RdsUtils();
+
+ private static final Set subscribedMethods =
+ Collections.unmodifiableSet(new HashSet() {
+ {
+ add("connect");
+ add("forceConnect");
+ }
+ });
+
+ static {
+ PropertyDefinition.registerPluginProperties(FederatedAuthPlugin.class);
+ }
+
+ private final TelemetryFactory telemetryFactory;
+ private final TelemetryGauge cacheSizeGauge;
+ private final TelemetryCounter fetchTokenCounter;
+
+ @Override
+ public Set getSubscribedMethods() {
+ return subscribedMethods;
+ }
+
+ public FederatedAuthPlugin(final PluginService pluginService,
+ final CredentialsProviderFactory credentialsProviderFactory) {
+ try {
+ Class.forName("software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest");
+ } catch (final ClassNotFoundException e) {
+ throw new RuntimeException(Messages.get("FederatedAuthPlugin.javaStsSdkNotInClasspath"));
+ }
+ this.pluginService = pluginService;
+ this.credentialsProviderFactory = credentialsProviderFactory;
+ this.telemetryFactory = pluginService.getTelemetryFactory();
+ this.cacheSizeGauge = telemetryFactory.createGauge("federatedAuth.tokenCache.size", () -> (long) tokenCache.size());
+ this.fetchTokenCounter = telemetryFactory.createCounter("federatedAuth.fetchToken.count");
+ }
+
+
+ @Override
+ public Connection connect(
+ final String driverProtocol,
+ final HostSpec hostSpec,
+ final Properties props,
+ final boolean isInitialConnection,
+ final JdbcCallable connectFunc)
+ throws SQLException {
+ return connectInternal(hostSpec, props, connectFunc);
+ }
+
+ @Override
+ public Connection forceConnect(
+ final @NonNull String driverProtocol,
+ final @NonNull HostSpec hostSpec,
+ final @NonNull Properties props,
+ final boolean isInitialConnection,
+ final @NonNull JdbcCallable forceConnectFunc)
+ throws SQLException {
+ return connectInternal(hostSpec, props, forceConnectFunc);
+ }
+
+ private Connection connectInternal(final HostSpec hostSpec, final Properties props,
+ final JdbcCallable connectFunc) throws SQLException {
+
+ checkIdpCredentialsWithFallback(props);
+
+ final String host = IamAuthUtils.getIamHost(IAM_HOST.getString(props), hostSpec);
+
+ final int port = IamAuthUtils.getIamPort(
+ IAM_DEFAULT_PORT.getInteger(props),
+ hostSpec,
+ this.pluginService.getDialect().getDefaultPort());
+
+ final Region region = getRegion(host, props);
+
+ final String cacheKey = getCacheKey(
+ DB_USER.getString(props),
+ host,
+ port,
+ region);
+
+ final TokenInfo tokenInfo = tokenCache.get(cacheKey);
+
+ final boolean isCachedToken = tokenInfo != null && !tokenInfo.isExpired();
+
+ if (isCachedToken) {
+ LOGGER.finest(
+ () -> Messages.get(
+ "FederatedAuthPlugin.useCachedIamToken",
+ new Object[] {tokenInfo.getToken()}));
+ PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken());
+ } else {
+ updateAuthenticationToken(hostSpec, props, region, cacheKey);
+ }
+
+ PropertyDefinition.USER.set(props, DB_USER.getString(props));
+
+ try {
+ return connectFunc.call();
+ } catch (final SQLException exception) {
+ updateAuthenticationToken(hostSpec, props, region, cacheKey);
+ return connectFunc.call();
+ } catch (final Exception exception) {
+ LOGGER.warning(
+ () -> Messages.get(
+ "FederatedAuthPlugin.unhandledException",
+ new Object[] {exception}));
+ throw new SQLException(exception);
+ }
+ }
+
+ private void checkIdpCredentialsWithFallback(final Properties props) {
+ if (IDP_USERNAME.getString(props) == null) {
+ IDP_USERNAME.set(props, PropertyDefinition.USER.getString(props));
+ }
+
+ if (IDP_PASSWORD.getString(props) == null) {
+ IDP_PASSWORD.set(props, PropertyDefinition.PASSWORD.getString(props));
+ }
+ }
+
+ private void updateAuthenticationToken(final HostSpec hostSpec, final Properties props, final Region region,
+ final String cacheKey)
+ throws SQLException {
+ final int tokenExpirationSec = IAM_TOKEN_EXPIRATION.getInteger(props);
+ final Instant tokenExpiry = Instant.now().plus(tokenExpirationSec, ChronoUnit.SECONDS);
+ final int port = IamAuthUtils.getIamPort(
+ StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props)) ? 0 : IAM_DEFAULT_PORT.getInteger(props),
+ hostSpec,
+ this.pluginService.getDialect().getDefaultPort());
+ final AwsCredentialsProvider credentialsProvider =
+ this.credentialsProviderFactory.getAwsCredentialsProvider(hostSpec.getHost(), region, props);
+ final String token = generateAuthenticationToken(
+ props,
+ hostSpec.getHost(),
+ port,
+ region,
+ credentialsProvider);
+ LOGGER.finest(
+ () -> Messages.get(
+ "FederatedAuthPlugin.generatedNewIamToken",
+ new Object[] {token}));
+ PropertyDefinition.PASSWORD.set(props, token);
+ tokenCache.put(
+ cacheKey,
+ new TokenInfo(token, tokenExpiry));
+ }
+
+ private Region getRegion(final String hostname, final Properties props) throws SQLException {
+ final String iamRegion = IAM_REGION.getString(props);
+ if (!StringUtils.isNullOrEmpty(iamRegion)) {
+ return Region.of(iamRegion);
+ }
+
+ // Fallback to using host
+ // Get Region
+ final String rdsRegion = rdsUtils.getRdsRegion(hostname);
+
+ if (StringUtils.isNullOrEmpty(rdsRegion)) {
+ // Does not match Amazon's Hostname, throw exception
+ final String exceptionMessage = Messages.get(
+ "FederatedAuthPlugin.unsupportedHostname",
+ new Object[] {hostname});
+
+ LOGGER.fine(exceptionMessage);
+ throw new SQLException(exceptionMessage);
+ }
+
+ // Check Region
+ final Optional regionOptional = Region.regions().stream()
+ .filter(r -> r.id().equalsIgnoreCase(rdsRegion))
+ .findFirst();
+
+ if (!regionOptional.isPresent()) {
+ final String exceptionMessage = Messages.get(
+ "AwsSdk.unsupportedRegion",
+ new Object[] {rdsRegion});
+
+ LOGGER.fine(exceptionMessage);
+ throw new SQLException(exceptionMessage);
+ }
+
+ return regionOptional.get();
+ }
+
+ String generateAuthenticationToken(final Properties props, final String hostname,
+ final int port, final Region region, final AwsCredentialsProvider awsCredentialsProvider) {
+ final TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory();
+ final TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext(
+ TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED);
+ this.fetchTokenCounter.inc();
+ try {
+ final String user = DB_USER.getString(props);
+ final RdsUtilities utilities =
+ RdsUtilities.builder().credentialsProvider(awsCredentialsProvider).region(region).build();
+ return utilities.generateAuthenticationToken((builder) -> builder.hostname(hostname).port(port).username(user));
+ } catch (final Exception e) {
+ telemetryContext.setSuccess(false);
+ telemetryContext.setException(e);
+ throw e;
+ } finally {
+ telemetryContext.closeContext();
+ }
+ }
+
+ private String getCacheKey(
+ final String user,
+ final String hostname,
+ final int port,
+ final Region region) {
+
+ return String.format("%s:%s:%d:%s", region, hostname, port, user);
+ }
+
+ public static void clearCache() {
+ tokenCache.clear();
+ }
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java
new file mode 100644
index 000000000..4236b46f3
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginFactory.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import java.security.GeneralSecurityException;
+import java.util.Properties;
+import software.amazon.jdbc.ConnectionPlugin;
+import software.amazon.jdbc.ConnectionPluginFactory;
+import software.amazon.jdbc.PluginService;
+import software.amazon.jdbc.util.Messages;
+import software.amazon.jdbc.util.StringUtils;
+
+public class FederatedAuthPluginFactory implements ConnectionPluginFactory {
+
+ @Override
+ public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) {
+ return new FederatedAuthPlugin(pluginService, getCredentialsProviderFactory(pluginService, props));
+ }
+
+ private CredentialsProviderFactory getCredentialsProviderFactory(final PluginService pluginService,
+ final Properties props) {
+ final String idpName = FederatedAuthPlugin.IDP_NAME.getString(props);
+ if (StringUtils.isNullOrEmpty(idpName) || AdfsCredentialsProviderFactory.IDP_NAME.equalsIgnoreCase(idpName)) {
+ return new AdfsCredentialsProviderFactory(
+ pluginService,
+ () -> {
+ try {
+ return new HttpClientFactory().getCloseableHttpClient(
+ FederatedAuthPlugin.HTTP_CLIENT_SOCKET_TIMEOUT.getInteger(props),
+ FederatedAuthPlugin.HTTP_CLIENT_CONNECT_TIMEOUT.getInteger(props),
+ FederatedAuthPlugin.SSL_INSECURE.getBoolean(props));
+ } catch (GeneralSecurityException e) {
+ throw new RuntimeException(
+ Messages.get("FederatedAuthPluginFactory.failedToInitializeHttpClient"), e);
+ }
+ });
+ }
+ throw new IllegalArgumentException(Messages.get("FederatedAuthPluginFactory.unsupportedIdp",
+ new Object[] {idpName}));
+ }
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java
new file mode 100644
index 000000000..db44ff92b
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/HttpClientFactory.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import java.security.GeneralSecurityException;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+import org.apache.http.client.config.CookieSpecs;
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.conn.ssl.NoopHostnameVerifier;
+import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.DefaultHttpRequestRetryHandler;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.impl.client.LaxRedirectStrategy;
+
+/**
+ * Provides a HttpClient so that requests to HTTP API can be made. This is used by the
+ * {@link software.amazon.jdbc.plugin.federatedauth.AdfsCredentialsProviderFactory} to make HTTP calls to ADFS HTTP
+ * endpoints that are not available via SDK.
+ */
+public class HttpClientFactory {
+ private static final int MAX_REQUEST_RETRIES = 3;
+
+ public CloseableHttpClient getCloseableHttpClient(final int socketTimeoutMs, final int connectionTimeoutMs,
+ final boolean keySslInsecure) throws GeneralSecurityException {
+ final RequestConfig rc = RequestConfig.custom()
+ .setSocketTimeout(socketTimeoutMs)
+ .setConnectTimeout(connectionTimeoutMs)
+ .setExpectContinueEnabled(false)
+ .setCookieSpec(CookieSpecs.STANDARD)
+ .build();
+
+ final HttpClientBuilder builder = HttpClients.custom()
+ .setDefaultRequestConfig(rc)
+ .setRedirectStrategy(new LaxRedirectStrategy())
+ .setRetryHandler(new DefaultHttpRequestRetryHandler(MAX_REQUEST_RETRIES, true))
+ .useSystemProperties(); // this is needed for proxy setting using system properties.
+
+ if (keySslInsecure) {
+ final SSLContext ctx = SSLContext.getInstance("TLSv1.2");
+ final TrustManager[] tma = new TrustManager[] {new NonValidatingSSLSocketFactory.NonValidatingTrustManager()};
+ ctx.init(null, tma, null);
+ final SSLSocketFactory factory = ctx.getSocketFactory();
+
+ final SSLConnectionSocketFactory sf = new SSLConnectionSocketFactory(
+ factory,
+ new NoopHostnameVerifier());
+
+ builder.setSSLSocketFactory(sf);
+ }
+
+ return builder.build();
+ }
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java
new file mode 100644
index 000000000..100d351b3
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/NonValidatingSSLSocketFactory.java
@@ -0,0 +1,98 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.Socket;
+import java.security.GeneralSecurityException;
+import java.security.cert.X509Certificate;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.X509TrustManager;
+
+/**
+ * Provide a SSLSocketFactory that allows SSL connections to be made without validating the server's
+ * certificate. This is more convenient for some applications, but is less secure as it allows "man
+ * in the middle" attacks.
+ */
+public class NonValidatingSSLSocketFactory extends SSLSocketFactory {
+
+ /**
+ * We provide a constructor that takes an unused argument solely because the ssl calling code will
+ * look for this constructor first and then fall back to the no argument constructor, so we avoid
+ * an exception and additional reflection lookups.
+ *
+ * @param arg input argument
+ * @throws GeneralSecurityException if something goes wrong
+ */
+ public NonValidatingSSLSocketFactory(final String arg) throws GeneralSecurityException {
+ final SSLContext ctx = SSLContext.getInstance("TLS"); // or "SSL" ?
+
+ ctx.init(null, new TrustManager[]{new NonValidatingTrustManager()}, null);
+
+ factory = ctx.getSocketFactory();
+ }
+
+ protected SSLSocketFactory factory;
+
+ public Socket createSocket(final InetAddress host, final int port) throws IOException {
+ return factory.createSocket(host, port);
+ }
+
+ public Socket createSocket(final String host, final int port) throws IOException {
+ return factory.createSocket(host, port);
+ }
+
+ public Socket createSocket(final String host, final int port, final InetAddress localHost, final int localPort)
+ throws IOException {
+ return factory.createSocket(host, port, localHost, localPort);
+ }
+
+ public Socket createSocket(final InetAddress address, final int port, final InetAddress localAddress,
+ final int localPort)
+ throws IOException {
+ return factory.createSocket(address, port, localAddress, localPort);
+ }
+
+ public Socket createSocket(final Socket socket, final String host, final int port, final boolean autoClose)
+ throws IOException {
+ return factory.createSocket(socket, host, port, autoClose);
+ }
+
+ public String[] getDefaultCipherSuites() {
+ return factory.getDefaultCipherSuites();
+ }
+
+ public String[] getSupportedCipherSuites() {
+ return factory.getSupportedCipherSuites();
+ }
+
+ public static class NonValidatingTrustManager implements X509TrustManager {
+
+ public X509Certificate[] getAcceptedIssuers() {
+ return new X509Certificate[0];
+ }
+
+ public void checkClientTrusted(final X509Certificate[] certs, final String authType) {
+ }
+
+ public void checkServerTrusted(final X509Certificate[] certs, final String authType) {
+ }
+ }
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java
new file mode 100644
index 000000000..a2f1081cf
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/SamlCredentialsProviderFactory.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import static software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.IAM_IDP_ARN;
+import static software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.IAM_ROLE_ARN;
+
+import java.sql.SQLException;
+import java.util.Properties;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.sts.StsClient;
+import software.amazon.awssdk.services.sts.auth.StsAssumeRoleWithSamlCredentialsProvider;
+import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest;
+
+public abstract class SamlCredentialsProviderFactory implements CredentialsProviderFactory {
+
+ @Override
+ public AwsCredentialsProvider getAwsCredentialsProvider(final String host, final Region region,
+ final @NonNull Properties props)
+ throws SQLException {
+
+ final String samlAssertion = getSamlAssertion(props);
+
+ final AssumeRoleWithSamlRequest assumeRoleWithSamlRequest = AssumeRoleWithSamlRequest.builder()
+ .samlAssertion(samlAssertion)
+ .roleArn(IAM_ROLE_ARN.getString(props))
+ .principalArn(IAM_IDP_ARN.getString(props))
+ .build();
+
+ final StsClient stsClient = StsClient.builder()
+ .credentialsProvider(AnonymousCredentialsProvider.create())
+ .region(region)
+ .build();
+
+ return StsAssumeRoleWithSamlCredentialsProvider.builder()
+ .refreshRequest(assumeRoleWithSamlRequest)
+ .asyncCredentialUpdateEnabled(true)
+ .stsClient(stsClient)
+ .build();
+ }
+
+ abstract String getSamlAssertion(final @NonNull Properties props) throws SQLException;
+}
diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java
new file mode 100644
index 000000000..af6b6af06
--- /dev/null
+++ b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.util;
+
+import software.amazon.jdbc.HostSpec;
+
+public class IamAuthUtils {
+ public static String getIamHost(final String iamHost, final HostSpec hostSpec) {
+ if (!StringUtils.isNullOrEmpty(iamHost)) {
+ return iamHost;
+ }
+ return hostSpec.getHost();
+ }
+
+ public static int getIamPort(final int iamDefaultPort, final HostSpec hostSpec, final int dialectDefaultPort) {
+ if (iamDefaultPort > 0) {
+ return iamDefaultPort;
+ } else if (hostSpec.isPortSpecified()) {
+ return hostSpec.getPort();
+ } else {
+ return dialectDefaultPort;
+ }
+ }
+}
diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties
index b79c7d242..f703921c1 100644
--- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties
+++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties
@@ -14,6 +14,15 @@
# limitations under the License.
#
+# ADFS Credentials Provider Getter
+AdfsCredentialsProviderFactory.failedLogin=Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response: \n''{0}''
+AdfsCredentialsProviderFactory.getSamlAssertionFailed=Failed to get SAML Assertion due to exception: ''{0}''
+AdfsCredentialsProviderFactory.invalidHttpsUrl=Invalid HTTPS URL: ''{0}''
+AdfsCredentialsProviderFactory.signOnPagePostActionUrl=ADFS SignOn Action URL: ''{0}''
+AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed=ADFS SignOn Page POST action failed with HTTP status ''{0}'', reason phrase ''{1}'', and response ''{2}''
+AdfsCredentialsProviderFactory.signOnPageRequestFailed=ADFS SignOn Page Request Failed with HTTP status ''{0}'', reason phrase ''{1}'', and response ''{2}''
+AdfsCredentialsProviderFactory.signOnPageUrl=ADFS SignOn URL: ''{0}''
+
# Aurora Host List Connection Plugin
AuroraHostListConnectionPlugin.providerAlreadySet=Another dynamic host list provider has already been set: {0}.
@@ -148,6 +157,16 @@ Failover.failedToUpdateCurrentHostspecAvailability=Failed to update current host
Failover.noOperationsAfterConnectionClosed=No operations allowed after connection closed.
Failover.invalidHostListProvider=Incorrect type of host list provider found, please ensure the correct host list provider is specified. The host list provider in use is: ''{0}'', the plugin is expected a cluster-aware host list provider such as the AuroraHostListProvider.
+# Federated Authentication Connection Plugin
+FederatedAuthPlugin.generatedNewIamToken=Generated new IAM token = ''{0}''
+FederatedAuthPlugin.javaStsSdkNotInClasspath=Required dependency 'AWS Java SDK for AWS Secret Token Service' is not on the classpath.
+FederatedAuthPlugin.unhandledException=Unhandled exception: ''{0}''
+FederatedAuthPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
+FederatedAuthPlugin.useCachedIamToken=Use cached IAM token = ''{0}''
+
+# Federated Authentication Connection Plugin Factory
+FederatedAuthPluginFactory.failedToInitializeHttpClient=Failed to initialize HttpClient.
+FederatedAuthPluginFactory.unsupportedIdp=Unsupported Identity Provider ''{0}''. Please visit to the documentation for supported Identity Providers.
# HikariPooledConnectionProvider
HikariPooledConnectionProvider.errorConnectingWithDataSource=Unable to connect to ''{0}'' using the Hikari data source.
@@ -173,7 +192,6 @@ HostSelector.roundRobinInvalidDefaultWeight=The provided default weight value is
IamAuthConnectionPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
IamAuthConnectionPlugin.useCachedIamToken=Use cached IAM token = ''{0}''
IamAuthConnectionPlugin.generatedNewIamToken=Generated new IAM token = ''{0}''
-IamAuthConnectionPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero. Falling back to default port.
IamAuthConnectionPlugin.unhandledException=Unhandled exception: ''{0}''
IamAuthConnectionPlugin.connectException=Error occurred while opening a connection: ''{0}''
diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts
index 938e05852..29b27b157 100644
--- a/wrapper/src/test/build.gradle.kts
+++ b/wrapper/src/test/build.gradle.kts
@@ -41,8 +41,9 @@ dependencies {
testImplementation("com.zaxxer:HikariCP:4.+") // version 4.+ is compatible with Java 8
testImplementation("org.springframework.boot:spring-boot-starter-jdbc:2.7.13") // 2.7.13 is the last version compatible with Java 8
testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8
- testImplementation("software.amazon.awssdk:rds:2.20.49")
testImplementation("software.amazon.awssdk:ec2:2.20.49")
+ testImplementation("software.amazon.awssdk:rds:2.20.49")
+ testImplementation("software.amazon.awssdk:sts:2.20.49")
testImplementation("org.testcontainers:testcontainers:1.17.+")
testImplementation("org.testcontainers:mysql:1.17.+")
testImplementation("org.testcontainers:postgresql:1.17.+")
diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java
index 2323859ec..63741b7b5 100644
--- a/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java
+++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java
@@ -102,7 +102,7 @@ public static void registerDrivers() throws SQLException {
@Test
public void testPostgresConnectValidTokenInCache() throws SQLException {
IamAuthConnectionPlugin.tokenCache.put(PG_CACHE_KEY,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT);
@@ -114,7 +114,7 @@ public void testMySqlConnectValidTokenInCache() throws SQLException {
props.setProperty(PropertyDefinition.USER.name, "mysqlUser");
props.setProperty(PropertyDefinition.PASSWORD.name, "mysqlPassword");
IamAuthConnectionPlugin.tokenCache.put(MYSQL_CACHE_KEY,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_MYSQL_PORT);
@@ -130,7 +130,7 @@ public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLEx
final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:"
+ PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT);
}
@@ -145,7 +145,7 @@ public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort()
final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:"
+ DEFAULT_PG_PORT + ":postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC);
}
@@ -153,7 +153,7 @@ public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort()
@Test
public void testConnectExpiredTokenInCache() throws SQLException {
IamAuthConnectionPlugin.tokenCache.put(PG_CACHE_KEY,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().minusMillis(300000)));
when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT);
@@ -171,7 +171,7 @@ public void testConnectEmptyCache() throws SQLException {
public void testConnectWithSpecifiedPort() throws SQLException {
final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:1234:" + "postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT);
}
@@ -183,7 +183,7 @@ public void testConnectWithSpecifiedIamDefaultPort() throws SQLException {
final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:"
+ iamDefaultPort + ":postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT);
}
@@ -194,7 +194,7 @@ public void testConnectWithSpecifiedRegion() throws SQLException {
"us-west-1:pg.testdb.us-west-1.rds.amazonaws.com:" + DEFAULT_PG_PORT + ":" + "postgresqlUser";
props.setProperty(IamAuthConnectionPlugin.IAM_REGION.name, "us-west-1");
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewRegion,
- new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
+ new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));
when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT);
diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java
new file mode 100644
index 000000000..d3ea267ec
--- /dev/null
+++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/AdfsCredentialsProviderFactoryTest.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.sql.SQLException;
+import java.util.Properties;
+import java.util.function.Supplier;
+import org.apache.http.HttpEntity;
+import org.apache.http.StatusLine;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.util.EntityUtils;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.testcontainers.shaded.org.apache.commons.io.IOUtils;
+import software.amazon.jdbc.PluginService;
+import software.amazon.jdbc.util.telemetry.TelemetryContext;
+import software.amazon.jdbc.util.telemetry.TelemetryFactory;
+
+class AdfsCredentialsProviderFactoryTest {
+
+ private static final String USERNAME = "someFederatedUsername@example.com";
+ private static final String PASSWORD = "somePassword";
+ @Mock private PluginService mockPluginService;
+ @Mock private TelemetryFactory mockTelemetryFactory;
+ @Mock private TelemetryContext mockTelemetryContext;
+ @Mock private Supplier mockHttpClientSupplier;
+ @Mock private CloseableHttpClient mockHttpClient;
+ @Mock private CloseableHttpResponse mockHttpGetSignInPageResponse;
+ @Mock private CloseableHttpResponse mockHttpPostSignInResponse;
+ @Mock private StatusLine mockStatusLine;
+ @Mock private HttpEntity mockSignInPageHttpEntity;
+ @Mock private HttpEntity mockSamlHttpEntity;
+ private AdfsCredentialsProviderFactory adfsCredentialsProviderFactory;
+ private Properties props;
+
+ @BeforeEach
+ public void init() throws IOException {
+ MockitoAnnotations.openMocks(this);
+
+ this.props = new Properties();
+ this.props.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com");
+ this.props.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, USERNAME);
+ this.props.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, PASSWORD);
+
+ when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory);
+ when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext);
+ when(mockHttpClientSupplier.get()).thenReturn(mockHttpClient);
+ when(mockHttpClient.execute(any(HttpGet.class))).thenReturn(mockHttpGetSignInPageResponse);
+ when(mockHttpGetSignInPageResponse.getStatusLine()).thenReturn(mockStatusLine);
+ when(mockStatusLine.getStatusCode()).thenReturn(200);
+ when(mockHttpGetSignInPageResponse.getEntity()).thenReturn(mockSignInPageHttpEntity);
+
+ String signinPageHtml = IOUtils.toString(
+ this.getClass().getClassLoader().getResourceAsStream("federated_auth/adfs-sign-in-page.html"), "UTF-8");
+ InputStream signInPageHtmlInputStream = new ByteArrayInputStream(signinPageHtml.getBytes());
+ when(mockSignInPageHttpEntity.getContent()).thenReturn(signInPageHtmlInputStream);
+
+ when(mockHttpClient.execute(any(HttpPost.class))).thenReturn(mockHttpPostSignInResponse);
+ when(mockHttpPostSignInResponse.getStatusLine()).thenReturn(mockStatusLine);
+ when(mockHttpPostSignInResponse.getEntity()).thenReturn(mockSamlHttpEntity);
+
+ String adfsSamlHtml = IOUtils.toString(
+ this.getClass().getClassLoader().getResourceAsStream("federated_auth/adfs-saml.html"), "UTF-8");
+ InputStream samlHtmlInputStream = new ByteArrayInputStream(adfsSamlHtml.getBytes());
+ when(mockSamlHttpEntity.getContent()).thenReturn(samlHtmlInputStream);
+
+ this.adfsCredentialsProviderFactory = new AdfsCredentialsProviderFactory(mockPluginService, mockHttpClientSupplier);
+ }
+
+ @Test
+ void test() throws IOException, SQLException {
+ this.adfsCredentialsProviderFactory.getSamlAssertion(props);
+
+ ArgumentCaptor httpPostArgumentCaptor = ArgumentCaptor.forClass(HttpPost.class);
+ verify(mockHttpClient, times(2)).execute(httpPostArgumentCaptor.capture());
+ HttpPost actualHttpPost = httpPostArgumentCaptor.getValue();
+ String content = EntityUtils.toString(actualHttpPost.getEntity());
+ String[] params = content.split("&");
+ assertEquals("UserName=" + USERNAME.replace("@", "%40"), params[0]);
+ assertEquals("Password=" + PASSWORD, params[1]);
+ assertEquals("Kmsi=true", params[2]);
+ assertEquals("AuthMethod=FormsAuthentication", params[3]);
+ }
+}
diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java
new file mode 100644
index 000000000..06a49d397
--- /dev/null
+++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPluginTest.java
@@ -0,0 +1,188 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package software.amazon.jdbc.plugin.federatedauth;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.time.Instant;
+import java.util.Properties;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.jdbc.HostSpec;
+import software.amazon.jdbc.HostSpecBuilder;
+import software.amazon.jdbc.JdbcCallable;
+import software.amazon.jdbc.PluginService;
+import software.amazon.jdbc.PropertyDefinition;
+import software.amazon.jdbc.dialect.Dialect;
+import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy;
+import software.amazon.jdbc.plugin.TokenInfo;
+import software.amazon.jdbc.util.telemetry.TelemetryContext;
+import software.amazon.jdbc.util.telemetry.TelemetryCounter;
+import software.amazon.jdbc.util.telemetry.TelemetryFactory;
+
+class FederatedAuthPluginTest {
+
+ private static final int DEFAULT_PORT = 1234;
+ private static final String DRIVER_PROTOCOL = "jdbc:postgresql:";
+
+ private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy())
+ .host("pg.testdb.us-east-2.rds.amazonaws.com").build();
+ private static final String DB_USER = "iamUser";
+ private static final String TEST_TOKEN = "someTestToken";
+ private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000));
+ @Mock private PluginService mockPluginService;
+ @Mock private Dialect mockDialect;
+ @Mock JdbcCallable mockLambda;
+ @Mock private TelemetryFactory mockTelemetryFactory;
+ @Mock private TelemetryContext mockTelemetryContext;
+ @Mock private TelemetryCounter mockTelemetryCounter;
+ @Mock private CredentialsProviderFactory mockCredentialsProviderFactory;
+ @Mock private AwsCredentialsProvider mockAwsCredentialsProvider;
+ @Mock private CompletableFuture completableFuture;
+ @Mock private AwsCredentialsIdentity mockAwsCredentialsIdentity;
+ private Properties props;
+
+ @BeforeEach
+ public void init() throws ExecutionException, InterruptedException, SQLException {
+ MockitoAnnotations.openMocks(this);
+ props = new Properties();
+ props.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth");
+ props.setProperty(FederatedAuthPlugin.DB_USER.name, DB_USER);
+ FederatedAuthPlugin.clearCache();
+
+ when(mockPluginService.getDialect()).thenReturn(mockDialect);
+ when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PORT);
+ when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory);
+ when(mockTelemetryFactory.createCounter(any())).thenReturn(mockTelemetryCounter);
+ when(mockTelemetryFactory.openTelemetryContext(any(), any())).thenReturn(mockTelemetryContext);
+ when(mockCredentialsProviderFactory.getAwsCredentialsProvider(any(), any(), any()))
+ .thenReturn(mockAwsCredentialsProvider);
+ when(mockAwsCredentialsProvider.resolveIdentity()).thenReturn(completableFuture);
+ when(completableFuture.get()).thenReturn(mockAwsCredentialsIdentity);
+ }
+
+ @Test
+ void testCachedToken() throws SQLException {
+ FederatedAuthPlugin plugin =
+ new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory);
+
+ String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser";
+ FederatedAuthPlugin.tokenCache.put(key, TEST_TOKEN_INFO);
+
+ plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);
+
+ assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
+ assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
+ }
+
+ @Test
+ void testExpiredCachedToken() throws SQLException {
+ FederatedAuthPlugin spyPlugin = Mockito.spy(
+ new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory));
+
+ String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser";
+ String someExpiredToken = "someExpiredToken";
+ TokenInfo expiredTokenInfo = new TokenInfo(
+ someExpiredToken, Instant.now().minusMillis(300000));
+ FederatedAuthPlugin.tokenCache.put(key, expiredTokenInfo);
+
+ when(
+ spyPlugin.generateAuthenticationToken(
+ props,
+ HOST_SPEC.getHost(),
+ DEFAULT_PORT,
+ Region.US_EAST_2, mockAwsCredentialsProvider))
+ .thenReturn(TEST_TOKEN);
+
+ spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);
+ assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
+ assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
+ }
+
+ @Test
+ void testNoCachedToken() throws SQLException {
+ FederatedAuthPlugin spyPlugin = Mockito.spy(
+ new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory));
+
+ when(
+ spyPlugin.generateAuthenticationToken(
+ props,
+ HOST_SPEC.getHost(),
+ DEFAULT_PORT,
+ Region.US_EAST_2, mockAwsCredentialsProvider))
+ .thenReturn(TEST_TOKEN);
+
+ spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);
+ assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
+ assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
+ }
+
+ @Test
+ void testSpecifiedIamHostPortRegion() throws SQLException {
+ final String expectedHost = "pg.testdb.us-west-2.rds.amazonaws.com";
+ final int expectedPort = 9876;
+ final Region expectedRegion = Region.US_WEST_2;
+
+ props.setProperty(FederatedAuthPlugin.IAM_HOST.name, expectedHost);
+ props.setProperty(FederatedAuthPlugin.IAM_DEFAULT_PORT.name, String.valueOf(expectedPort));
+ props.setProperty(FederatedAuthPlugin.IAM_REGION.name, expectedRegion.toString());
+
+ final String key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + String.valueOf(expectedPort) + ":iamUser";
+ FederatedAuthPlugin.tokenCache.put(key, TEST_TOKEN_INFO);
+
+ FederatedAuthPlugin plugin =
+ new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory);
+
+ plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);
+
+ assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
+ assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
+ }
+
+ @Test
+ void testIdpCredentialsFallback() throws SQLException {
+ String expectedUser = "expectedUser";
+ String expectedPassword = "expectedPassword";
+ PropertyDefinition.USER.set(props, expectedUser);
+ PropertyDefinition.PASSWORD.set(props, expectedPassword);
+
+ FederatedAuthPlugin plugin =
+ new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory);
+
+ String key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + DEFAULT_PORT + ":iamUser";
+ FederatedAuthPlugin.tokenCache.put(key, TEST_TOKEN_INFO);
+
+ plugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);
+
+ assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
+ assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
+ assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props));
+ assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props));
+ }
+}
diff --git a/wrapper/src/test/resources/federated_auth/adfs-saml.html b/wrapper/src/test/resources/federated_auth/adfs-saml.html
new file mode 100644
index 000000000..6686e3db8
--- /dev/null
+++ b/wrapper/src/test/resources/federated_auth/adfs-saml.html
@@ -0,0 +1 @@
+Working...
diff --git a/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html b/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html
new file mode 100644
index 000000000..0b06f6dda
--- /dev/null
+++ b/wrapper/src/test/resources/federated_auth/adfs-sign-in-page.html
@@ -0,0 +1,613 @@
+
+
+
+
+
+
+
+
+
+
+
+ Sign In
+
+
+
+
+
+
+
+
+
+
+
+
JavaScript required
+
JavaScript is required. This web browser does not support JavaScript or JavaScript in this web browser is not enabled.
+
To find out if your web browser supports JavaScript or to enable JavaScript, see web browser help.