Skip to content

Commit

Permalink
Handle edge cases and support autocommit=false
Browse files Browse the repository at this point in the history
  • Loading branch information
luneo7 committed Oct 28, 2024
1 parent 362ba1e commit 5927278
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
*/
class JdbcLocalTxnInterceptor implements MethodInterceptor {
private static final Logger logger = LoggerFactory.getLogger(JdbcLocalTxnInterceptor.class);
private static final ConcurrentMap<Method, Transactional> methodsTransactionals = new ConcurrentHashMap<Method, Transactional>();
private static final ConcurrentMap<Method, Transactional> transactionalMethods = new ConcurrentHashMap<Method, Transactional>();

private final Provider<JooqPersistService> jooqPersistServiceProvider;
private final Provider<UnitOfWork> unitOfWorkProvider;
Expand All @@ -44,110 +44,91 @@ class JdbcLocalTxnInterceptor implements MethodInterceptor {
private static class Internal {
}

// Tracks if the unit of work was begun implicitly by this transaction.
private final ThreadLocal<Boolean> didWeStartWork = new ThreadLocal<Boolean>();

@Inject
public JdbcLocalTxnInterceptor(Provider<JooqPersistService> jooqPersistServiceProvider,
Provider<UnitOfWork> unitOfWorkProvider) {
this.jooqPersistServiceProvider = jooqPersistServiceProvider;
this.unitOfWorkProvider = unitOfWorkProvider;
}

@Override
public Object invoke(final MethodInvocation methodInvocation) throws Throwable {
UnitOfWork unitOfWork = unitOfWorkProvider.get();
JooqPersistService jooqProvider = jooqPersistServiceProvider.get();

// Should we start a unit of work?
if (!jooqProvider.isWorking()) {
if (jooqProvider.isWorking()) {
// Allow 'joining' of transactions if there is an enclosing @Transactional method.
return methodInvocation.proceed();
} else {
// We should start a unit of work
unitOfWork.begin();
didWeStartWork.set(true);
}

Transactional transactional = readTransactionMetadata(methodInvocation);
DefaultConnectionProvider conn = jooqProvider.getConnectionWrapper();
DefaultConnectionProvider conn = jooqProvider.getThreadLocals().getConnectionProvider();
boolean reenableAutoCommit = false;

// Allow 'joining' of transactions if there is an enclosing @Transactional method.
if (!conn.getAutoCommit()) {
return methodInvocation.proceed();
}
try {
if (conn.getAutoCommit()) {
logger.debug("Disabling JDBC auto commit for this thread");
reenableAutoCommit = true;
conn.setAutoCommit(false);
}

logger.debug("Disabling JDBC auto commit for this thread");
conn.setAutoCommit(false);
Object result = methodInvocation.proceed();

Object result;
logger.debug("Committing JDBC transaction");
conn.commit();

try {
result = methodInvocation.proceed();
return result;
} catch (Exception e) {
//commit transaction only if rollback didn't occur
if (rollbackIfNecessary(transactional, e, conn)) {
if (rollbackIfNecessary(readTransactionMetadata(methodInvocation), e, conn)) {
logger.debug("Committing JDBC transaction");
conn.commit();
}

logger.debug("Enabling auto commit for this thread");
conn.setAutoCommit(true);

//propagate whatever exception is thrown anyway
throw e;
} finally {
// Close the em if necessary (guarded so this code doesn't run unless catch fired).
if (null != didWeStartWork.get() && conn.getAutoCommit()) {
didWeStartWork.remove();
unitOfWork.end();
}
}

// everything was normal so commit the txn (do not move into try block above as it
// interferes with the advised method's throwing semantics)
try {
logger.debug("Committing JDBC transaction");
conn.commit();
logger.debug("Enabling auto commit for this thread");
conn.setAutoCommit(true);
} finally {
//close the em if necessary
if (null != didWeStartWork.get()) {
didWeStartWork.remove();
if (reenableAutoCommit) {
try {
conn.setAutoCommit(true);
} catch (Exception ignored) {
}
}
unitOfWork.end();
}
}

//or return result
return result;
}

private Transactional readTransactionMetadata(final MethodInvocation methodInvocation) {
Method method = methodInvocation.getMethod();
Transactional cachedTransactional = methodsTransactionals.get(method);
Transactional cachedTransactional = transactionalMethods.get(method);
if (cachedTransactional != null) {
return cachedTransactional;
}

Transactional transactional = method.getAnnotation(Transactional.class);
if (null == transactional) {
// If none on method, try the class.
Class<?> targetClass = methodInvocation.getThis().getClass();
transactional = targetClass.getAnnotation(Transactional.class);
}

if (null != transactional) {
methodsTransactionals.put(method, transactional);
} else {
// If there is no transactional annotation present, use the default
transactional = Internal.class.getAnnotation(Transactional.class);
}

return transactional;
return transactionalMethods.computeIfAbsent(method, ignored -> {
Transactional transactional;
transactional = method.getAnnotation(Transactional.class);
if (null == transactional) {
// If none on method, try the class.
Class<?> targetClass = methodInvocation.getThis().getClass();
transactional = targetClass.getAnnotation(Transactional.class);
}
if (null == transactional) {
// If there is no transactional annotation present, use the default
transactional = Internal.class.getAnnotation(Transactional.class);
}
return transactional;
});
}

/**
* Returns True if rollback DID NOT HAPPEN (i.e. if commit should continue).
*
* @param transactional The metadata annotation of the method
* @param e The exception to test for rollback
* @param txn A JPA Transaction to issue rollbacks on
* @param conn A DefaultConnectionProvider to issue rollbacks on
*/
private boolean rollbackIfNecessary(final Transactional transactional,
final Exception e,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ protected void configurePersistence() {

transactionInterceptor = new JdbcLocalTxnInterceptor(getProvider(JooqPersistService.class),
getProvider(UnitOfWork.class));
requestInjection(transactionInterceptor);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.jooq.conf.Settings;
import org.jooq.impl.DSL;
import org.jooq.impl.DefaultConnectionProvider;
import org.jooq.tools.jdbc.JDBCUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -47,8 +48,7 @@ class JooqPersistService implements Provider<DSLContext>, UnitOfWork, PersistSer

private static final Logger logger = LoggerFactory.getLogger(JooqPersistService.class);

private final ThreadLocal<DSLContext> threadFactory = new ThreadLocal<DSLContext>();
private final ThreadLocal<DefaultConnectionProvider> threadConnection = new ThreadLocal<DefaultConnectionProvider>();
private static final ThreadLocal<ThreadLocals> threadFactory = new ThreadLocal<>();
private final Provider<DataSource> jdbcSource;
private final SQLDialect sqlDialect;
private final Settings jooqSettings;
Expand All @@ -68,86 +68,94 @@ public JooqPersistService(final Provider<DataSource> jdbcSource, final SQLDialec
}
}

@Override
public DSLContext get() {
DSLContext factory = threadFactory.get();
if(null == factory) {
ThreadLocals factory = threadFactory.get();
if (null == factory) {
throw new IllegalStateException("Requested Factory outside work unit. "
+ "Try calling UnitOfWork.begin() first, use @Transactional annotation"
+ "or use a PersistFilter if you are inside a servlet environment.");
}

return factory;
return factory.getDSLContext();
}

public DefaultConnectionProvider getConnectionWrapper() {
return threadConnection.get();
public ThreadLocals getThreadLocals() {
return threadFactory.get();
}

public boolean isWorking() {
return threadFactory.get() != null;
}

@Override
public void begin() {
if(null != threadFactory.get()) {
if (null != threadFactory.get()) {
throw new IllegalStateException("Work already begun on this thread. "
+ "It looks like you have called UnitOfWork.begin() twice"
+ " without a balancing call to end() in between.");
}

DefaultConnectionProvider conn;
threadFactory.set(createThreadLocals());
}

private ThreadLocals createThreadLocals() {
DefaultConnectionProvider conn = null;
try {
logger.debug("Getting JDBC connection");
DataSource dataSource = jdbcSource.get();
Connection jdbcConn = dataSource.getConnection();
conn = new DefaultConnectionProvider(jdbcConn);
} catch (SQLException e) {
throw new RuntimeException(e);
}

DSLContext jooqFactory;
DSLContext jooqFactory;

if (configuration != null) {
logger.debug("Creating factory from configuration having dialect {}", configuration.dialect());
jooqFactory = DSL.using(conn, configuration.dialect(), configuration.settings());
} else {
if (jooqSettings == null) {
logger.debug("Creating factory with dialect {}", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect);
if (configuration != null) {
logger.debug("Creating factory from configuration having dialect {}", configuration.dialect());
jooqFactory = DSL.using(conn, configuration.dialect(), configuration.settings());
} else {
logger.debug("Creating factory with dialect {} and settings.", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect, jooqSettings);
if (jooqSettings == null) {
logger.debug("Creating factory with dialect {}", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect);
} else {
logger.debug("Creating factory with dialect {} and settings.", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect, jooqSettings);
}
}
return new ThreadLocals(jooqFactory, conn);
} catch (Exception e) {
if (conn != null) {
JDBCUtils.safeClose(conn.acquire());
}
throw new RuntimeException(e);
}
threadConnection.set(conn);
threadFactory.set(jooqFactory);
}

@Override
public void end() {
DSLContext jooqFactory = threadFactory.get();
DefaultConnectionProvider conn = threadConnection.get();
ThreadLocals threadLocals = threadFactory.get();
// Let's not penalize users for calling end() multiple times.
if (null == jooqFactory) {
if (null == threadLocals) {
return;
}

try {
logger.debug("Closing JDBC connection");
conn.acquire().close();
threadLocals.getConnectionProvider().acquire().close();
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
threadFactory.remove();
}
threadFactory.remove();
threadConnection.remove();
}


@Override
public synchronized void start() {
//nothing to do on start
}

@Override
public synchronized void stop() {
//nothing to do on stop
}


}
22 changes: 22 additions & 0 deletions src/main/java/com/adamlewis/guice/persist/jooq/ThreadLocals.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.adamlewis.guice.persist.jooq;

import org.jooq.DSLContext;
import org.jooq.impl.DefaultConnectionProvider;

final class ThreadLocals {
private final DSLContext dslContext;
private final DefaultConnectionProvider connectionProvider;

ThreadLocals(DSLContext dslContext, DefaultConnectionProvider connectionProvider) {
this.dslContext = dslContext;
this.connectionProvider = connectionProvider;
}

DSLContext getDSLContext() {
return dslContext;
}

DefaultConnectionProvider getConnectionProvider() {
return connectionProvider;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.google.inject.persist.Transactional;
import com.google.inject.persist.UnitOfWork;
import org.aopalliance.intercept.MethodInvocation;
import org.jooq.SQLDialect;
import org.jooq.impl.DSL;
import org.jooq.impl.DefaultConnectionProvider;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -43,7 +45,8 @@ public void setUp() throws Exception {
connection.setAutoCommit(true);

DefaultConnectionProvider connectionProvider = new DefaultConnectionProvider(connection);
when(jooqPersistService.getConnectionWrapper()).thenReturn(connectionProvider);
when(jooqPersistService.getThreadLocals()).thenReturn(new ThreadLocals(DSL.using(SQLDialect.DEFAULT),
connectionProvider));
when(jooqPersistService.isWorking()).thenReturn(false);

// Method is final. Mockito doesn't support mocking final classes. Using reflection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import javax.sql.DataSource;
import org.jooq.Configuration;
import org.jooq.conf.BackslashEscaping;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

Expand All @@ -28,6 +29,17 @@ public void setup() {
injector = null;
}

@After
public void tearDown() {
if (injector != null) {
JooqPersistService instance = injector.getInstance(JooqPersistService.class);
if (instance.isWorking()) {
instance.end();
}
}
}


@Test
public void canCreateWithoutConfiguration() {
JooqPersistService jooqPersistService = givenJooqPersistServiceWithModule();
Expand Down Expand Up @@ -65,8 +77,11 @@ public void canProvideSettings() {
}

@Test
public void canProvideSettingsAndConfigurationButSettingsIsIgnored() {
public void canProvideSettingsAndConfigurationButSettingsIsIgnored() throws Exception {
JooqPersistService jooqPersistService = givenJooqPersistServiceWithModule(new ConfigurationModule(), new SettingsModule());
DataSource dataSource = injector.getInstance(DataSource.class);
Connection connectionMock = mock(Connection.class);
when(dataSource.getConnection()).thenReturn(connectionMock);
jooqPersistService.begin();

Configuration configuration = injector.getInstance(Configuration.class);
Expand Down

0 comments on commit 5927278

Please sign in to comment.