Skip to content

Commit

Permalink
Merge #3845 into 3.7.0-M5
Browse files Browse the repository at this point in the history
  • Loading branch information
chemicL committed Jul 12, 2024
2 parents 877fe1b + ebded61 commit c575e04
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 8 deletions.
16 changes: 12 additions & 4 deletions reactor-core/src/main/java/reactor/core/publisher/Flux.java
Original file line number Diff line number Diff line change
Expand Up @@ -4393,6 +4393,14 @@ public final Flux<T> contextWrite(Function<Context, Context> contextModifier) {
return onAssembly(new FluxContextWrite<>(this, contextModifier));
}

private final Flux<T> contextWriteSkippingContextPropagation(ContextView contextToAppend) {
return contextWriteSkippingContextPropagation(c -> c.putAll(contextToAppend));
}

private final Flux<T> contextWriteSkippingContextPropagation(Function<Context, Context> contextModifier) {
return onAssembly(new FluxContextWrite<>(this, contextModifier));
}

/**
* Counts the number of values in this {@link Flux}.
* The count will be emitted when onComplete is observed.
Expand Down Expand Up @@ -4866,7 +4874,7 @@ public final Flux<T> doOnComplete(Runnable onComplete) {
* @return a {@link Flux} that cleans up matching elements that get discarded upstream of it.
*/
public final <R> Flux<T> doOnDiscard(final Class<R> type, final Consumer<? super R> discardHook) {
return contextWrite(Operators.discardLocalAdapter(type, discardHook));
return contextWriteSkippingContextPropagation(Operators.discardLocalAdapter(type, discardHook));
}

/**
Expand Down Expand Up @@ -7147,7 +7155,7 @@ public final Flux<T> onErrorComplete(Predicate<? super Throwable> predicate) {
*/
public final Flux<T> onErrorContinue(BiConsumer<Throwable, Object> errorConsumer) {
BiConsumer<Throwable, Object> genericConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resume(genericConsumer)
));
Expand Down Expand Up @@ -7231,7 +7239,7 @@ public final <E extends Throwable> Flux<T> onErrorContinue(Predicate<E> errorPre
@SuppressWarnings("unchecked")
Predicate<Throwable> genericPredicate = (Predicate<Throwable>) errorPredicate;
BiConsumer<Throwable, Object> genericErrorConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resumeIf(genericPredicate, genericErrorConsumer)
));
Expand All @@ -7248,7 +7256,7 @@ public final <E extends Throwable> Flux<T> onErrorContinue(Predicate<E> errorPre
* was used downstream
*/
public final Flux<T> onErrorStop() {
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.stop()));
}
Expand Down
16 changes: 12 additions & 4 deletions reactor-core/src/main/java/reactor/core/publisher/Mono.java
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,14 @@ public final Mono<T> contextWrite(Function<Context, Context> contextModifier) {
return onAssembly(new MonoContextWrite<>(this, contextModifier));
}

private final Mono<T> contextWriteSkippingContextPropagation(ContextView contextToAppend) {
return contextWriteSkippingContextPropagation(c -> c.putAll(contextToAppend));
}

private final Mono<T> contextWriteSkippingContextPropagation(Function<Context, Context> contextModifier) {
return onAssembly(new MonoContextWrite<>(this, contextModifier));
}

/**
* Provide a default single value if this mono is completed without any data
*
Expand Down Expand Up @@ -2713,7 +2721,7 @@ public final Mono<T> doOnCancel(Runnable onCancel) {
* @return a {@link Mono} that cleans up matching elements that get discarded upstream of it.
*/
public final <R> Mono<T> doOnDiscard(final Class<R> type, final Consumer<? super R> discardHook) {
return contextWrite(Operators.discardLocalAdapter(type, discardHook));
return contextWriteSkippingContextPropagation(Operators.discardLocalAdapter(type, discardHook));
}

/**
Expand Down Expand Up @@ -3712,7 +3720,7 @@ public final Mono<T> onErrorComplete(Predicate<? super Throwable> predicate) {
*/
public final Mono<T> onErrorContinue(BiConsumer<Throwable, Object> errorConsumer) {
BiConsumer<Throwable, Object> genericConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resume(genericConsumer)
));
Expand Down Expand Up @@ -3802,7 +3810,7 @@ public final <E extends Throwable> Mono<T> onErrorContinue(Predicate<E> errorPre
@SuppressWarnings("unchecked")
Predicate<Throwable> genericPredicate = (Predicate<Throwable>) errorPredicate;
BiConsumer<Throwable, Object> genericErrorConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resumeIf(genericPredicate, genericErrorConsumer)
));
Expand All @@ -3819,7 +3827,7 @@ public final <E extends Throwable> Mono<T> onErrorContinue(Predicate<E> errorPre
* was used downstream
*/
public final Mono<T> onErrorStop() {
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.stop()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.micrometer.context.ContextRegistry;
import io.micrometer.context.ThreadLocalAccessor;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -2453,4 +2455,124 @@ void fluxToIterable() {
assertThat(value.get()).isEqualTo("present");
}
}

@Nested
class SpecialContextAlteringOperators {

// The cases here consider operators like doOnDiscard(), which underneath
// utilize contextWrite() for its purpose. They are special in that we use them
// internally and do not anticipate the registered keys to be corresponding to
// any ThreadLocal values. That expectation is reasonable in user facing code
// as we don't know what keys are used and whether a ThreadLocalAccessor is
// registered for these keys. Therefore, in specific cases that are internal to
// reactor-core, we can skip ThreadLocal restoration in fragments of the chain.

// Explanation of GET/SET operations on TL in the scenarios here:
// When going UP, we use the value "present".
// When going DOWN, we clear the value or restore a captured empty value.
//
// 1 x GET in block() with implicit context capture
//
// 1 x GET going UP from contextWrite (read current to restore later)
// + 2 x SET going UP from contextWrite + SET restoring current later
//
// 1 x GET going DOWN from contextWrite with subscription (read current)
// + 2 x SET going DOWN from contextWrite + SET restoring current later
//
// 1 x GET going UP to request (read current)
// + 2 x SET going UP from contextWrite + SET restoring current later
//
// 1 x GET going DOWN to deliver onComplete (read current)
// + 2 x SET going DOWN from contextWrite + SET restoring current later

@Test
void discardFlux() {
CountingThreadLocalAccessor accessor = new CountingThreadLocalAccessor();
ContextRegistry.getInstance().registerThreadLocalAccessor(accessor);

AtomicInteger tlPresent = new AtomicInteger();
AtomicInteger discards = new AtomicInteger();

Flux.just("a")
.doOnEach(signal -> {
if (CountingThreadLocalAccessor.TL.get().equals("present")) {
tlPresent.incrementAndGet();
}
})
.filter(s -> false)
.doOnDiscard(String.class, s -> discards.incrementAndGet())
.count()
.contextWrite(ctx -> ctx.put(CountingThreadLocalAccessor.KEY, "present"))
.block();

assertThat(tlPresent.get()).isEqualTo(2); // 1 x onNext + 1 x onComplete
assertThat(discards.get()).isEqualTo(1);
// 5 with doOnDiscard skipping TL restoration, 9 with restoring
assertThat(accessor.reads.get()).isEqualTo(5);
// 8 with doOnDiscard skipping TL restoration, 16 with restoring
assertThat(accessor.writes.get()).isEqualTo(8);

ContextRegistry.getInstance().removeThreadLocalAccessor(CountingThreadLocalAccessor.KEY);
}

@Test
void discardMono() {
CountingThreadLocalAccessor accessor = new CountingThreadLocalAccessor();
ContextRegistry.getInstance().registerThreadLocalAccessor(accessor);

AtomicInteger tlPresent = new AtomicInteger();
AtomicInteger discards = new AtomicInteger();

Mono.just("a")
.doOnEach(signal -> {
if (CountingThreadLocalAccessor.TL.get().equals("present")) {
tlPresent.incrementAndGet();
}
})
.filter(s -> false)
.doOnDiscard(String.class, s -> discards.incrementAndGet())
.contextWrite(ctx -> ctx.put(CountingThreadLocalAccessor.KEY, "present"))
.block();

assertThat(tlPresent.get()).isEqualTo(2); // 1 x onNext + 1 x onComplete
assertThat(discards.get()).isEqualTo(1);
// 5 with doOnDiscard skipping TL restoration, 9 with restoring
assertThat(accessor.reads.get()).isEqualTo(5);
// 8 with doOnDiscard skipping TL restoration, 16 with restoring
assertThat(accessor.writes.get()).isEqualTo(8);

ContextRegistry.getInstance().removeThreadLocalAccessor(CountingThreadLocalAccessor.KEY);
}
}

private static class CountingThreadLocalAccessor implements ThreadLocalAccessor<String> {
static final String KEY = "CTLA";
static final ThreadLocal<String> TL = new ThreadLocal<>();

AtomicInteger reads = new AtomicInteger();
AtomicInteger writes = new AtomicInteger();

@Override
public Object key() {
return KEY;
}

@Override
public String getValue() {
reads.incrementAndGet();
return TL.get();
}

@Override
public void setValue(String s) {
writes.incrementAndGet();
TL.set(s);
}

@Override
public void setValue() {
writes.incrementAndGet();
TL.remove();
}
}
}

0 comments on commit c575e04

Please sign in to comment.