diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Callback.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Callback.java index a07eca3e1258..719ff9695eb7 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Callback.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Callback.java @@ -57,6 +57,34 @@ public void fail(Throwable x) }; } + /** + * Creates a nested callback that runs completed after + * completing the nested callback. + * + * @param callback The nested callback + * @param completed The completion to run after the nested callback is completed + * @return a new callback. + */ + static Callback from(Callback callback, Runnable completed) + { + return new Callback() + { + @Override + public void succeed() + { + callback.succeed(); + completed.run(); + } + + @Override + public void fail(Throwable x) + { + callback.fail(x); + completed.run(); + } + }; + } + /** *

Method to invoke to succeed the callback.

* diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Frame.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Frame.java index 0f5bd23923e9..4c8c3451777e 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Frame.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-api/src/main/java/org/eclipse/jetty/websocket/api/Frame.java @@ -101,4 +101,111 @@ public String toString() boolean isRsv2(); boolean isRsv3(); + + class Wrapper implements Frame + { + private final Frame _frame; + + public Wrapper(Frame frame) + { + _frame = frame; + } + + @Override + public byte[] getMask() + { + return _frame.getMask(); + } + + @Override + public byte getOpCode() + { + return _frame.getOpCode(); + } + + @Override + public ByteBuffer getPayload() + { + return _frame.getPayload(); + } + + @Override + public int getPayloadLength() + { + return _frame.getPayloadLength(); + } + + @Override + public Type getType() + { + return _frame.getType(); + } + + @Override + public boolean hasPayload() + { + return _frame.hasPayload(); + } + + @Override + public boolean isFin() + { + return _frame.isFin(); + } + + @Override + public boolean isMasked() + { + return _frame.isMasked(); + } + + @Override + public boolean isRsv1() + { + return _frame.isRsv1(); + } + + @Override + public boolean isRsv2() + { + return _frame.isRsv2(); + } + + @Override + public boolean isRsv3() + { + return _frame.isRsv3(); + } + } + + static Frame copy(Frame frame) + { + ByteBuffer payloadCopy = copy(frame.getPayload()); + return new Frame.Wrapper(frame) + { + @Override + public ByteBuffer getPayload() + { + return payloadCopy; + } + + @Override + public int getPayloadLength() + { + return payloadCopy == null ? 0 : payloadCopy.remaining(); + } + }; + } + + private static ByteBuffer copy(ByteBuffer buffer) + { + if (buffer == null) + return null; + int p = buffer.position(); + ByteBuffer clone = buffer.isDirect() ? ByteBuffer.allocateDirect(buffer.remaining()) : ByteBuffer.allocate(buffer.remaining()); + clone.put(buffer); + clone.flip(); + buffer.position(p); + return clone; + } } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java index 902784fda5d5..5effe583994a 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java @@ -206,18 +206,60 @@ public void onFrame(Frame frame, Callback coreCallback) coreCallback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " FRAME method error: " + cause.getMessage(), cause)); return; } + + switch (frame.getOpCode()) + { + case OpCode.TEXT -> + { + if (textHandle == null) + autoDemand(); + } + case OpCode.BINARY -> + { + if (binaryHandle == null) + autoDemand(); + } + case OpCode.CONTINUATION -> + { + if (activeMessageSink == null) + autoDemand(); + } + case OpCode.PING -> + { + if (pingHandle == null) + autoDemand(); + } + case OpCode.PONG -> + { + if (pongHandle == null) + autoDemand(); + } + case OpCode.CLOSE -> + { + // Do nothing. + } + default -> + { + coreCallback.failed(new IllegalStateException()); + return; + } + }; } Callback.Completable eventCallback = new Callback.Completable(); switch (frame.getOpCode()) { - case OpCode.CLOSE -> onCloseFrame(frame, eventCallback); - case OpCode.PING -> onPingFrame(frame, eventCallback); - case OpCode.PONG -> onPongFrame(frame, eventCallback); case OpCode.TEXT -> onTextFrame(frame, eventCallback); case OpCode.BINARY -> onBinaryFrame(frame, eventCallback); case OpCode.CONTINUATION -> onContinuationFrame(frame, eventCallback); - default -> coreCallback.failed(new IllegalStateException()); + case OpCode.PING -> onPingFrame(frame, eventCallback); + case OpCode.PONG -> onPongFrame(frame, eventCallback); + case OpCode.CLOSE -> onCloseFrame(frame, eventCallback); + default -> + { + coreCallback.failed(new IllegalStateException()); + return; + } }; // Combine the callback from the frame handler and the event handler. @@ -315,6 +357,13 @@ private void onPingFrame(Frame frame, Callback callback) } else { + // If we have a frameHandler it takes responsibility for handling the ping and demanding. + if (frameHandle != null) + { + callback.succeeded(); + return; + } + // Automatically respond. getSession().sendPong(frame.getPayload(), new org.eclipse.jetty.websocket.api.Callback() { @@ -358,7 +407,10 @@ private void onPongFrame(Frame frame, Callback callback) } else { - internalDemand(); + // If we have a frameHandler it takes responsibility for handling the pong and demanding. + callback.succeeded(); + if (frameHandle == null) + internalDemand(); } } @@ -387,7 +439,8 @@ private void acceptFrame(Frame frame, Callback callback) if (activeMessageSink == null) { callback.succeeded(); - internalDemand(); + if (frameHandle == null) + internalDemand(); return; } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java index 2994945b7c91..b25eb087105d 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java @@ -72,12 +72,29 @@ public void onMessage(String message) throws IOException public static class ListenerSocket implements Session.Listener { final List frames = new CopyOnWriteArrayList<>(); + Session session; + + @Override + public void onWebSocketOpen(Session session) + { + this.session = session; + session.demand(); + } @Override public void onWebSocketFrame(Frame frame, Callback callback) { - frames.add(frame); + frames.add(Frame.copy(frame)); + + // Because no pingListener is registered, the frameListener is responsible for handling pings. + if (frame.getOpCode() == OpCode.PING) + { + session.sendPong(frame.getPayload(), Callback.from(callback, session::demand)); + return; + } + callback.succeed(); + session.demand(); } } @@ -109,27 +126,19 @@ public void onWebSocketFrame(Frame frame, Callback callback) if (frame.getOpCode() == OpCode.TEXT) textMessages.add(BufferUtil.toString(frame.getPayload())); callback.succeed(); + session.demand(); } } @WebSocket(autoDemand = false) public static class PingSocket extends ListenerSocket { - Session session; - - @Override - public void onWebSocketOpen(Session session) - { - this.session = session; - session.demand(); - } - @Override public void onWebSocketFrame(Frame frame, Callback callback) { - super.onWebSocketFrame(frame, callback); if (frame.getType() == Frame.Type.TEXT) session.sendPing(ByteBuffer.wrap("server-ping".getBytes(StandardCharsets.UTF_8)), Callback.NOOP); + super.onWebSocketFrame(frame, callback); } } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/proxy/WebSocketProxyTest.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/proxy/WebSocketProxyTest.java index 62d8b8d100b4..e4725eb08438 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/proxy/WebSocketProxyTest.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/proxy/WebSocketProxyTest.java @@ -347,10 +347,18 @@ public void onWebSocketFrame(Frame frame, Callback callback) { switch (frame.getOpCode()) { - case OpCode.PING -> pingMessages.add(BufferUtil.copy(frame.getPayload())); - case OpCode.PONG -> pongMessages.add(BufferUtil.copy(frame.getPayload())); + case OpCode.PING -> + { + pingMessages.add(BufferUtil.copy(frame.getPayload())); + session.sendPong(frame.getPayload(), callback); + } + case OpCode.PONG -> + { + pongMessages.add(BufferUtil.copy(frame.getPayload())); + callback.succeed(); + } + default -> callback.succeed(); } - callback.succeed(); } } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/FrameListenerTest.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/FrameListenerTest.java index 6378deb4128e..9adc5a28813e 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/FrameListenerTest.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/FrameListenerTest.java @@ -131,10 +131,12 @@ public static class FrameEndpoint implements Session.Listener { public CountDownLatch closeLatch = new CountDownLatch(1); public LinkedBlockingQueue frameEvents = new LinkedBlockingQueue<>(); + public Session session; @Override public void onWebSocketOpen(Session session) { + this.session = session; session.demand(); } @@ -147,6 +149,7 @@ public void onWebSocketFrame(Frame frame, Callback callback) BufferUtil.toUTF8String(frame.getPayload()), frame.getPayloadLength())); callback.succeed(); + session.demand(); } @Override