Skip to content

Commit

Permalink
Fix reentrant subscribe in StreamingNettyByteBody (#11051)
Browse files Browse the repository at this point in the history
If a subscriber to a split streaming body writes a response when reading data from the body, this can ultimately lead to another split of the same body being closed. Closing the body in turn is a subscribe operation, which must not happen during a read in a reentrant fashion. This would lead to a failure in the assertion that guards against such reentrant operations (`assert !working;`), and various downstream issues like buffer leaks. In particular, `MaxRequestSizeSpec` was affected by this bug occasionally.

This patch replaces the use of EventLoopFlow with more suitable code. In particular, EventLoopFlow does not support reentrant or concurrent calls, it only ensures serialization. The new logic supports reentrant or concurrent calls and still ensures serialization where it matters.

The new test does not work yet due to netty/netty#13730 . This PR is a draft until that patch is released.
  • Loading branch information
yawkat authored Sep 9, 2024
1 parent 2f1dc37 commit 7b60047
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.micronaut.http.body.CloseableByteBody;
import io.micronaut.http.exceptions.BufferLengthExceededException;
import io.micronaut.http.exceptions.ContentLengthExceededException;
import io.micronaut.http.netty.EventLoopFlow;
import io.micronaut.http.netty.PublisherAsBlocking;
import io.micronaut.http.netty.PublisherAsStream;
import io.netty.buffer.ByteBuf;
Expand Down Expand Up @@ -57,14 +56,30 @@
@Internal
public final class StreamingNettyByteBody extends NettyByteBody implements CloseableByteBody {
private final SharedBuffer sharedBuffer;
/**
* We have reserve, subscribe, and add calls in {@link SharedBuffer} that all modify the same
* data structures. They can all happen concurrently and must be moved to the event loop. We
* also need to ensure that a reserve and associated subscribe stay serialized
* ({@link io.micronaut.http.netty.EventLoopFlow} semantics). But because of the potential
* concurrency, we actually need stronger semantics than
* {@link io.micronaut.http.netty.EventLoopFlow}.
* <p>
* The solution is to use the old {@link EventLoop#inEventLoop()} + {@link EventLoop#execute}
* pattern. Serialization semantics for reserve to subscribe are guaranteed using this field:
* If the reserve call is delayed, this field is {@code true}, and the subscribe call will also
* be delayed. This approach is possible because we only need to serialize a single reserve
* with a single subscribe.
*/
private final boolean forceDelaySubscribe;
private BufferConsumer.Upstream upstream;

public StreamingNettyByteBody(SharedBuffer sharedBuffer) {
this(sharedBuffer, sharedBuffer.rootUpstream);
this(sharedBuffer, false, sharedBuffer.rootUpstream);
}

private StreamingNettyByteBody(SharedBuffer sharedBuffer, BufferConsumer.Upstream upstream) {
private StreamingNettyByteBody(SharedBuffer sharedBuffer, boolean forceDelaySubscribe, BufferConsumer.Upstream upstream) {
this.sharedBuffer = sharedBuffer;
this.forceDelaySubscribe = forceDelaySubscribe;
this.upstream = upstream;
}

Expand All @@ -74,7 +89,7 @@ public BufferConsumer.Upstream primary(BufferConsumer primary) {
failClaim();
}
this.upstream = null;
sharedBuffer.subscribe(primary, upstream);
sharedBuffer.subscribe(primary, upstream, forceDelaySubscribe);
return upstream;
}

Expand All @@ -86,8 +101,8 @@ public BufferConsumer.Upstream primary(BufferConsumer primary) {
}
UpstreamBalancer.UpstreamPair pair = UpstreamBalancer.balancer(upstream, backpressureMode);
this.upstream = pair.left();
this.sharedBuffer.reserve();
return new StreamingNettyByteBody(sharedBuffer, pair.right());
boolean forceDelaySubscribe = this.sharedBuffer.reserve();
return new StreamingNettyByteBody(sharedBuffer, forceDelaySubscribe, pair.right());
}

@Override
Expand Down Expand Up @@ -163,7 +178,7 @@ public void error(Throwable e) {
this.upstream = null;
upstream.start();
upstream.onBytesConsumed(Long.MAX_VALUE);
return sharedBuffer.subscribeFull(upstream).map(AvailableNettyByteBody::new);
return sharedBuffer.subscribeFull(upstream, forceDelaySubscribe).map(AvailableNettyByteBody::new);
}

@Override
Expand All @@ -176,14 +191,14 @@ public void close() {
upstream.allowDiscard();
upstream.disregardBackpressure();
upstream.start();
sharedBuffer.subscribe(null, upstream);
sharedBuffer.subscribe(null, upstream, forceDelaySubscribe);
}

/**
* This class buffers input data and distributes it to multiple {@link StreamingNettyByteBody}
* instances.
* <p>Thread safety: The {@link BufferConsumer} methods <i>must</i> only be called from one
* thread, the {@link #eventLoopFlow} thread. The other methods (subscribe, reserve) can be
* thread, the {@link #eventLoop} thread. The other methods (subscribe, reserve) can be
* called from any thread.
*/
public static final class SharedBuffer implements BufferConsumer {
Expand All @@ -193,7 +208,7 @@ public static final class SharedBuffer implements BufferConsumer {
@Nullable
private final ResourceLeakTracker<SharedBuffer> tracker = LEAK_DETECTOR.get().track(this);

private final EventLoopFlow eventLoopFlow;
private final EventLoop eventLoop;
private final BodySizeLimits limits;
/**
* Upstream of all subscribers. This is only used to cancel incoming data if the max
Expand Down Expand Up @@ -230,6 +245,11 @@ public static final class SharedBuffer implements BufferConsumer {
* in a reentrant fashion.
*/
private boolean working = false;
/**
* {@code true} during {@link #add(ByteBuf)} to avoid reentrant subscribe or reserve calls.
* Field must only be accessed on the event loop.
*/
private boolean adding = false;
/**
* Number of bytes received so far.
*/
Expand All @@ -242,7 +262,7 @@ public static final class SharedBuffer implements BufferConsumer {
private volatile long expectedLength = -1;

public SharedBuffer(EventLoop loop, BodySizeLimits limits, Upstream rootUpstream) {
this.eventLoopFlow = new EventLoopFlow(loop);
this.eventLoop = loop;
this.limits = limits;
this.rootUpstream = rootUpstream;
}
Expand Down Expand Up @@ -274,9 +294,13 @@ public void setExpectedLength(long length) {
this.expectedLength = length;
}

void reserve() {
if (eventLoopFlow.executeNow(this::reserve0)) {
boolean reserve() {
if (eventLoop.inEventLoop() && !adding) {
reserve0();
return false;
} else {
eventLoop.execute(this::reserve0);
return true;
}
}

Expand All @@ -295,10 +319,13 @@ private void reserve0() {
*
* @param subscriber The subscriber to add. Can be {@code null}, then the bytes will just be discarded
* @param specificUpstream The upstream for the subscriber. This is used to call allowDiscard if there was an error
* @param forceDelay Whether to require an {@link EventLoop#execute} call to ensure serialization with previous {@link #reserve()} call
*/
void subscribe(@Nullable BufferConsumer subscriber, Upstream specificUpstream) {
if (eventLoopFlow.executeNow(() -> subscribe0(subscriber, specificUpstream))) {
void subscribe(@Nullable BufferConsumer subscriber, Upstream specificUpstream, boolean forceDelay) {
if (!forceDelay && eventLoop.inEventLoop() && !adding) {
subscribe0(subscriber, specificUpstream);
} else {
eventLoop.execute(() -> subscribe0(subscriber, specificUpstream));
}
}

Expand Down Expand Up @@ -354,16 +381,18 @@ private void subscribe0(@Nullable BufferConsumer subscriber, Upstream specificUp
* body.
*
* @param specificUpstream The upstream for the subscriber. This is used to call allowDiscard if there was an error
* @param forceDelay Whether to require an {@link EventLoop#execute} call to ensure serialization with previous {@link #reserve()} call
* @return A flow that will complete when all data has arrived, with a buffer containing that data
*/
ExecutionFlow<ByteBuf> subscribeFull(Upstream specificUpstream) {
ExecutionFlow<ByteBuf> subscribeFull(Upstream specificUpstream, boolean forceDelay) {
DelayedExecutionFlow<ByteBuf> asyncFlow = DelayedExecutionFlow.create();
if (eventLoopFlow.executeNow(() -> {
ExecutionFlow<ByteBuf> res = subscribeFull0(asyncFlow, specificUpstream, false);
assert res == asyncFlow;
})) {
if (!forceDelay && eventLoop.inEventLoop() && !adding) {
return subscribeFull0(asyncFlow, specificUpstream, true);
} else {
eventLoop.execute(() -> {
ExecutionFlow<ByteBuf> res = subscribeFull0(asyncFlow, specificUpstream, false);
assert res == asyncFlow;
});
return asyncFlow;
}
}
Expand Down Expand Up @@ -445,6 +474,7 @@ public void add(ByteBuf buf) {
buf.release();
return;
}
adding = true;
// calculate the new total length
long newLength = lengthSoFar + buf.readableBytes();
lengthSoFar = newLength;
Expand All @@ -453,6 +483,7 @@ public void add(ByteBuf buf) {
buf.release();
error(new ContentLengthExceededException(limits.maxBodySize(), newLength));
rootUpstream.allowDiscard();
adding = false;
return;
}

Expand Down Expand Up @@ -486,6 +517,7 @@ public void add(ByteBuf buf) {
} else {
buf.release();
}
adding = false;
working = false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.micronaut.http.server.netty.handler

import io.micronaut.core.io.buffer.ByteBuffer
import io.micronaut.http.body.AvailableByteBody
import io.micronaut.http.body.ByteBody
import io.micronaut.http.body.CloseableAvailableByteBody
import io.micronaut.http.body.CloseableByteBody
import io.netty.buffer.ByteBuf
Expand Down Expand Up @@ -604,6 +605,38 @@ class PipeliningServerHandlerSpec extends Specification {
unwritten == 0
}

def 'reentrant close'() {
given:
def resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)
resp.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
def ch = new EmbeddedChannel(new PipeliningServerHandler(new RequestHandler() {
@Override
void accept(ChannelHandlerContext ctx, HttpRequest request, CloseableByteBody body, OutboundAccess outboundAccess) {
def split = body.split(ByteBody.SplitBackpressureMode.FASTEST)
Flux.from(split.toByteArrayPublisher())
.subscribe {
body.close()
outboundAccess.writeFull(resp)
}
}

@Override
void handleUnboundError(Throwable cause) {
cause.printStackTrace()
}
}))


def request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/")
request.headers().add(HttpHeaderNames.CONTENT_LENGTH, 3)
when:
ch.writeInbound(request)
ch.writeInbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("foo", StandardCharsets.UTF_8)))

then:
ch.checkException()
}

static class MonitorHandler extends ChannelOutboundHandlerAdapter {
int flush = 0
int read = 0
Expand Down

0 comments on commit 7b60047

Please sign in to comment.