Skip to content

Commit

Permalink
Merge pull request #12342 from jetty/jetty-12.1.x-11307-websocket-fra…
Browse files Browse the repository at this point in the history
…meHandler-demand

Issue #11307 - Explicit demand control in WebSocket endpoints with only onWebSocketFrame
  • Loading branch information
lachlan-roberts authored Nov 15, 2024
2 parents dc6e7b9 + 8896b10 commit 4360d12
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,46 @@ public void fail(Throwable x)
};
}

/**
* Creates a nested callback that runs completed after
* completing the nested callback.
*
* @param callback The nested callback
* @param completed The completion to run after the nested callback is completed
* @return a new callback.
*/
static Callback from(Callback callback, Runnable completed)
{
return new Callback()
{
@Override
public void succeed()
{
try
{
callback.succeed();
}
finally
{
completed.run();
}
}

@Override
public void fail(Throwable x)
{
try
{
callback.fail(x);
}
finally
{
completed.run();
}
}
};
}

/**
* <p>Method to invoke to succeed the callback.</p>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,22 @@ public String toString()
boolean isRsv2();

boolean isRsv3();

default CloseStatus getCloseStatus()
{
return null;
}

record CloseStatus(int statusCode, String reason)
{
}

/**
* The effective opcode of the frame accounting for the CONTINUATION opcode.
* If the frame is a CONTINUATION frame for a TEXT message, this will return TEXT.
* If the frame is a CONTINUATION frame for a BINARY message, this will return BINARY.
* Otherwise, this will return the same opcode as the frame.
* @return the effective opcode of the frame.
*/
byte getEffectiveOpCode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ default void onWebSocketOpen(Session session)
* or data frames either BINARY or TEXT.</p>
*
* @param frame the received frame
* @param callback the callback to complete once the frame has been processed.
*/
default void onWebSocketFrame(Frame frame, Callback callback)
{
Expand Down Expand Up @@ -299,6 +300,7 @@ default void onWebSocketPartialText(String payload, boolean last)
* <p>A WebSocket BINARY message has been received.</p>
*
* @param payload the raw payload array received
* @param callback the callback to complete when the payload has been processed
*/
default void onWebSocketBinary(ByteBuffer payload, Callback callback)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,32 @@
import java.nio.ByteBuffer;

import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;

public class JettyWebSocketFrame implements org.eclipse.jetty.websocket.api.Frame
{
private final Frame frame;
private final byte effectiveOpCode;

/**
* @param frame the core websocket {@link Frame} to wrap as a {@link org.eclipse.jetty.websocket.api.Frame}.
* @deprecated there is no alternative intended to publicly construct a {@link JettyWebSocketFrame}.
*/
@Deprecated(forRemoval = true, since = "12.1.0")
public JettyWebSocketFrame(Frame frame)
{
this(frame, frame.getOpCode());
}

/**
* @param frame the core websocket {@link Frame} to wrap as a Jetty API {@link org.eclipse.jetty.websocket.api.Frame}.
* @param effectiveOpCode the effective OpCode of the Frame, where any CONTINUATION should be replaced with the
* initial opcode of that websocket message.
*/
JettyWebSocketFrame(Frame frame, byte effectiveOpCode)
{
this.frame = frame;
this.effectiveOpCode = effectiveOpCode;
}

@Override
Expand Down Expand Up @@ -92,6 +110,21 @@ public boolean isRsv3()
return frame.isRsv3();
}

@Override
public byte getEffectiveOpCode()
{
return effectiveOpCode;
}

@Override
public CloseStatus getCloseStatus()
{
if (getOpCode() != OpCode.CLOSE)
return null;
org.eclipse.jetty.websocket.core.CloseStatus closeStatus = org.eclipse.jetty.websocket.core.CloseStatus.getCloseStatus(frame);
return new CloseStatus(closeStatus.getCode(), closeStatus.getReason());
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;

import org.eclipse.jetty.util.BufferUtil;
Expand Down Expand Up @@ -69,6 +68,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler
private MessageSink binarySink;
private MessageSink activeMessageSink;
private WebSocketSession session;
private byte messageType;

public JettyWebSocketFrameHandler(WebSocketContainer container, Object endpointInstance, JettyWebSocketFrameHandlerMetadata metadata)
{
Expand Down Expand Up @@ -193,44 +193,36 @@ private static MessageSink createMessageSink(Class<? extends MessageSink> sinkCl
@Override
public void onFrame(Frame frame, Callback coreCallback)
{
CompletableFuture<Void> frameCallback = null;
if (frame.getOpCode() == OpCode.TEXT || frame.getOpCode() == OpCode.BINARY)
messageType = frame.getOpCode();

if (frameHandle != null)
{
try
{
frameCallback = new org.eclipse.jetty.websocket.api.Callback.Completable();
frameHandle.invoke(new JettyWebSocketFrame(frame), frameCallback);
byte effectiveOpCode = frame.isDataFrame() ? messageType : frame.getOpCode();
frameHandle.invoke(new JettyWebSocketFrame(frame, effectiveOpCode),
org.eclipse.jetty.websocket.api.Callback.from(coreCallback::succeeded, coreCallback::failed));
}
catch (Throwable cause)
{
coreCallback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " FRAME method error: " + cause.getMessage(), cause));
return;
}

autoDemand();
return;
}

Callback.Completable eventCallback = new Callback.Completable();
switch (frame.getOpCode())
{
case OpCode.CLOSE -> onCloseFrame(frame, eventCallback);
case OpCode.PING -> onPingFrame(frame, eventCallback);
case OpCode.PONG -> onPongFrame(frame, eventCallback);
case OpCode.TEXT -> onTextFrame(frame, eventCallback);
case OpCode.BINARY -> onBinaryFrame(frame, eventCallback);
case OpCode.CONTINUATION -> onContinuationFrame(frame, eventCallback);
case OpCode.TEXT -> onTextFrame(frame, coreCallback);
case OpCode.BINARY -> onBinaryFrame(frame, coreCallback);
case OpCode.CONTINUATION -> onContinuationFrame(frame, coreCallback);
case OpCode.PING -> onPingFrame(frame, coreCallback);
case OpCode.PONG -> onPongFrame(frame, coreCallback);
case OpCode.CLOSE -> onCloseFrame(frame, coreCallback);
default -> coreCallback.failed(new IllegalStateException());
};

// Combine the callback from the frame handler and the event handler.
CompletableFuture<Void> callback = eventCallback;
if (frameCallback != null)
callback = frameCallback.thenCompose(ignored -> eventCallback);
callback.whenComplete((r, x) ->
{
if (x == null)
coreCallback.succeeded();
else
coreCallback.failed(x);
});
}
}

@Override
Expand Down Expand Up @@ -358,6 +350,7 @@ private void onPongFrame(Frame frame, Callback callback)
}
else
{
callback.succeeded();
internalDemand();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public void setAutoDemand(boolean autoDemand)
public void setBinaryHandle(Class<? extends MessageSink> sinkClass, MethodHandle binary, Object origin)
{
assertNotSet(this.binaryHandle, "BINARY Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.binaryHandle = binary;
this.binarySink = sinkClass;
}
Expand Down Expand Up @@ -85,6 +86,10 @@ public MethodHandle getErrorHandle()
public void setFrameHandle(MethodHandle frame, Object origin)
{
assertNotSet(this.frameHandle, "FRAME Handler", origin);
assertNotSet(this.textHandle, "TEXT Handler", origin);
assertNotSet(this.binaryHandle, "BINARY Handler", origin);
assertNotSet(this.pingHandle, "PING Handler", origin);
assertNotSet(this.pongHandle, "PONG Handler", origin);
this.frameHandle = frame;
}

Expand All @@ -107,6 +112,7 @@ public MethodHandle getOpenHandle()
public void setPingHandle(MethodHandle ping, Object origin)
{
assertNotSet(this.pingHandle, "PING Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.pingHandle = ping;
}

Expand All @@ -118,6 +124,7 @@ public MethodHandle getPingHandle()
public void setPongHandle(MethodHandle pong, Object origin)
{
assertNotSet(this.pongHandle, "PONG Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.pongHandle = pong;
}

Expand All @@ -129,6 +136,7 @@ public MethodHandle getPongHandle()
public void setTextHandle(Class<? extends MessageSink> sinkClass, MethodHandle text, Object origin)
{
assertNotSet(this.textHandle, "TEXT Handler", origin);
assertNotSet(this.frameHandle, "FRAME Handler", origin);
this.textHandle = text;
this.textSink = sinkClass;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -72,12 +73,36 @@ public void onMessage(String message) throws IOException
public static class ListenerSocket implements Session.Listener
{
final List<Frame> frames = new CopyOnWriteArrayList<>();
final List<Callback> callbacks = new CopyOnWriteArrayList<>();
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
frames.add(frame);
callback.succeed();
callbacks.add(callback);

// Because no pingListener is registered, the frameListener is responsible for handling pings.
if (frame.getOpCode() == OpCode.PING)
{
session.sendPong(frame.getPayload(), Callback.from(session::demand, callback::fail));
return;
}
else if (frame.getOpCode() == OpCode.CLOSE)
{
Frame.CloseStatus closeStatus = frame.getCloseStatus();
session.close(closeStatus.statusCode(), closeStatus.reason(), Callback.NOOP);
return;
}

session.demand();
}
}

Expand Down Expand Up @@ -109,27 +134,19 @@ public void onWebSocketFrame(Frame frame, Callback callback)
if (frame.getOpCode() == OpCode.TEXT)
textMessages.add(BufferUtil.toString(frame.getPayload()));
callback.succeed();
session.demand();
}
}

@WebSocket(autoDemand = false)
public static class PingSocket extends ListenerSocket
{
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
super.onWebSocketFrame(frame, callback);
if (frame.getType() == Frame.Type.TEXT)
session.sendPing(ByteBuffer.wrap("server-ping".getBytes(StandardCharsets.UTF_8)), Callback.NOOP);
super.onWebSocketFrame(frame, callback);
}
}

Expand Down Expand Up @@ -217,13 +234,23 @@ public void testNoAutoDemand() throws Exception
Frame frame0 = listenerSocket.frames.get(0);
assertThat(frame0.getType(), is(Frame.Type.PONG));
assertThat(StandardCharsets.UTF_8.decode(frame0.getPayload()).toString(), is("ping-0"));
Callback callback0 = listenerSocket.callbacks.get(0);
assertNotNull(callback0);
callback0.succeed();

Frame frame1 = listenerSocket.frames.get(1);
assertThat(frame1.getType(), is(Frame.Type.PONG));
assertThat(StandardCharsets.UTF_8.decode(frame1.getPayload()).toString(), is("ping-1"));
Callback callback1 = listenerSocket.callbacks.get(1);
assertNotNull(callback1);
callback1.succeed();

session.close();
await().atMost(5, TimeUnit.SECONDS).until(listenerSocket.frames::size, is(3));
assertThat(listenerSocket.frames.get(2).getType(), is(Frame.Type.CLOSE));
Callback closeCallback = listenerSocket.callbacks.get(2);
assertNotNull(closeCallback);
closeCallback.succeed();
}

@Test
Expand Down
Loading

0 comments on commit 4360d12

Please sign in to comment.