diff --git a/reactor-core/src/main/java/reactor/core/publisher/Flux.java b/reactor-core/src/main/java/reactor/core/publisher/Flux.java index d28f3fccce..b244ff8cdf 100644 --- a/reactor-core/src/main/java/reactor/core/publisher/Flux.java +++ b/reactor-core/src/main/java/reactor/core/publisher/Flux.java @@ -4393,6 +4393,14 @@ public final Flux contextWrite(Function contextModifier) { return onAssembly(new FluxContextWrite<>(this, contextModifier)); } + private final Flux contextWriteSkippingContextPropagation(ContextView contextToAppend) { + return contextWriteSkippingContextPropagation(c -> c.putAll(contextToAppend)); + } + + private final Flux contextWriteSkippingContextPropagation(Function contextModifier) { + return onAssembly(new FluxContextWrite<>(this, contextModifier)); + } + /** * Counts the number of values in this {@link Flux}. * The count will be emitted when onComplete is observed. @@ -4866,7 +4874,7 @@ public final Flux doOnComplete(Runnable onComplete) { * @return a {@link Flux} that cleans up matching elements that get discarded upstream of it. */ public final Flux doOnDiscard(final Class type, final Consumer discardHook) { - return contextWrite(Operators.discardLocalAdapter(type, discardHook)); + return contextWriteSkippingContextPropagation(Operators.discardLocalAdapter(type, discardHook)); } /** @@ -7147,7 +7155,7 @@ public final Flux onErrorComplete(Predicate predicate) { */ public final Flux onErrorContinue(BiConsumer errorConsumer) { BiConsumer genericConsumer = errorConsumer; - return contextWrite(Context.of( + return contextWriteSkippingContextPropagation(Context.of( OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY, OnNextFailureStrategy.resume(genericConsumer) )); @@ -7231,7 +7239,7 @@ public final Flux onErrorContinue(Predicate errorPre @SuppressWarnings("unchecked") Predicate genericPredicate = (Predicate) errorPredicate; BiConsumer genericErrorConsumer = errorConsumer; - return contextWrite(Context.of( + return contextWriteSkippingContextPropagation(Context.of( OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY, OnNextFailureStrategy.resumeIf(genericPredicate, genericErrorConsumer) )); @@ -7248,7 +7256,7 @@ public final Flux onErrorContinue(Predicate errorPre * was used downstream */ public final Flux onErrorStop() { - return contextWrite(Context.of( + return contextWriteSkippingContextPropagation(Context.of( OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY, OnNextFailureStrategy.stop())); } diff --git a/reactor-core/src/main/java/reactor/core/publisher/Mono.java b/reactor-core/src/main/java/reactor/core/publisher/Mono.java index f8f5ae96ee..4c1df5d815 100644 --- a/reactor-core/src/main/java/reactor/core/publisher/Mono.java +++ b/reactor-core/src/main/java/reactor/core/publisher/Mono.java @@ -2424,6 +2424,14 @@ public final Mono contextWrite(Function contextModifier) { return onAssembly(new MonoContextWrite<>(this, contextModifier)); } + private final Mono contextWriteSkippingContextPropagation(ContextView contextToAppend) { + return contextWriteSkippingContextPropagation(c -> c.putAll(contextToAppend)); + } + + private final Mono contextWriteSkippingContextPropagation(Function contextModifier) { + return onAssembly(new MonoContextWrite<>(this, contextModifier)); + } + /** * Provide a default single value if this mono is completed without any data * @@ -2713,7 +2721,7 @@ public final Mono doOnCancel(Runnable onCancel) { * @return a {@link Mono} that cleans up matching elements that get discarded upstream of it. */ public final Mono doOnDiscard(final Class type, final Consumer discardHook) { - return contextWrite(Operators.discardLocalAdapter(type, discardHook)); + return contextWriteSkippingContextPropagation(Operators.discardLocalAdapter(type, discardHook)); } /** @@ -3712,7 +3720,7 @@ public final Mono onErrorComplete(Predicate predicate) { */ public final Mono onErrorContinue(BiConsumer errorConsumer) { BiConsumer genericConsumer = errorConsumer; - return contextWrite(Context.of( + return contextWriteSkippingContextPropagation(Context.of( OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY, OnNextFailureStrategy.resume(genericConsumer) )); @@ -3802,7 +3810,7 @@ public final Mono onErrorContinue(Predicate errorPre @SuppressWarnings("unchecked") Predicate genericPredicate = (Predicate) errorPredicate; BiConsumer genericErrorConsumer = errorConsumer; - return contextWrite(Context.of( + return contextWriteSkippingContextPropagation(Context.of( OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY, OnNextFailureStrategy.resumeIf(genericPredicate, genericErrorConsumer) )); @@ -3819,7 +3827,7 @@ public final Mono onErrorContinue(Predicate errorPre * was used downstream */ public final Mono onErrorStop() { - return contextWrite(Context.of( + return contextWriteSkippingContextPropagation(Context.of( OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY, OnNextFailureStrategy.stop())); } diff --git a/reactor-core/src/withMicrometerTest/java/reactor/core/publisher/AutomaticContextPropagationTest.java b/reactor-core/src/withMicrometerTest/java/reactor/core/publisher/AutomaticContextPropagationTest.java index 27c5927d91..a23774a906 100644 --- a/reactor-core/src/withMicrometerTest/java/reactor/core/publisher/AutomaticContextPropagationTest.java +++ b/reactor-core/src/withMicrometerTest/java/reactor/core/publisher/AutomaticContextPropagationTest.java @@ -33,6 +33,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; @@ -40,6 +41,7 @@ import java.util.stream.Stream; import io.micrometer.context.ContextRegistry; +import io.micrometer.context.ThreadLocalAccessor; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -2453,4 +2455,124 @@ void fluxToIterable() { assertThat(value.get()).isEqualTo("present"); } } + + @Nested + class SpecialContextAlteringOperators { + + // The cases here consider operators like doOnDiscard(), which underneath + // utilize contextWrite() for its purpose. They are special in that we use them + // internally and do not anticipate the registered keys to be corresponding to + // any ThreadLocal values. That expectation is reasonable in user facing code + // as we don't know what keys are used and whether a ThreadLocalAccessor is + // registered for these keys. Therefore, in specific cases that are internal to + // reactor-core, we can skip ThreadLocal restoration in fragments of the chain. + + // Explanation of GET/SET operations on TL in the scenarios here: + // When going UP, we use the value "present". + // When going DOWN, we clear the value or restore a captured empty value. + // + // 1 x GET in block() with implicit context capture + // + // 1 x GET going UP from contextWrite (read current to restore later) + // + 2 x SET going UP from contextWrite + SET restoring current later + // + // 1 x GET going DOWN from contextWrite with subscription (read current) + // + 2 x SET going DOWN from contextWrite + SET restoring current later + // + // 1 x GET going UP to request (read current) + // + 2 x SET going UP from contextWrite + SET restoring current later + // + // 1 x GET going DOWN to deliver onComplete (read current) + // + 2 x SET going DOWN from contextWrite + SET restoring current later + + @Test + void discardFlux() { + CountingThreadLocalAccessor accessor = new CountingThreadLocalAccessor(); + ContextRegistry.getInstance().registerThreadLocalAccessor(accessor); + + AtomicInteger tlPresent = new AtomicInteger(); + AtomicInteger discards = new AtomicInteger(); + + Flux.just("a") + .doOnEach(signal -> { + if (CountingThreadLocalAccessor.TL.get().equals("present")) { + tlPresent.incrementAndGet(); + } + }) + .filter(s -> false) + .doOnDiscard(String.class, s -> discards.incrementAndGet()) + .count() + .contextWrite(ctx -> ctx.put(CountingThreadLocalAccessor.KEY, "present")) + .block(); + + assertThat(tlPresent.get()).isEqualTo(2); // 1 x onNext + 1 x onComplete + assertThat(discards.get()).isEqualTo(1); + // 5 with doOnDiscard skipping TL restoration, 9 with restoring + assertThat(accessor.reads.get()).isEqualTo(5); + // 8 with doOnDiscard skipping TL restoration, 16 with restoring + assertThat(accessor.writes.get()).isEqualTo(8); + + ContextRegistry.getInstance().removeThreadLocalAccessor(CountingThreadLocalAccessor.KEY); + } + + @Test + void discardMono() { + CountingThreadLocalAccessor accessor = new CountingThreadLocalAccessor(); + ContextRegistry.getInstance().registerThreadLocalAccessor(accessor); + + AtomicInteger tlPresent = new AtomicInteger(); + AtomicInteger discards = new AtomicInteger(); + + Mono.just("a") + .doOnEach(signal -> { + if (CountingThreadLocalAccessor.TL.get().equals("present")) { + tlPresent.incrementAndGet(); + } + }) + .filter(s -> false) + .doOnDiscard(String.class, s -> discards.incrementAndGet()) + .contextWrite(ctx -> ctx.put(CountingThreadLocalAccessor.KEY, "present")) + .block(); + + assertThat(tlPresent.get()).isEqualTo(2); // 1 x onNext + 1 x onComplete + assertThat(discards.get()).isEqualTo(1); + // 5 with doOnDiscard skipping TL restoration, 9 with restoring + assertThat(accessor.reads.get()).isEqualTo(5); + // 8 with doOnDiscard skipping TL restoration, 16 with restoring + assertThat(accessor.writes.get()).isEqualTo(8); + + ContextRegistry.getInstance().removeThreadLocalAccessor(CountingThreadLocalAccessor.KEY); + } + } + + private static class CountingThreadLocalAccessor implements ThreadLocalAccessor { + static final String KEY = "CTLA"; + static final ThreadLocal TL = new ThreadLocal<>(); + + AtomicInteger reads = new AtomicInteger(); + AtomicInteger writes = new AtomicInteger(); + + @Override + public Object key() { + return KEY; + } + + @Override + public String getValue() { + reads.incrementAndGet(); + return TL.get(); + } + + @Override + public void setValue(String s) { + writes.incrementAndGet(); + TL.set(s); + } + + @Override + public void setValue() { + writes.incrementAndGet(); + TL.remove(); + } + } }