diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index a75c6ca2a..7f6a22d00 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -27,6 +27,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.exceptions.ExceptionHandler; import software.amazon.jdbc.hostavailability.HostAvailability; +import software.amazon.jdbc.states.SessionDirtyFlag; import software.amazon.jdbc.util.telemetry.TelemetryFactory; /** @@ -47,6 +48,18 @@ EnumSet setCurrentConnection( @Nullable ConnectionPlugin skipNotificationForThisPlugin) throws SQLException; + EnumSet getCurrentConnectionState(); + + void setCurrentConnectionState(SessionDirtyFlag flag); + + void resetCurrentConnectionState(SessionDirtyFlag flag); + + void resetCurrentConnectionStates(); + + boolean getAutoCommit(); + + void setAutoCommit(final boolean autoCommit); + List getHosts(); HostSpec getInitialConnectionHostSpec(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 299360aee..2cd701359 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -43,6 +43,7 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostavailability.HostAvailabilityStrategyFactory; import software.amazon.jdbc.hostlistprovider.StaticHostListProvider; +import software.amazon.jdbc.states.SessionDirtyFlag; import software.amazon.jdbc.util.CacheMap; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -68,6 +69,8 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, private final ExceptionManager exceptionManager; protected final DialectProvider dialectProvider; protected Dialect dialect; + protected EnumSet currentConnectionSessionState = EnumSet.noneOf(SessionDirtyFlag.class); + protected boolean isAutoCommit = false; public PluginServiceImpl( @NonNull final ConnectionPluginManager pluginManager, @@ -568,4 +571,27 @@ public String getTargetName() { return this.pluginManager.getDefaultConnProvider().getTargetName(); } + public EnumSet getCurrentConnectionState() { + return this.currentConnectionSessionState.clone(); + } + + public void setCurrentConnectionState(SessionDirtyFlag flag) { + this.currentConnectionSessionState.add(flag); + } + + public void resetCurrentConnectionState(SessionDirtyFlag flag) { + this.currentConnectionSessionState.remove(flag); + } + + public void resetCurrentConnectionStates() { + this.currentConnectionSessionState.clear(); + } + + public boolean getAutoCommit() { + return this.isAutoCommit; + } + + public void setAutoCommit(final boolean autoCommit) { + this.isAutoCommit = autoCommit; + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index f7577bfe7..890589021 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -42,6 +42,10 @@ import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsHelper; +import software.amazon.jdbc.states.RestoreSessionStateCallable; +import software.amazon.jdbc.states.SessionDirtyFlag; +import software.amazon.jdbc.states.SessionStateHelper; +import software.amazon.jdbc.states.SessionStateTransferCallable; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; @@ -85,6 +89,10 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { static final String METHOD_ABORT = "Connection.abort"; static final String METHOD_CLOSE = "Connection.close"; static final String METHOD_IS_CLOSED = "Connection.isClosed"; + + protected static SessionStateTransferCallable sessionStateTransferCallable; + protected static RestoreSessionStateCallable restoreSessionStateCallable; + private final PluginService pluginService; protected final Properties properties; protected boolean enableFailoverSetting; @@ -199,6 +207,22 @@ public FailoverConnectionPlugin(final PluginService pluginService, final Propert this.failoverReaderFailedCounter = telemetryFactory.createCounter("readerFailover.completed.failed.count"); } + public static void setSessionStateTransferFunc(SessionStateTransferCallable callable) { + sessionStateTransferCallable = callable; + } + + public static void resetSessionStateTransferFunc() { + sessionStateTransferCallable = null; + } + + public static void setRestoreSessionStateFunc(RestoreSessionStateCallable callable) { + restoreSessionStateCallable = callable; + } + + public static void resetRestoreSessionStateFunc() { + restoreSessionStateCallable = null; + } + @Override public Set getSubscribedMethods() { return subscribedMethods; @@ -521,9 +545,10 @@ private boolean shouldAttemptReaderConnection() { */ private void switchCurrentConnectionTo(final HostSpec host, final Connection connection) throws SQLException { Connection currentConnection = this.pluginService.getCurrentConnection(); + HostSpec currentHostSpec = this.pluginService.getCurrentHostSpec(); if (currentConnection != connection) { - transferSessionState(currentConnection, connection); + transferSessionState(currentConnection, currentHostSpec, connection, host); invalidateCurrentConnection(); } @@ -535,46 +560,73 @@ private void switchCurrentConnectionTo(final HostSpec host, final Connection con } /** - * Transfers basic session state from one connection to another. + * Transfers session state from one connection to another. * - * @param from The connection to transfer state from - * @param to The connection to transfer state to + * @param src The connection to transfer state from + * @param srcHostSpec The connection {@link HostSpec} to transfer state from + * @param dest The connection to transfer state to + * @param destHostSpec The connection {@link HostSpec} to transfer state to * @throws SQLException if a database access error occurs, this method is called on a closed connection, this * method is called during a distributed transaction, or this method is called during a * transaction */ protected void transferSessionState( - final Connection from, - final Connection to) throws SQLException { + final Connection src, + final HostSpec srcHostSpec, + final Connection dest, + final HostSpec destHostSpec) throws SQLException { - if (from == null || to == null) { + if (src == null || dest == null) { return; } - to.setReadOnly(from.isReadOnly()); - to.setAutoCommit(from.getAutoCommit()); - to.setTransactionIsolation(from.getTransactionIsolation()); + EnumSet sessionState = this.pluginService.getCurrentConnectionState(); + + SessionStateTransferCallable callableCopy = sessionStateTransferCallable; + if (callableCopy != null) { + final boolean isHandled = callableCopy.transferSessionState(sessionState, src, srcHostSpec, dest, destHostSpec); + if (isHandled) { + // Custom function has handled session transfer + return; + } + } + + // Otherwise, lets run default logic. + sessionState = this.pluginService.getCurrentConnectionState(); + final SessionStateHelper helper = new SessionStateHelper(); + helper.transferSessionState(sessionState, src, dest); } /** * Restores partial session state from saved values to a connection. * - * @param to The connection to transfer state to + * @param dest The connection to transfer state to * @throws SQLException if a database access error occurs, this method is called on a closed connection, this * method is called during a distributed transaction, or this method is called during a * transaction */ - protected void restoreSessionState(final Connection to) throws SQLException { - if (to == null) { + protected void restoreSessionState(final Connection dest) throws SQLException { + if (dest == null) { return; } - if (savedReadOnlyStatus != null) { - to.setReadOnly(savedReadOnlyStatus); - } - if (savedAutoCommitStatus != null) { - to.setAutoCommit(savedAutoCommitStatus); + final RestoreSessionStateCallable callableCopy = restoreSessionStateCallable; + if (callableCopy != null) { + final boolean isHandled = callableCopy.restoreSessionState( + this.pluginService.getCurrentConnectionState(), + dest, + this.savedReadOnlyStatus, + this.savedAutoCommitStatus + ); + if (isHandled) { + // Custom function has handled everything. + return; + } } + + // Otherwise, lets run default logic. + final SessionStateHelper helper = new SessionStateHelper(); + helper.restoreSessionState(dest, this.savedReadOnlyStatus, this.savedAutoCommitStatus); } private void dealWithOriginalException( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index d79dfd776..4ea8d6345 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -40,6 +40,9 @@ import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.plugin.failover.FailoverSQLException; +import software.amazon.jdbc.states.SessionDirtyFlag; +import software.amazon.jdbc.states.SessionStateHelper; +import software.amazon.jdbc.states.SessionStateTransferCallable; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.WrapperUtils; @@ -61,6 +64,9 @@ public class ReadWriteSplittingPlugin extends AbstractConnectionPlugin static final String METHOD_SET_READ_ONLY = "Connection.setReadOnly"; static final String METHOD_CLEAR_WARNINGS = "Connection.clearWarnings"; + protected static SessionStateTransferCallable sessionStateTransferCallable; + + private final PluginService pluginService; private final Properties properties; private final String readerSelectorStrategy; @@ -105,6 +111,14 @@ public class ReadWriteSplittingPlugin extends AbstractConnectionPlugin this.readerConnection = readerConnection; } + public static void setSessionStateTransferFunc(SessionStateTransferCallable callable) { + sessionStateTransferCallable = callable; + } + + public static void resetSessionStateTransferFunc() { + sessionStateTransferCallable = null; + } + @Override public Set getSubscribedMethods() { return subscribedMethods; @@ -408,7 +422,7 @@ private void switchCurrentConnectionTo( return; } - transferSessionStateOnReadWriteSplit(newConnection); + transferSessionStateOnReadWriteSplit(newConnection, newConnectionHost); this.pluginService.setCurrentConnection(newConnection, newConnectionHost); LOGGER.finest(() -> Messages.get( "ReadWriteSplittingPlugin.settingCurrentConnection", @@ -421,19 +435,41 @@ private void switchCurrentConnectionTo( * status. This method is only called when setReadOnly is being called; the read-only status * will be updated when the setReadOnly call continues down the plugin chain * - * @param to The connection to transfer state to + * @param dest The destination connection to transfer state to + * @param destHostSpec The destination connection {@link HostSpec} * @throws SQLException if a database access error occurs, this method is called on a closed * connection, or this method is called during a distributed transaction */ protected void transferSessionStateOnReadWriteSplit( - final Connection to) throws SQLException { - final Connection from = this.pluginService.getCurrentConnection(); - if (from == null || to == null) { + final Connection dest, + final HostSpec destHostSpec) + throws SQLException { + + final Connection src = this.pluginService.getCurrentConnection(); + if (src == null || dest == null) { return; } - to.setAutoCommit(from.getAutoCommit()); - to.setTransactionIsolation(from.getTransactionIsolation()); + EnumSet sessionState = this.pluginService.getCurrentConnectionState(); + + SessionStateTransferCallable callableCopy = sessionStateTransferCallable; + if (callableCopy != null) { + final boolean isHandled = callableCopy.transferSessionState( + sessionState, + src, + this.pluginService.getCurrentHostSpec(), + dest, + destHostSpec); + if (isHandled) { + // Custom function has handled session transfer + return; + } + } + + sessionState = this.pluginService.getCurrentConnectionState(); + sessionState.remove(SessionDirtyFlag.READONLY); // We don't want to change READONLY flag of the connection + final SessionStateHelper helper = new SessionStateHelper(); + helper.transferSessionState(sessionState, src, dest); } private synchronized void switchToReaderConnection(final List hosts) diff --git a/wrapper/src/main/java/software/amazon/jdbc/states/RestoreSessionStateCallable.java b/wrapper/src/main/java/software/amazon/jdbc/states/RestoreSessionStateCallable.java new file mode 100644 index 000000000..b1e3a236e --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/states/RestoreSessionStateCallable.java @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.states; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.EnumSet; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +public interface RestoreSessionStateCallable { + /** + * Restores partial session state from saved values to a connection. + * + * @param sessionState Session state flags for from-connection + * @param dest The destination connection to transfer state to + * @param readOnly ReadOnly flag to set to + * @param autoCommit AutoCommit flag to set to + * @return true, if session state is restored successful and no default logic should be executed after. + * False, if default logic should be executed. + */ + boolean restoreSessionState( + final @NonNull EnumSet sessionState, + final @NonNull Connection dest, + final @Nullable Boolean readOnly, + final @Nullable Boolean autoCommit) + throws SQLException; +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/states/SessionDirtyFlag.java b/wrapper/src/main/java/software/amazon/jdbc/states/SessionDirtyFlag.java new file mode 100644 index 000000000..985da0204 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/states/SessionDirtyFlag.java @@ -0,0 +1,33 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.states; + + +import java.util.EnumSet; + +public enum SessionDirtyFlag { + READONLY, + AUTO_COMMIT, + TRANSACTION_ISOLATION, + CATALOG, + NETWORK_TIMEOUT, + SCHEMA, + TYPE_MAP, + HOLDABILITY; + + public static final EnumSet ALL = EnumSet.allOf(SessionDirtyFlag.class); +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/states/SessionStateHelper.java b/wrapper/src/main/java/software/amazon/jdbc/states/SessionStateHelper.java new file mode 100644 index 000000000..d4b203593 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/states/SessionStateHelper.java @@ -0,0 +1,99 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.states; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.EnumSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public class SessionStateHelper { + + /** + * Transfers session state from source connection to destination connection. + * + * @param sessionState Session state of source connection + * @param src The source connection to transfer state from + * @param dest The destination connection to transfer state to + * @throws SQLException if a database access error occurs, this method is called on a closed connection, this + * method is called during a distributed transaction, or this method is called during a + * transaction + */ + public void transferSessionState( + final EnumSet sessionState, + final Connection src, + final Connection dest) throws SQLException { + + if (src == null || dest == null) { + return; + } + + if (sessionState.contains(SessionDirtyFlag.READONLY)) { + dest.setReadOnly(src.isReadOnly()); + } + if (sessionState.contains(SessionDirtyFlag.AUTO_COMMIT)) { + dest.setAutoCommit(src.getAutoCommit()); + } + if (sessionState.contains(SessionDirtyFlag.TRANSACTION_ISOLATION)) { + dest.setTransactionIsolation(src.getTransactionIsolation()); + } + if (sessionState.contains(SessionDirtyFlag.CATALOG)) { + dest.setCatalog(src.getCatalog()); + } + if (sessionState.contains(SessionDirtyFlag.SCHEMA)) { + dest.setSchema(src.getSchema()); + } + if (sessionState.contains(SessionDirtyFlag.TYPE_MAP)) { + dest.setTypeMap(src.getTypeMap()); + } + if (sessionState.contains(SessionDirtyFlag.HOLDABILITY)) { + dest.setHoldability(src.getHoldability()); + } + if (sessionState.contains(SessionDirtyFlag.NETWORK_TIMEOUT)) { + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + dest.setNetworkTimeout(executorService, src.getNetworkTimeout()); + executorService.shutdown(); + } + } + + /** + * Restores partial session state from saved values to a connection. + * + * @param dest The destination connection to transfer state to + * @param readOnly ReadOnly flag to set to + * @param autoCommit AutoCommit flag to set to + * @throws SQLException if a database access error occurs, this method is called on a closed connection, this + * method is called during a distributed transaction, or this method is called during a + * transaction + */ + public void restoreSessionState(final Connection dest, final Boolean readOnly, final Boolean autoCommit) + throws SQLException { + + if (dest == null) { + return; + } + + if (readOnly != null) { + dest.setReadOnly(readOnly); + } + if (autoCommit != null) { + dest.setAutoCommit(autoCommit); + } + } + +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/states/SessionStateTransferCallable.java b/wrapper/src/main/java/software/amazon/jdbc/states/SessionStateTransferCallable.java new file mode 100644 index 000000000..6872bb66c --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/states/SessionStateTransferCallable.java @@ -0,0 +1,45 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.states; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.EnumSet; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import software.amazon.jdbc.HostSpec; + +public interface SessionStateTransferCallable { + + /** + * Transfers session state from one connection to another. + * + * @param sessionState Session state flags for from-connection + * @param src The source connection to transfer state from + * @param srcHostSpec The source connection {@link HostSpec} + * @param dest The destination connection to transfer state to + * @param destHostSpec The destination connection {@link HostSpec} + * @return true, if session state transfer is successful and no default logic should be executed after. + * False, if default logic should be executed. + */ + boolean transferSessionState( + final @NonNull EnumSet sessionState, + final @NonNull Connection src, + final @Nullable HostSpec srcHostSpec, + final @NonNull Connection dest, + final @Nullable HostSpec destHostSpec) throws SQLException; +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index 16cb8a8f8..a48cdbc55 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -47,16 +47,12 @@ import software.amazon.jdbc.PluginServiceImpl; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.cleanup.CanReleaseResources; -import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.DialectManager; -import software.amazon.jdbc.dialect.DialectProvider; import software.amazon.jdbc.dialect.HostListProviderSupplier; -import software.amazon.jdbc.hostlistprovider.ConnectionStringHostListProvider; +import software.amazon.jdbc.states.SessionDirtyFlag; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; -import software.amazon.jdbc.util.telemetry.DefaultTelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class ConnectionWrapper implements Connection, CanReleaseResources { @@ -184,6 +180,7 @@ public void abort(final Executor executor) throws SQLException { () -> { this.pluginService.getCurrentConnection().abort(executor); this.pluginManagerService.setInTransaction(false); + this.pluginService.resetCurrentConnectionStates(); }, executor); } @@ -209,6 +206,7 @@ public void close() throws SQLException { this.pluginService.getCurrentConnection().close(); this.openConnectionStacktrace = null; this.pluginManagerService.setInTransaction(false); + this.pluginService.resetCurrentConnectionStates(); }); this.releaseResources(); } @@ -222,7 +220,12 @@ public void commit() throws SQLException { "Connection.commit", () -> { this.pluginService.getCurrentConnection().commit(); + final boolean isInTransaction = this.pluginService.isInTransaction(); this.pluginManagerService.setInTransaction(false); + if (isInTransaction + && this.pluginService.getCurrentConnectionState().contains(SessionDirtyFlag.AUTO_COMMIT)) { + this.pluginService.resetCurrentConnectionState(SessionDirtyFlag.AUTO_COMMIT); + } }); } @@ -320,10 +323,9 @@ public Statement createStatement( this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.createStatement", - () -> - this.pluginService - .getCurrentConnection() - .createStatement(resultSetType, resultSetConcurrency, resultSetHoldability), + () -> this.pluginService + .getCurrentConnection() + .createStatement(resultSetType, resultSetConcurrency, resultSetHoldability), resultSetType, resultSetConcurrency, resultSetHoldability); @@ -352,6 +354,7 @@ public void setReadOnly(final boolean readOnly) throws SQLException { () -> { this.pluginService.getCurrentConnection().setReadOnly(readOnly); this.pluginManagerService.setReadOnly(readOnly); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.READONLY); }, readOnly); } @@ -681,7 +684,12 @@ public void rollback() throws SQLException { "Connection.rollback", () -> { this.pluginService.getCurrentConnection().rollback(); + final boolean isInTransaction = this.pluginService.isInTransaction(); this.pluginManagerService.setInTransaction(false); + if (isInTransaction + && this.pluginService.getCurrentConnectionState().contains(SessionDirtyFlag.AUTO_COMMIT)) { + this.pluginService.resetCurrentConnectionState(SessionDirtyFlag.AUTO_COMMIT); + } }); } @@ -710,7 +718,14 @@ public void setAutoCommit(final boolean autoCommit) throws SQLException { this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setAutoCommit", - () -> this.pluginService.getCurrentConnection().setAutoCommit(autoCommit), + () -> { + final boolean currentAutoCommit = this.pluginService.getAutoCommit(); + this.pluginService.getCurrentConnection().setAutoCommit(autoCommit); + this.pluginService.setAutoCommit(autoCommit); + if (currentAutoCommit != autoCommit) { + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.AUTO_COMMIT); + } + }, autoCommit); } @@ -732,7 +747,10 @@ public void setCatalog(final String catalog) throws SQLException { this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setCatalog", - () -> this.pluginService.getCurrentConnection().setCatalog(catalog), + () -> { + this.pluginService.getCurrentConnection().setCatalog(catalog); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.CATALOG); + }, catalog); } @@ -766,7 +784,10 @@ public void setHoldability(final int holdability) throws SQLException { this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setHoldability", - () -> this.pluginService.getCurrentConnection().setHoldability(holdability), + () -> { + this.pluginService.getCurrentConnection().setHoldability(holdability); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.HOLDABILITY); + }, holdability); } @@ -777,7 +798,10 @@ public void setNetworkTimeout(final Executor executor, final int milliseconds) t this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setNetworkTimeout", - () -> this.pluginService.getCurrentConnection().setNetworkTimeout(executor, milliseconds), + () -> { + this.pluginService.getCurrentConnection().setNetworkTimeout(executor, milliseconds); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.NETWORK_TIMEOUT); + }, executor, milliseconds); } @@ -812,7 +836,10 @@ public void setSchema(final String schema) throws SQLException { this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setSchema", - () -> this.pluginService.getCurrentConnection().setSchema(schema), + () -> { + this.pluginService.getCurrentConnection().setSchema(schema); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.SCHEMA); + }, schema); } @@ -823,7 +850,10 @@ public void setTransactionIsolation(final int level) throws SQLException { this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setTransactionIsolation", - () -> this.pluginService.getCurrentConnection().setTransactionIsolation(level), + () -> { + this.pluginService.getCurrentConnection().setTransactionIsolation(level); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.TRANSACTION_ISOLATION); + }, level); } @@ -834,7 +864,10 @@ public void setTypeMap(final Map> map) throws SQLException { this.pluginManager, this.pluginService.getCurrentConnection(), "Connection.setTypeMap", - () -> this.pluginService.getCurrentConnection().setTypeMap(map), + () -> { + this.pluginService.getCurrentConnection().setTypeMap(map); + this.pluginService.setCurrentConnectionState(SessionDirtyFlag.TYPE_MAP); + }, map); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java index aae93ae7d..62007ac73 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java @@ -69,6 +69,7 @@ import software.amazon.jdbc.dialect.UnknownDialect; import software.amazon.jdbc.hostavailability.HostAvailability; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.states.SessionDirtyFlag; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @Disabled @@ -288,6 +289,36 @@ public EnumSet setCurrentConnection(@NonNull Connection conne return null; } + @Override + public EnumSet getCurrentConnectionState() { + return EnumSet.noneOf(SessionDirtyFlag.class); + } + + @Override + public void setCurrentConnectionState(SessionDirtyFlag flag) { + + } + + @Override + public void resetCurrentConnectionState(SessionDirtyFlag flag) { + + } + + @Override + public void resetCurrentConnectionStates() { + + } + + @Override + public boolean getAutoCommit() { + return false; + } + + @Override + public void setAutoCommit(boolean autoCommit) { + + } + @Override public List getHosts() { return null; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index d02e07306..551f21173 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -62,6 +62,7 @@ import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.hostlistprovider.DynamicHostListProvider; +import software.amazon.jdbc.states.SessionDirtyFlag; import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.SqlState; import software.amazon.jdbc.util.telemetry.GaugeCallable; @@ -111,6 +112,7 @@ void init() throws SQLException { when(mockPluginService.getCurrentHostSpec()).thenReturn(mockHostSpec); when(mockPluginService.connect(any(HostSpec.class), eq(properties))).thenReturn(mockConnection); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockPluginService.getCurrentConnectionState()).thenReturn(EnumSet.allOf(SessionDirtyFlag.class)); when(mockReaderFailoverHandler.failover(any(), any())).thenReturn(mockReaderResult); when(mockWriterFailoverHandler.failover(any())).thenReturn(mockWriterResult); @@ -197,10 +199,10 @@ void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLEx void test_syncSessionState_withNullConnections() throws SQLException { initializePlugin(); - plugin.transferSessionState(null, mockConnection); + plugin.transferSessionState(null, null, mockConnection, null); verify(mockConnection, never()).getAutoCommit(); - plugin.transferSessionState(mockConnection, null); + plugin.transferSessionState(mockConnection, null, null, null); verify(mockConnection, never()).getAutoCommit(); } @@ -214,7 +216,7 @@ void test_syncSessionState() throws SQLException { initializePlugin(); - plugin.transferSessionState(mockConnection, mockConnection); + plugin.transferSessionState(mockConnection, null, mockConnection, null); verify(target).setReadOnly(eq(false)); verify(target).getAutoCommit(); verify(target).getTransactionIsolation(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index 40aa20c38..18ba25f21 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -58,6 +58,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; +import software.amazon.jdbc.states.SessionDirtyFlag; import software.amazon.jdbc.util.SqlState; public class ReadWriteSplittingPluginTest { @@ -140,6 +141,7 @@ void mockDefaultBehavior() throws SQLException { when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class))) .thenReturn(mockReaderConn3); when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); + when(mockPluginService.getCurrentConnectionState()).thenReturn(EnumSet.allOf(SessionDirtyFlag.class)); when(this.mockConnectFunc.call()).thenReturn(mockWriterConn); when(mockWriterConn.createStatement()).thenReturn(mockStatement); when(mockReaderConn1.createStatement()).thenReturn(mockStatement);