Skip to content

Commit

Permalink
feature: add federated auth plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronchung-bitquill committed Dec 14, 2023
1 parent a6b1e17 commit 02cdd53
Show file tree
Hide file tree
Showing 21 changed files with 1,976 additions and 60 deletions.
1 change: 1 addition & 0 deletions examples/AWSDriverExample/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
implementation("mysql:mysql-connector-java:8.0.33")
implementation("software.amazon.awssdk:rds:2.21.42")
implementation("software.amazon.awssdk:secretsmanager:2.21.43")
implementation("software.amazon.awssdk:sts:2.21.42")
implementation("com.fasterxml.jackson.core:jackson-databind:2.16.0")
implementation(project(":aws-advanced-jdbc-wrapper"))
implementation("io.opentelemetry:opentelemetry-api:1.32.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon;

import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Properties;

public class FederatedAuthPluginExample {

private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-identifier.XYZ.us-east-2.rds.amazonaws.com:5432/employees";

public static void main(String[] args) throws SQLException {
// Set the AWS Federated Authentication Connection Plugin parameters and the JDBC Wrapper parameters.
final Properties properties = new Properties();

// Enable the AWS Federated Authentication Connection Plugin.
properties.setProperty(PropertyDefinition.PLUGINS.name, "federatedAuth");
properties.setProperty(FederatedAuthPlugin.IDP_NAME.name, "adfs");
properties.setProperty(FederatedAuthPlugin.IDP_ENDPOINT.name, "ec2amaz-ab3cdef.example.com");
properties.setProperty(FederatedAuthPlugin.IAM_ROLE_ARN.name, "arn:aws:iam::123456789012:role/adfs_example_iam_role");
properties.setProperty(FederatedAuthPlugin.IAM_IDP_ARN.name, "arn:aws:iam::123456789012:saml-provider/adfs_example");
properties.setProperty(FederatedAuthPlugin.IAM_REGION.name, "us-east-2");
properties.setProperty(FederatedAuthPlugin.IDP_USERNAME.name, "[email protected]");
properties.setProperty(FederatedAuthPlugin.IDP_PASSWORD.name, "somePassword");
properties.setProperty(PropertyDefinition.USER.name, "someIamUser");

// Try and make a connection:
try (final Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties);
final Statement statement = conn.createStatement();
final ResultSet rs = statement.executeQuery("SELECT 1")) {
System.out.println(Util.getResult(rs));
}
}
}
3 changes: 3 additions & 0 deletions wrapper/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ plugins {

dependencies {
implementation("org.checkerframework:checker-qual:3.40.0")
compileOnly("org.apache.httpcomponents:httpclient:4.5.14")
compileOnly("software.amazon.awssdk:rds:2.21.42")
compileOnly("software.amazon.awssdk:sts:2.21.42")
compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8
compileOnly("software.amazon.awssdk:secretsmanager:2.21.43")
compileOnly("com.fasterxml.jackson.core:jackson-databind:2.16.0")
Expand Down Expand Up @@ -61,6 +63,7 @@ dependencies {
testImplementation("software.amazon.awssdk:rds:2.21.42")
testImplementation("software.amazon.awssdk:ec2:2.21.12")
testImplementation("software.amazon.awssdk:secretsmanager:2.21.43")
testImplementation("software.amazon.awssdk:sts:2.21.42")
testImplementation("org.testcontainers:testcontainers:1.19.3")
testImplementation("org.testcontainers:mysql:1.19.3")
testImplementation("org.testcontainers:postgresql:1.19.3")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory;
import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory;
import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory;
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory;
import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory;
import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPluginFactory;
import software.amazon.jdbc.plugin.strategy.fastestresponse.FastestResponseStrategyPluginFactory;
Expand Down Expand Up @@ -67,6 +68,7 @@ public class ConnectionPluginChainBuilder {
put("failover", FailoverConnectionPluginFactory.class);
put("iam", IamAuthConnectionPluginFactory.class);
put("awsSecretsManager", AwsSecretsManagerConnectionPluginFactory.class);
put("federatedAuth", FederatedAuthPluginFactory.class);
put("auroraStaleDns", AuroraStaleDnsPluginFactory.class);
put("readWriteSplitting", ReadWriteSplittingPluginFactory.class);
put("auroraConnectionTracker", AuroraConnectionTrackerPluginFactory.class);
Expand Down Expand Up @@ -99,7 +101,8 @@ public class ConnectionPluginChainBuilder {
put(FastestResponseStrategyPluginFactory.class, 900);
put(IamAuthConnectionPluginFactory.class, 1000);
put(AwsSecretsManagerConnectionPluginFactory.class, 1100);
put(LogQueryConnectionPluginFactory.class, 1200);
put(FederatedAuthPluginFactory.class, 1200);
put(LogQueryConnectionPluginFactory.class, 1300);
put(ConnectTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
put(ExecutionTimeConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
put(DeveloperConnectionPluginFactory.class, WEIGHT_RELATIVE_TO_PRIOR_PLUGIN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.authentication.AwsCredentialsManager;
import software.amazon.jdbc.util.IamAuthUtils;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class IamAuthConnectionPlugin extends AbstractConnectionPlugin {
"Overrides the host that is used to generate the IAM token");

public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty(
"iamDefaultPort", null,
"iamDefaultPort", "-1",
"Overrides default port that is used to generate the IAM token");

public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty(
Expand Down Expand Up @@ -115,12 +116,12 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro
throw new SQLException(PropertyDefinition.USER.name + " is null or empty.");
}

String host = hostSpec.getHost();
if (!StringUtils.isNullOrEmpty(IAM_HOST.getString(props))) {
host = IAM_HOST.getString(props);
}
String host = IamAuthUtils.getIamHost(IAM_HOST.getString(props), hostSpec);

int port = getPort(props, hostSpec);
int port = IamAuthUtils.getIamPort(
IAM_DEFAULT_PORT.getInteger(props),
hostSpec,
this.pluginService.getDialect().getDefaultPort());

final String iamRegion = IAM_REGION.getString(props);
final Region region = StringUtils.isNullOrEmpty(iamRegion)
Expand Down Expand Up @@ -261,26 +262,6 @@ public static void clearCache() {
tokenCache.clear();
}

private int getPort(Properties props, HostSpec hostSpec) {
if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) {
int defaultPort = IAM_DEFAULT_PORT.getInteger(props);
if (defaultPort > 0) {
return defaultPort;
} else {
LOGGER.finest(
() -> Messages.get(
"IamAuthConnectionPlugin.invalidPort",
new Object[] {defaultPort}));
}
}

if (hostSpec.isPortSpecified()) {
return hostSpec.getPort();
} else {
return this.pluginService.getDialect().getDefaultPort();
}
}

private Region getRdsRegion(final String hostname) throws SQLException {

// Get Region
Expand Down Expand Up @@ -312,27 +293,4 @@ private Region getRdsRegion(final String hostname) throws SQLException {

return regionOptional.get();
}

static class TokenInfo {

private final String token;
private final Instant expiration;

public TokenInfo(final String token, final Instant expiration) {
this.token = token;
this.expiration = expiration;
}

public String getToken() {
return this.token;
}

public Instant getExpiration() {
return this.expiration;
}

public boolean isExpired() {
return Instant.now().isAfter(this.expiration);
}
}
}
41 changes: 41 additions & 0 deletions wrapper/src/main/java/software/amazon/jdbc/plugin/TokenInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon.jdbc.plugin;

import java.time.Instant;

public class TokenInfo {
private final String token;
private final Instant expiration;

public TokenInfo(final String token, final Instant expiration) {
this.token = token;
this.expiration = expiration;
}

public String getToken() {
return this.token;
}

public Instant getExpiration() {
return this.expiration;
}

public boolean isExpired() {
return Instant.now().isAfter(this.expiration);
}
}
Loading

0 comments on commit 02cdd53

Please sign in to comment.