diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 32ea6b384..3e66b7e96 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -502,10 +502,7 @@ private boolean shouldReconnectToWriter(final Boolean readOnly) { * @throws SQLException if an error occurs */ private void switchCurrentConnectionTo(final HostSpec host, final Connection connection) throws SQLException { - final Connection currentConnection = this.pluginService.getCurrentConnection(); - if (currentConnection != connection) { - invalidateCurrentConnection(); - } + Connection currentConnection = this.pluginService.getCurrentConnection(); final boolean readOnly; if (isWriter(host)) { @@ -517,7 +514,12 @@ private void switchCurrentConnectionTo(final HostSpec host, final Connection con } else { readOnly = false; } - transferSessionState(currentConnection, connection, readOnly); + + if (currentConnection != connection) { + transferSessionState(currentConnection, connection, readOnly); + invalidateCurrentConnection(); + } + this.pluginService.setCurrentConnection(connection, host); if (this.pluginManagerService != null) { @@ -563,6 +565,8 @@ private void dealWithOriginalException( if (this.lastExceptionDealtWith != originalException && shouldExceptionTriggerConnectionSwitch(originalException)) { invalidateCurrentConnection(); + this.pluginService.setAvailability( + this.pluginService.getCurrentHostSpec().getAliases(), HostAvailability.NOT_AVAILABLE); try { pickNewConnection(); } catch (final SQLException e) { @@ -702,7 +706,6 @@ protected void invalidateCurrentConnection() { return; } - final HostSpec originalHost = this.pluginService.getCurrentHostSpec(); if (this.pluginService.isInTransaction()) { isInTransaction = this.pluginService.isInTransaction(); try { @@ -719,19 +722,6 @@ protected void invalidateCurrentConnection() { } catch (final SQLException e) { // swallow this exception, current connection should be useless anyway. } - - try { - this.pluginService.setCurrentConnection( - conn, - new HostSpec( - originalHost.getHost(), - originalHost.getPort(), - originalHost.getRole(), - HostAvailability.NOT_AVAILABLE)); - this.pluginService.setAvailability(originalHost.getAliases(), HostAvailability.NOT_AVAILABLE); - } catch (final SQLException e) { - LOGGER.fine(() -> Messages.get("Failover.failedToUpdateCurrentHostspecAvailability")); - } } protected synchronized void pickNewConnection() throws SQLException { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index c85f5404c..b937399e7 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -47,8 +47,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.HostAvailability; @@ -80,10 +78,9 @@ class FailoverConnectionPluginTest { @Mock ClusterAwareWriterFailoverHandler mockWriterFailoverHandler; @Mock ReaderFailoverResult mockReaderResult; @Mock WriterFailoverResult mockWriterResult; - @Captor ArgumentCaptor hostSpecArgumentCaptor; @Mock JdbcCallable mockSqlFunction; - private Properties properties = new Properties(); + private final Properties properties = new Properties(); private FailoverConnectionPlugin plugin; private AutoCloseable closeable; @@ -425,8 +422,6 @@ void test_invalidateCurrentConnection_inTransaction() throws SQLException { when(mockHostSpec.getPort()).thenReturn(123); when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - final HostSpec expectedHostSpec = new HostSpec("host", 123, HostRole.READER, HostAvailability.NOT_AVAILABLE); - initializePlugin(); plugin.invalidateCurrentConnection(); verify(mockConnection).rollback(); @@ -434,25 +429,19 @@ void test_invalidateCurrentConnection_inTransaction() throws SQLException { // Assert SQL exceptions thrown during rollback do not get propagated. doThrow(new SQLException()).when(mockConnection).rollback(); assertDoesNotThrow(() -> plugin.invalidateCurrentConnection()); - - verify(mockPluginService, times(2)).setCurrentConnection(eq(mockConnection), hostSpecArgumentCaptor.capture()); - assertEquals(expectedHostSpec, hostSpecArgumentCaptor.getValue()); } @Test - void test_invalidateCurrentConnection_notInTransaction() throws SQLException { + void test_invalidateCurrentConnection_notInTransaction() { when(mockPluginService.isInTransaction()).thenReturn(false); when(mockHostSpec.getHost()).thenReturn("host"); when(mockHostSpec.getPort()).thenReturn(123); when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - final HostSpec expectedHostSpec = new HostSpec("host", 123, HostRole.READER, HostAvailability.NOT_AVAILABLE); initializePlugin(); plugin.invalidateCurrentConnection(); verify(mockPluginService).isInTransaction(); - verify(mockPluginService).setCurrentConnection(eq(mockConnection), hostSpecArgumentCaptor.capture()); - assertEquals(expectedHostSpec, hostSpecArgumentCaptor.getValue()); } @Test @@ -462,7 +451,6 @@ void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { when(mockHostSpec.getHost()).thenReturn("host"); when(mockHostSpec.getPort()).thenReturn(123); when(mockHostSpec.getRole()).thenReturn(HostRole.READER); - final HostSpec expectedHostSpec = new HostSpec("host", 123, HostRole.READER, HostAvailability.NOT_AVAILABLE); initializePlugin(); plugin.invalidateCurrentConnection(); @@ -472,8 +460,6 @@ void test_invalidateCurrentConnection_withOpenConnection() throws SQLException { verify(mockConnection, times(2)).isClosed(); verify(mockConnection, times(2)).close(); - verify(mockPluginService, times(2)).setCurrentConnection(eq(mockConnection), hostSpecArgumentCaptor.capture()); - assertEquals(expectedHostSpec, hostSpecArgumentCaptor.getValue()); } @Test