Skip to content

Commit

Permalink
Issue #11307 - Explicit demand control in WebSocket endpoints with on…
Browse files Browse the repository at this point in the history
…ly onWebSocketFrame

Signed-off-by: Lachlan Roberts <[email protected]>
  • Loading branch information
lachlan-roberts committed Oct 3, 2024
1 parent 0553fe3 commit fbd66a5
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,34 @@ public void fail(Throwable x)
};
}

/**
* Creates a nested callback that runs completed after
* completing the nested callback.
*
* @param callback The nested callback
* @param completed The completion to run after the nested callback is completed
* @return a new callback.
*/
static Callback from(Callback callback, Runnable completed)
{
return new Callback()
{
@Override
public void succeed()
{
callback.succeed();
completed.run();
}

@Override
public void fail(Throwable x)
{
callback.fail(x);
completed.run();
}
};
}

/**
* <p>Method to invoke to succeed the callback.</p>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,111 @@ public String toString()
boolean isRsv2();

boolean isRsv3();

class Wrapper implements Frame
{
private final Frame _frame;

public Wrapper(Frame frame)
{
_frame = frame;
}

@Override
public byte[] getMask()
{
return _frame.getMask();
}

@Override
public byte getOpCode()
{
return _frame.getOpCode();
}

@Override
public ByteBuffer getPayload()
{
return _frame.getPayload();
}

@Override
public int getPayloadLength()
{
return _frame.getPayloadLength();
}

@Override
public Type getType()
{
return _frame.getType();
}

@Override
public boolean hasPayload()
{
return _frame.hasPayload();
}

@Override
public boolean isFin()
{
return _frame.isFin();
}

@Override
public boolean isMasked()
{
return _frame.isMasked();
}

@Override
public boolean isRsv1()
{
return _frame.isRsv1();
}

@Override
public boolean isRsv2()
{
return _frame.isRsv2();
}

@Override
public boolean isRsv3()
{
return _frame.isRsv3();
}
}

static Frame copy(Frame frame)
{
ByteBuffer payloadCopy = copy(frame.getPayload());
return new Frame.Wrapper(frame)
{
@Override
public ByteBuffer getPayload()
{
return payloadCopy;
}

@Override
public int getPayloadLength()
{
return payloadCopy == null ? 0 : payloadCopy.remaining();
}
};
}

private static ByteBuffer copy(ByteBuffer buffer)
{
if (buffer == null)
return null;
int p = buffer.position();
ByteBuffer clone = buffer.isDirect() ? ByteBuffer.allocateDirect(buffer.remaining()) : ByteBuffer.allocate(buffer.remaining());
clone.put(buffer);
clone.flip();
buffer.position(p);
return clone;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,60 @@ public void onFrame(Frame frame, Callback coreCallback)
coreCallback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " FRAME method error: " + cause.getMessage(), cause));
return;
}

switch (frame.getOpCode())
{
case OpCode.TEXT ->
{
if (textHandle == null)
autoDemand();
}
case OpCode.BINARY ->
{
if (binaryHandle == null)
autoDemand();
}
case OpCode.CONTINUATION ->
{
if (activeMessageSink == null)
autoDemand();
}
case OpCode.PING ->
{
if (pingHandle == null)
autoDemand();
}
case OpCode.PONG ->
{
if (pongHandle == null)
autoDemand();
}
case OpCode.CLOSE ->
{
// Do nothing.
}
default ->
{
coreCallback.failed(new IllegalStateException());
return;
}
};
}

Callback.Completable eventCallback = new Callback.Completable();
switch (frame.getOpCode())
{
case OpCode.CLOSE -> onCloseFrame(frame, eventCallback);
case OpCode.PING -> onPingFrame(frame, eventCallback);
case OpCode.PONG -> onPongFrame(frame, eventCallback);
case OpCode.TEXT -> onTextFrame(frame, eventCallback);
case OpCode.BINARY -> onBinaryFrame(frame, eventCallback);
case OpCode.CONTINUATION -> onContinuationFrame(frame, eventCallback);
default -> coreCallback.failed(new IllegalStateException());
case OpCode.PING -> onPingFrame(frame, eventCallback);
case OpCode.PONG -> onPongFrame(frame, eventCallback);
case OpCode.CLOSE -> onCloseFrame(frame, eventCallback);
default ->
{
coreCallback.failed(new IllegalStateException());
return;
}
};

// Combine the callback from the frame handler and the event handler.
Expand Down Expand Up @@ -315,6 +357,13 @@ private void onPingFrame(Frame frame, Callback callback)
}
else
{
// If we have a frameHandler it takes responsibility for handling the ping and demanding.
if (frameHandle != null)
{
callback.succeeded();
return;
}

// Automatically respond.
getSession().sendPong(frame.getPayload(), new org.eclipse.jetty.websocket.api.Callback()
{
Expand Down Expand Up @@ -358,7 +407,10 @@ private void onPongFrame(Frame frame, Callback callback)
}
else
{
internalDemand();
// If we have a frameHandler it takes responsibility for handling the pong and demanding.
callback.succeeded();
if (frameHandle == null)
internalDemand();
}
}

Expand Down Expand Up @@ -387,7 +439,8 @@ private void acceptFrame(Frame frame, Callback callback)
if (activeMessageSink == null)
{
callback.succeeded();
internalDemand();
if (frameHandle == null)
internalDemand();
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,29 @@ public void onMessage(String message) throws IOException
public static class ListenerSocket implements Session.Listener
{
final List<Frame> frames = new CopyOnWriteArrayList<>();
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
frames.add(frame);
frames.add(Frame.copy(frame));

// Because no pingListener is registered, the frameListener is responsible for handling pings.
if (frame.getOpCode() == OpCode.PING)
{
session.sendPong(frame.getPayload(), Callback.from(callback, session::demand));
return;
}

callback.succeed();
session.demand();
}
}

Expand Down Expand Up @@ -109,27 +126,19 @@ public void onWebSocketFrame(Frame frame, Callback callback)
if (frame.getOpCode() == OpCode.TEXT)
textMessages.add(BufferUtil.toString(frame.getPayload()));
callback.succeed();
session.demand();
}
}

@WebSocket(autoDemand = false)
public static class PingSocket extends ListenerSocket
{
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
super.onWebSocketFrame(frame, callback);
if (frame.getType() == Frame.Type.TEXT)
session.sendPing(ByteBuffer.wrap("server-ping".getBytes(StandardCharsets.UTF_8)), Callback.NOOP);
super.onWebSocketFrame(frame, callback);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,18 @@ public void onWebSocketFrame(Frame frame, Callback callback)
{
switch (frame.getOpCode())
{
case OpCode.PING -> pingMessages.add(BufferUtil.copy(frame.getPayload()));
case OpCode.PONG -> pongMessages.add(BufferUtil.copy(frame.getPayload()));
case OpCode.PING ->
{
pingMessages.add(BufferUtil.copy(frame.getPayload()));
session.sendPong(frame.getPayload(), callback);
}
case OpCode.PONG ->
{
pongMessages.add(BufferUtil.copy(frame.getPayload()));
callback.succeed();
}
default -> callback.succeed();
}
callback.succeed();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ public static class FrameEndpoint implements Session.Listener
{
public CountDownLatch closeLatch = new CountDownLatch(1);
public LinkedBlockingQueue<String> frameEvents = new LinkedBlockingQueue<>();
public Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

Expand All @@ -147,6 +149,7 @@ public void onWebSocketFrame(Frame frame, Callback callback)
BufferUtil.toUTF8String(frame.getPayload()),
frame.getPayloadLength()));
callback.succeed();
session.demand();
}

@Override
Expand Down

0 comments on commit fbd66a5

Please sign in to comment.