Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle edge cases and support autocommit=false #21

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
*/
class JdbcLocalTxnInterceptor implements MethodInterceptor {
private static final Logger logger = LoggerFactory.getLogger(JdbcLocalTxnInterceptor.class);
private static final ConcurrentMap<Method, Transactional> methodsTransactionals = new ConcurrentHashMap<Method, Transactional>();
private static final ConcurrentMap<Method, Transactional> transactionalMethods = new ConcurrentHashMap<Method, Transactional>();

private final Provider<JooqPersistService> jooqPersistServiceProvider;
private final Provider<UnitOfWork> unitOfWorkProvider;
Expand All @@ -44,110 +44,91 @@ class JdbcLocalTxnInterceptor implements MethodInterceptor {
private static class Internal {
}

// Tracks if the unit of work was begun implicitly by this transaction.
private final ThreadLocal<Boolean> didWeStartWork = new ThreadLocal<Boolean>();

@Inject
public JdbcLocalTxnInterceptor(Provider<JooqPersistService> jooqPersistServiceProvider,
Provider<UnitOfWork> unitOfWorkProvider) {
this.jooqPersistServiceProvider = jooqPersistServiceProvider;
this.unitOfWorkProvider = unitOfWorkProvider;
}

@Override
public Object invoke(final MethodInvocation methodInvocation) throws Throwable {
UnitOfWork unitOfWork = unitOfWorkProvider.get();
JooqPersistService jooqProvider = jooqPersistServiceProvider.get();

// Should we start a unit of work?
if (!jooqProvider.isWorking()) {
if (jooqProvider.isWorking()) {
// Allow 'joining' of transactions if there is an enclosing @Transactional method.
return methodInvocation.proceed();
} else {
// We should start a unit of work
unitOfWork.begin();
didWeStartWork.set(true);
}

Transactional transactional = readTransactionMetadata(methodInvocation);
DefaultConnectionProvider conn = jooqProvider.getConnectionWrapper();
DefaultConnectionProvider conn = jooqProvider.getThreadLocals().getConnectionProvider();
boolean reenableAutoCommit = false;

// Allow 'joining' of transactions if there is an enclosing @Transactional method.
if (!conn.getAutoCommit()) {
return methodInvocation.proceed();
}
try {
if (conn.getAutoCommit()) {
logger.debug("Disabling JDBC auto commit for this thread");
reenableAutoCommit = true;
conn.setAutoCommit(false);
}

logger.debug("Disabling JDBC auto commit for this thread");
conn.setAutoCommit(false);
Object result = methodInvocation.proceed();

Object result;
logger.debug("Committing JDBC transaction");
conn.commit();

try {
result = methodInvocation.proceed();
return result;
} catch (Exception e) {
//commit transaction only if rollback didn't occur
if (rollbackIfNecessary(transactional, e, conn)) {
if (rollbackIfNecessary(readTransactionMetadata(methodInvocation), e, conn)) {
logger.debug("Committing JDBC transaction");
conn.commit();
}

logger.debug("Enabling auto commit for this thread");
conn.setAutoCommit(true);

//propagate whatever exception is thrown anyway
throw e;
} finally {
// Close the em if necessary (guarded so this code doesn't run unless catch fired).
if (null != didWeStartWork.get() && conn.getAutoCommit()) {
didWeStartWork.remove();
unitOfWork.end();
}
}

// everything was normal so commit the txn (do not move into try block above as it
// interferes with the advised method's throwing semantics)
try {
logger.debug("Committing JDBC transaction");
conn.commit();
logger.debug("Enabling auto commit for this thread");
conn.setAutoCommit(true);
} finally {
//close the em if necessary
if (null != didWeStartWork.get()) {
didWeStartWork.remove();
if (reenableAutoCommit) {
try {
conn.setAutoCommit(true);
} catch (Exception ignored) {
}
}
unitOfWork.end();
}
}

//or return result
return result;
}

private Transactional readTransactionMetadata(final MethodInvocation methodInvocation) {
Method method = methodInvocation.getMethod();
Transactional cachedTransactional = methodsTransactionals.get(method);
Transactional cachedTransactional = transactionalMethods.get(method);
if (cachedTransactional != null) {
return cachedTransactional;
}

Transactional transactional = method.getAnnotation(Transactional.class);
if (null == transactional) {
// If none on method, try the class.
Class<?> targetClass = methodInvocation.getThis().getClass();
transactional = targetClass.getAnnotation(Transactional.class);
}

if (null != transactional) {
methodsTransactionals.put(method, transactional);
} else {
// If there is no transactional annotation present, use the default
transactional = Internal.class.getAnnotation(Transactional.class);
}

return transactional;
return transactionalMethods.computeIfAbsent(method, ignored -> {
Transactional transactional;
transactional = method.getAnnotation(Transactional.class);
if (null == transactional) {
// If none on method, try the class.
Class<?> targetClass = methodInvocation.getThis().getClass();
transactional = targetClass.getAnnotation(Transactional.class);
}
if (null == transactional) {
// If there is no transactional annotation present, use the default
transactional = Internal.class.getAnnotation(Transactional.class);
}
return transactional;
});
}

/**
* Returns True if rollback DID NOT HAPPEN (i.e. if commit should continue).
*
* @param transactional The metadata annotation of the method
* @param e The exception to test for rollback
* @param txn A JPA Transaction to issue rollbacks on
* @param conn A DefaultConnectionProvider to issue rollbacks on
*/
private boolean rollbackIfNecessary(final Transactional transactional,
final Exception e,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ protected void configurePersistence() {

transactionInterceptor = new JdbcLocalTxnInterceptor(getProvider(JooqPersistService.class),
getProvider(UnitOfWork.class));
requestInjection(transactionInterceptor);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.jooq.conf.Settings;
import org.jooq.impl.DSL;
import org.jooq.impl.DefaultConnectionProvider;
import org.jooq.tools.jdbc.JDBCUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -47,8 +48,7 @@ class JooqPersistService implements Provider<DSLContext>, UnitOfWork, PersistSer

private static final Logger logger = LoggerFactory.getLogger(JooqPersistService.class);

private final ThreadLocal<DSLContext> threadFactory = new ThreadLocal<DSLContext>();
private final ThreadLocal<DefaultConnectionProvider> threadConnection = new ThreadLocal<DefaultConnectionProvider>();
private static final ThreadLocal<ThreadLocals> threadFactory = new ThreadLocal<>();
private final Provider<DataSource> jdbcSource;
private final SQLDialect sqlDialect;
private final Settings jooqSettings;
Expand All @@ -68,86 +68,94 @@ public JooqPersistService(final Provider<DataSource> jdbcSource, final SQLDialec
}
}

@Override
public DSLContext get() {
DSLContext factory = threadFactory.get();
if(null == factory) {
ThreadLocals factory = threadFactory.get();
if (null == factory) {
throw new IllegalStateException("Requested Factory outside work unit. "
+ "Try calling UnitOfWork.begin() first, use @Transactional annotation"
+ "or use a PersistFilter if you are inside a servlet environment.");
}

return factory;
return factory.getDSLContext();
}

public DefaultConnectionProvider getConnectionWrapper() {
return threadConnection.get();
public ThreadLocals getThreadLocals() {
return threadFactory.get();
}

public boolean isWorking() {
return threadFactory.get() != null;
}

@Override
public void begin() {
if(null != threadFactory.get()) {
if (null != threadFactory.get()) {
throw new IllegalStateException("Work already begun on this thread. "
+ "It looks like you have called UnitOfWork.begin() twice"
+ " without a balancing call to end() in between.");
}

DefaultConnectionProvider conn;
threadFactory.set(createThreadLocals());
}

private ThreadLocals createThreadLocals() {
DefaultConnectionProvider conn = null;
try {
logger.debug("Getting JDBC connection");
DataSource dataSource = jdbcSource.get();
Connection jdbcConn = dataSource.getConnection();
conn = new DefaultConnectionProvider(jdbcConn);
} catch (SQLException e) {
throw new RuntimeException(e);
}

DSLContext jooqFactory;
DSLContext jooqFactory;

if (configuration != null) {
logger.debug("Creating factory from configuration having dialect {}", configuration.dialect());
jooqFactory = DSL.using(conn, configuration.dialect(), configuration.settings());
} else {
if (jooqSettings == null) {
logger.debug("Creating factory with dialect {}", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect);
if (configuration != null) {
logger.debug("Creating factory from configuration having dialect {}", configuration.dialect());
jooqFactory = DSL.using(conn, configuration.dialect(), configuration.settings());
} else {
logger.debug("Creating factory with dialect {} and settings.", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect, jooqSettings);
if (jooqSettings == null) {
logger.debug("Creating factory with dialect {}", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect);
} else {
logger.debug("Creating factory with dialect {} and settings.", sqlDialect);
jooqFactory = DSL.using(conn, sqlDialect, jooqSettings);
}
}
return new ThreadLocals(jooqFactory, conn);
} catch (Exception e) {
if (conn != null) {
JDBCUtils.safeClose(conn.acquire());
}
throw new RuntimeException(e);
}
threadConnection.set(conn);
threadFactory.set(jooqFactory);
}

@Override
public void end() {
DSLContext jooqFactory = threadFactory.get();
DefaultConnectionProvider conn = threadConnection.get();
ThreadLocals threadLocals = threadFactory.get();
// Let's not penalize users for calling end() multiple times.
if (null == jooqFactory) {
if (null == threadLocals) {
return;
}

try {
logger.debug("Closing JDBC connection");
conn.acquire().close();
threadLocals.getConnectionProvider().acquire().close();
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
threadFactory.remove();
}
threadFactory.remove();
threadConnection.remove();
}


@Override
public synchronized void start() {
//nothing to do on start
}

@Override
public synchronized void stop() {
//nothing to do on stop
}


}
22 changes: 22 additions & 0 deletions src/main/java/com/adamlewis/guice/persist/jooq/ThreadLocals.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.adamlewis.guice.persist.jooq;

import org.jooq.DSLContext;
import org.jooq.impl.DefaultConnectionProvider;

final class ThreadLocals {
private final DSLContext dslContext;
private final DefaultConnectionProvider connectionProvider;

ThreadLocals(DSLContext dslContext, DefaultConnectionProvider connectionProvider) {
this.dslContext = dslContext;
this.connectionProvider = connectionProvider;
}

DSLContext getDSLContext() {
return dslContext;
}

DefaultConnectionProvider getConnectionProvider() {
return connectionProvider;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.google.inject.persist.Transactional;
import com.google.inject.persist.UnitOfWork;
import org.aopalliance.intercept.MethodInvocation;
import org.jooq.SQLDialect;
import org.jooq.impl.DSL;
import org.jooq.impl.DefaultConnectionProvider;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -43,7 +45,8 @@ public void setUp() throws Exception {
connection.setAutoCommit(true);

DefaultConnectionProvider connectionProvider = new DefaultConnectionProvider(connection);
when(jooqPersistService.getConnectionWrapper()).thenReturn(connectionProvider);
when(jooqPersistService.getThreadLocals()).thenReturn(new ThreadLocals(DSL.using(SQLDialect.DEFAULT),
connectionProvider));
when(jooqPersistService.isWorking()).thenReturn(false);

// Method is final. Mockito doesn't support mocking final classes. Using reflection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import javax.sql.DataSource;
import org.jooq.Configuration;
import org.jooq.conf.BackslashEscaping;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

Expand All @@ -28,6 +29,17 @@ public void setup() {
injector = null;
}

@After
public void tearDown() {
if (injector != null) {
JooqPersistService instance = injector.getInstance(JooqPersistService.class);
if (instance.isWorking()) {
instance.end();
}
}
}


@Test
public void canCreateWithoutConfiguration() {
JooqPersistService jooqPersistService = givenJooqPersistServiceWithModule();
Expand Down Expand Up @@ -65,8 +77,11 @@ public void canProvideSettings() {
}

@Test
public void canProvideSettingsAndConfigurationButSettingsIsIgnored() {
public void canProvideSettingsAndConfigurationButSettingsIsIgnored() throws Exception {
JooqPersistService jooqPersistService = givenJooqPersistServiceWithModule(new ConfigurationModule(), new SettingsModule());
DataSource dataSource = injector.getInstance(DataSource.class);
Connection connectionMock = mock(Connection.class);
when(dataSource.getConnection()).thenReturn(connectionMock);
jooqPersistService.begin();

Configuration configuration = injector.getInstance(Configuration.class);
Expand Down