Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to set up a lambda to initialize new connections #705

Merged
merged 9 commits into from
Oct 27, 2023
22 changes: 22 additions & 0 deletions docs/using-the-jdbc-driver/UsingTheJdbcDriver.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,28 @@ DriverConfigurationProfiles.addOrReplaceProfile(
CustomConnectionPluginFactory.class));
```

### Executing Custom Code When Initializing a Connection
In some use cases you may need to define a specific configuration for a new driver connection before your application can use it. For instance:
- you might need to run some initial SQL queries when a connection is established, or;
- you might need to check for some additional conditions to determine the initialization configuration required for a particular connection.

The AWS JDBC Driver allows specifying a special function that can initialize a connection. It can be done with `ConnectionProviderManager.setConnectionInitFunc` method. The `resetConnectionInitFunc` method is also available to remove the function.

The initialization function is called for all connections, including connections opened by the internal connection pools (see [Using Read Write Splitting Plugin and Internal Connection Pooling](./using-plugins/UsingTheReadWriteSplittingPlugin.md#internal-connection-pooling)). This helps user applications clean up connection sessions that have been altered by previous operations, as returning a connection to a pool will reset the state and retrieving it will call the initialization function again.

> [!WARNING]\
> Executing CPU and network intensive code in the initialization function may significantly impact the AWS JDBC Driver's overall performance.

```java
ConnectionProviderManager.setConnectionInitFunc((connection, protocol, hostSpec, props) -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: let's also add a case without a lambda

// Set custom schema for connections to a test-database
if ("test-database".equals(props.getProperty("database"))) {
connection.setSchema("test-database-schema");
}
});
```


### List of Available Plugins
The AWS JDBC Driver has several built-in plugins that are available to use. Please visit the individual plugin page for more details.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package software.amazon.jdbc;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import software.amazon.jdbc.cleanup.CanReleaseResources;

public class ConnectionProviderManager {
Expand All @@ -28,6 +31,8 @@ public class ConnectionProviderManager {
private static ConnectionProvider connProvider = null;
private final ConnectionProvider defaultProvider;

private static ConnectionInitFunc connectionInitFunc = null;

/**
* {@link ConnectionProviderManager} constructor.
*
Expand Down Expand Up @@ -193,4 +198,34 @@ public static void releaseResources() {
}
}
}

public static void setConnectionInitFunc(final @NonNull ConnectionInitFunc func) {
connectionInitFunc = func;
}

public static void resetConnectionInitFunc() {
connectionInitFunc = null;
}

public void initConnection(
final @Nullable Connection connection,
final @NonNull String protocol,
final @NonNull HostSpec hostSpec,
final @NonNull Properties props) throws SQLException {

final ConnectionInitFunc copy = connectionInitFunc;
if (copy == null) {
return;
}

copy.initConnection(connection, protocol, hostSpec, props);
}

public interface ConnectionInitFunc {
void initConnection(
final @Nullable Connection connection,
final @NonNull String protocol,
final @NonNull HostSpec hostSpec,
final @NonNull Properties props) throws SQLException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ public DefaultConnectionPlugin(
final PluginService pluginService,
final ConnectionProvider defaultConnProvider,
final PluginManagerService pluginManagerService) {
this(pluginService,
defaultConnProvider,
pluginManagerService,
new ConnectionProviderManager(defaultConnProvider));
}

public DefaultConnectionPlugin(
final PluginService pluginService,
final ConnectionProvider defaultConnProvider,
final PluginManagerService pluginManagerService,
final ConnectionProviderManager connProviderManager) {

if (pluginService == null) {
throw new IllegalArgumentException("pluginService");
}
Expand All @@ -79,7 +91,7 @@ public DefaultConnectionPlugin(

this.pluginService = pluginService;
this.pluginManagerService = pluginManagerService;
this.connProviderManager = new ConnectionProviderManager(defaultConnProvider);
this.connProviderManager = connProviderManager;
}

@Override
Expand Down Expand Up @@ -173,6 +185,8 @@ private Connection connectInternal(
telemetryContext.closeContext();
}

this.connProviderManager.initConnection(conn, driverProtocol, hostSpec, props);

this.pluginService.setAvailability(hostSpec.asAliases(), HostAvailability.AVAILABLE);
this.pluginService.updateDialect(conn);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -32,6 +33,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.stream.Stream;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -42,6 +44,8 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import software.amazon.jdbc.ConnectionProvider;
import software.amazon.jdbc.ConnectionProviderManager;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.JdbcCallable;
import software.amazon.jdbc.PluginManagerService;
import software.amazon.jdbc.PluginService;
Expand All @@ -59,12 +63,16 @@ class DefaultConnectionPluginTest {
@Mock ConnectionProvider connectionProvider;
@Mock PluginManagerService pluginManagerService;
@Mock JdbcCallable<Void, SQLException> mockSqlFunction;
@Mock JdbcCallable<Connection, SQLException> mockConnectFunction;
@Mock Connection conn;
@Mock Connection oldConn;
@Mock private TelemetryFactory mockTelemetryFactory;
@Mock TelemetryContext mockTelemetryContext;
@Mock TelemetryCounter mockTelemetryCounter;
@Mock TelemetryGauge mockTelemetryGauge;
@Mock ConnectionProviderManager mockConnectionProviderManager;
@Mock ConnectionProvider mockConnectionProvider;
@Mock HostSpec mockHostSpec;


private AutoCloseable closeable;
Expand All @@ -79,8 +87,11 @@ void setUp() {
when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter);
// noinspection unchecked
when(mockTelemetryFactory.createGauge(anyString(), any(GaugeCallable.class))).thenReturn(mockTelemetryGauge);
when(mockConnectionProviderManager.getConnectionProvider(anyString(), any(), any()))
.thenReturn(mockConnectionProvider);

plugin = new DefaultConnectionPlugin(pluginService, connectionProvider, pluginManagerService);
plugin = new DefaultConnectionPlugin(
pluginService, connectionProvider, pluginManagerService, mockConnectionProviderManager);
}

@AfterEach
Expand Down Expand Up @@ -109,6 +120,13 @@ void testExecute_closeOldConnection() throws SQLException {
verify(pluginManagerService, never()).setInTransaction(anyBoolean());
}

@Test
void testConnect() throws SQLException {
plugin.connect("anyProtocol", mockHostSpec, new Properties(), true, mockConnectFunction);
verify(mockConnectionProvider, atLeastOnce()).connect(anyString(), any(), any(), any());
verify(mockConnectionProviderManager, atLeastOnce()).initConnection(any(), anyString(), any(), any());
}

private static Stream<Arguments> multiStatementQueries() {
return Stream.of(
Arguments.of("", new ArrayList<String>()),
Expand Down
Loading