Skip to content

Commit

Permalink
feature: add federated auth plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronchung-bitquill committed Nov 25, 2023
1 parent 4a0d15b commit 6bcfc68
Show file tree
Hide file tree
Showing 17 changed files with 1,870 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/AWSDriverExample/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
implementation("mysql:mysql-connector-java:8.0.33")
implementation("software.amazon.awssdk:rds:2.21.11")
implementation("software.amazon.awssdk:secretsmanager:2.21.21")
implementation("software.amazon.awssdk:sts:2.21.21")
implementation("com.fasterxml.jackson.core:jackson-databind:2.16.0")
implementation(project(":aws-advanced-jdbc-wrapper"))
implementation("io.opentelemetry:opentelemetry-api:1.32.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon;

import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Properties;

public class FederatedAuthPluginExample {

private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-identifier.XYZ.us-east-2.rds.amazonaws.com:5432/employees";

public static void main(String[] args) throws SQLException {
// Set the AWS Federated Authentication Connection Plugin parameters and the JDBC Wrapper parameters.
final Properties properties = new Properties();

// Enable the AWS Federated Authentication Connection Plugin.
properties.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth");
properties.setProperty(FederatedAuthPlugin.IDP_NAME.name, "adfs");
properties.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com");
properties.setProperty(FederatedAuthPlugin.IAM_ROLE_ARN.name, "arn:aws:iam::123456789012:role/adfs_example_iam_role");
properties.setProperty(FederatedAuthPlugin.IAM_IDP_ARN.name, "arn:aws:iam::123456789012:saml-provider/adfs_example");
properties.setProperty(FederatedAuthPlugin.IAM_REGION.name, "us-east-2");
properties.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, "[email protected]");
properties.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, "somePassword");
properties.setProperty(PropertyDefinition.USER.name, "someIamUser");

// Try and make a connection:
try (final Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties);
final Statement statement = conn.createStatement();
final ResultSet rs = statement.executeQuery("SELECT 1")) {
System.out.println(Util.getResult(rs));
}
}
}
6 changes: 5 additions & 1 deletion wrapper/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ plugins {

dependencies {
implementation("org.checkerframework:checker-qual:3.40.0")
compileOnly("org.apache.httpcomponents:httpclient:4.5.14")
compileOnly("software.amazon.awssdk:rds:2.21.11")
compileOnly("software.amazon.awssdk:sts:2.21.11")
compileOnly("software.amazon.awssdk:iam:2.21.11")
compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8
compileOnly("software.amazon.awssdk:secretsmanager:2.21.21")
compileOnly("com.fasterxml.jackson.core:jackson-databind:2.16.0")
Expand Down Expand Up @@ -58,9 +61,10 @@ dependencies {
testImplementation("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8
testImplementation("org.springframework.boot:spring-boot-starter-jdbc:2.7.13") // 2.7.13 is the last version compatible with Java 8
testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8
testImplementation("software.amazon.awssdk:rds:2.21.11")
testImplementation("software.amazon.awssdk:ec2:2.21.12")
testImplementation("software.amazon.awssdk:rds:2.21.11")
testImplementation("software.amazon.awssdk:secretsmanager:2.21.21")
testImplementation("software.amazon.awssdk:sts:2.21.11")
testImplementation("org.testcontainers:testcontainers:1.19.1")
testImplementation("org.testcontainers:mysql:1.19.1")
testImplementation("org.testcontainers:postgresql:1.19.2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory;
import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory;
import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory;
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory;
import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory;
import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPluginFactory;
import software.amazon.jdbc.profile.ConfigurationProfile;
import software.amazon.jdbc.profile.DriverConfigurationProfiles;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.SqlState;
import software.amazon.jdbc.util.StringUtils;
Expand All @@ -65,6 +65,7 @@ public class ConnectionPluginChainBuilder {
put("failover", FailoverConnectionPluginFactory.class);
put("iam", IamAuthConnectionPluginFactory.class);
put("awsSecretsManager", AwsSecretsManagerConnectionPluginFactory.class);
put("federatedAuth", FederatedAuthPluginFactory.class);
put("auroraStaleDns", AuroraStaleDnsPluginFactory.class);
put("readWriteSplitting", ReadWriteSplittingPluginFactory.class);
put("auroraConnectionTracker", AuroraConnectionTrackerPluginFactory.class);
Expand Down Expand Up @@ -92,7 +93,8 @@ public class ConnectionPluginChainBuilder {
put(HostMonitoringConnectionPluginFactory.class, 800);
put(IamAuthConnectionPluginFactory.class, 900);
put(AwsSecretsManagerConnectionPluginFactory.class, 1000);
put(LogQueryConnectionPluginFactory.class, 1100);
put(FederatedAuthPluginFactory.class, 1100);
put(LogQueryConnectionPluginFactory.class, 1200);
put(ConnectTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
put(ExecutionTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
put(DeveloperConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon.jdbc.plugin.federatedauth;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.http.NameValuePair;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import org.checkerframework.checker.nullness.qual.NonNull;
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.StringUtils;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;

public class AdfsCredentialsProviderFactory extends SamlCredentialsProviderFactory {

public static final String IDP_NAME = "adfs";
private static final String TELEMETRY_FETCH_SAML = "Fetch ADFS SAML Assertion";
private static final Logger LOGGER = Logger.getLogger(AdfsCredentialsProviderFactory.class.getName());
private final PluginService pluginService;
private final TelemetryFactory telemetryFactory;
private final CloseableHttpClient httpClient;
private TelemetryContext telemetryContext;

public AdfsCredentialsProviderFactory(PluginService pluginService, CloseableHttpClient httpClient) {
this.pluginService = pluginService;
this.telemetryFactory = this.pluginService.getTelemetryFactory();
this.httpClient = httpClient;
}

@Override
String getSamlAssertion(final @NonNull Properties props) {
this.telemetryContext = telemetryFactory.openTelemetryContext(TELEMETRY_FETCH_SAML, TelemetryTraceLevel.NESTED);

String uri = "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':'
+ FederatedAuthPlugin.IDP_PORT.getString(props) + "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp="
+ FederatedAuthPlugin.RELAYING_PARTY_ID.getString(props);

try {
LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPageUrl", new Object[] {uri}));
validateURL(uri);
HttpGet get = new HttpGet(uri);
CloseableHttpResponse resp = httpClient.execute(get);

if (resp.getStatusLine().getStatusCode() != 200) {
throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPageRequestFailed",
new Object[] {
resp.getStatusLine().getStatusCode(),
resp.getStatusLine().getReasonPhrase(),
EntityUtils.toString(resp.getEntity())}));
}

String body = EntityUtils.toString(resp.getEntity());

String action = getFormAction(body);
if (!StringUtils.isNullOrEmpty(action) && action.startsWith("/")) {
uri = "https://" + FederatedAuthPlugin.IDP_ENDPOINT.getString(props) + ':' + FederatedAuthPlugin.IDP_PORT.getString(props) + action;
}

LOGGER.finest(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionUrl", new Object[] {uri}));

validateURL(uri);
HttpPost post = new HttpPost(uri);
post.setEntity(new UrlEncodedFormEntity(getParameters(body, props)));
resp = httpClient.execute(post);
if (resp.getStatusLine().getStatusCode() != 200) {
throw new IOException(Messages.get("AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed",
new Object[] {
resp.getStatusLine().getStatusCode(),
resp.getStatusLine().getReasonPhrase(),
EntityUtils.toString(resp.getEntity())}));
}

String content = EntityUtils.toString(resp.getEntity());
Matcher matcher = FederatedAuthPlugin.SAML_RESPONSE_PATTERN.matcher(content);
if (!matcher.find()) {
throw new IOException(Messages.get("AdfsCredentialsProviderFactory.failedLogin", new Object[] {content}));
}

return matcher.group(1);
} catch (IOException e) {
LOGGER.severe(Messages.get("AdfsCredentialsProviderFactory.getSamlAssertionFailed", new Object[] {e}));
this.telemetryContext.setSuccess(false);
this.telemetryContext.setException(e);
throw new RuntimeException(e);
} finally {
this.telemetryContext.closeContext();
}
}

private List<String> getInputTagsFromHTML(String body) {
Set<String> distinctInputTags = new HashSet<>();
List<String> inputTags = new ArrayList<String>();
Pattern inputTagPattern = Pattern.compile("<input(.+?)/>", Pattern.DOTALL);
Matcher inputTagMatcher = inputTagPattern.matcher(body);
while (inputTagMatcher.find()) {
String tag = inputTagMatcher.group(0);
String tagNameLower = getValueByKey(tag, "name").toLowerCase();
if (!tagNameLower.isEmpty() && distinctInputTags.add(tagNameLower)) {
inputTags.add(tag);
}
}
return inputTags;
}

private String getValueByKey(String input, String key) {
Pattern keyValuePattern = Pattern.compile("(" + Pattern.quote(key) + ")\\s*=\\s*\"(.*?)\"");
Matcher keyValueMatcher = keyValuePattern.matcher(input);
if (keyValueMatcher.find()) {
return escapeHtmlEntity(keyValueMatcher.group(2));
}
return "";
}

private String escapeHtmlEntity(String html) {
StringBuilder sb = new StringBuilder(html.length());
int i = 0;
int length = html.length();
while (i < length) {
char c = html.charAt(i);
if (c != '&') {
sb.append(c);
i++;
continue;
}

if (html.startsWith("&amp;", i)) {
sb.append('&');
i += 5;
} else if (html.startsWith("&apos;", i)) {
sb.append('\'');
i += 6;
} else if (html.startsWith("&quot;", i)) {
sb.append('"');
i += 6;
} else if (html.startsWith("&lt;", i)) {
sb.append('<');
i += 4;
} else if (html.startsWith("&gt;", i)) {
sb.append('>');
i += 4;
} else {
sb.append(c);
++i;
}
}
return sb.toString();
}

private List<NameValuePair> getParameters(String body, final @NonNull Properties props) {
List<NameValuePair> parameters = new ArrayList<NameValuePair>();
for (String inputTag : getInputTagsFromHTML(body)) {
String name = getValueByKey(inputTag, "name");
String value = getValueByKey(inputTag, "value");
String nameLower = name.toLowerCase();

if (nameLower.contains("username")) {
parameters.add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_USERNAME.getString(props)));
} else if (nameLower.contains("authmethod")) {
if (!value.isEmpty()) {
parameters.add(new BasicNameValuePair(name, value));
}
} else if (nameLower.contains("password")) {
parameters
.add(new BasicNameValuePair(name, FederatedAuthPlugin.IDP_PASSWORD.getString(props)));
} else if (!name.isEmpty()) {
parameters.add(new BasicNameValuePair(name, value));
}
}
return parameters;
}

private String getFormAction(String body) {
Pattern pattern = Pattern.compile("<form.*?action=\"([^\"]+)\"");
Matcher m = pattern.matcher(body);
if (m.find()) {
return escapeHtmlEntity(m.group(1));
}
return null;
}

private void validateURL(String paramString) throws IOException {

URI authorizeRequestUrl = URI.create(paramString);
String errorMessage = Messages.get("AdfsCredentialsProviderFactory.invalidHttpsUrl", new Object[] {paramString});

try {
if (!authorizeRequestUrl.toURL().getProtocol().equalsIgnoreCase("https")) {
throw new IOException(errorMessage);
}

Matcher matcher = FederatedAuthPlugin.HTTPS_URL_PATTERN.matcher(paramString);
if (!matcher.find()) {
throw new IOException(errorMessage);
}
} catch (MalformedURLException e) {
throw new IOException(errorMessage, e);
}
}

public void close() {
try {
this.httpClient.close();
} catch (IOException e) {
// Ignore
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon.jdbc.plugin.federatedauth;

import java.io.Closeable;
import java.util.Properties;
import org.checkerframework.checker.nullness.qual.NonNull;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;

public interface CredentialsProviderFactory extends Closeable {
AwsCredentialsProvider getAwsCredentialsProvider(String host, Region region, final @NonNull Properties props);
}
Loading

0 comments on commit 6bcfc68

Please sign in to comment.