From 4781d3193b3304cf6a30b495cd512dcb580026c6 Mon Sep 17 00:00:00 2001 From: Mat Trudel Date: Fri, 29 Nov 2024 11:57:24 -0500 Subject: [PATCH] Use inform to send websocket upgrades instead of a regular sending (#428) * Allow inform/3 to receive keyword lists * Use inform to send websocket upgrades instead of a regular send with a 101 status --- lib/bandit/adapter.ex | 2 ++ lib/bandit/websocket/handshake.ex | 24 ++++++++++--------- test/bandit/http1/request_test.exs | 2 +- test/bandit/http2/plug_test.exs | 2 +- .../bandit/websocket/http1_handshake_test.exs | 19 +++++++++++++++ test/bandit/websocket/upgrade_test.exs | 20 ++++++++++++---- test/support/simple_websocket_client.ex | 4 +--- 7 files changed, 53 insertions(+), 20 deletions(-) diff --git a/lib/bandit/adapter.ex b/lib/bandit/adapter.ex index 891ce22f..5fcbf183 100644 --- a/lib/bandit/adapter.ex +++ b/lib/bandit/adapter.ex @@ -214,6 +214,8 @@ defmodule Bandit.Adapter do if get_http_protocol(adapter) == :"HTTP/1.0" do {:error, :not_supported} else + # inform/3 is unique in that headers comes in as a keyword list + headers = Enum.map(headers, fn {header, value} -> {to_string(header), value} end) {:ok, send_headers(adapter, status, headers, :inform)} end end diff --git a/lib/bandit/websocket/handshake.ex b/lib/bandit/websocket/handshake.ex index c3f4dc55..ac6794a5 100644 --- a/lib/bandit/websocket/handshake.ex +++ b/lib/bandit/websocket/handshake.ex @@ -69,19 +69,21 @@ defmodule Bandit.WebSocket.Handshake do hashed_key = :crypto.hash(:sha, concatenated_key) server_key = Base.encode64(hashed_key) - conn - |> resp(101, "") - |> put_resp_header("upgrade", "websocket") - |> put_resp_header("connection", "Upgrade") - |> put_resp_header("sec-websocket-accept", server_key) - |> put_websocket_extension_header(extensions) - |> send_resp() + headers = + [ + {:upgrade, "websocket"}, + {:connection, "Upgrade"}, + {:"sec-websocket-accept", server_key} + ] ++ + websocket_extension_header(extensions) + + inform(conn, 101, headers) end - @spec put_websocket_extension_header(Plug.Conn.t(), extensions()) :: Plug.Conn.t() - defp put_websocket_extension_header(conn, []), do: conn + @spec websocket_extension_header(extensions()) :: keyword() + defp websocket_extension_header([]), do: [] - defp put_websocket_extension_header(conn, extensions) do + defp websocket_extension_header(extensions) do extensions = extensions |> Enum.map_join(",", fn {extension, params} -> @@ -97,6 +99,6 @@ defmodule Bandit.WebSocket.Handshake do |> Enum.join(";") end) - put_resp_header(conn, "sec-websocket-extensions", extensions) + [{:"sec-websocket-extensions", extensions}] end end diff --git a/test/bandit/http1/request_test.exs b/test/bandit/http1/request_test.exs index fa0f08c4..cc03a5de 100644 --- a/test/bandit/http1/request_test.exs +++ b/test/bandit/http1/request_test.exs @@ -1809,7 +1809,7 @@ defmodule HTTP1RequestTest do end def send_inform(conn) do - conn = conn |> inform(100, [{"x-from", "inform"}]) + conn = conn |> inform(100, [{:"x-from", "inform"}]) conn |> send_resp(200, "Informer") end diff --git a/test/bandit/http2/plug_test.exs b/test/bandit/http2/plug_test.exs index 03edfc50..421a7a2a 100644 --- a/test/bandit/http2/plug_test.exs +++ b/test/bandit/http2/plug_test.exs @@ -699,7 +699,7 @@ defmodule HTTP2PlugTest do end def send_inform(conn) do - conn = conn |> inform(100, [{"x-from", "inform"}]) + conn = conn |> inform(100, [{:"x-from", "inform"}]) conn |> send_resp(200, "Informer") end diff --git a/test/bandit/websocket/http1_handshake_test.exs b/test/bandit/websocket/http1_handshake_test.exs index 60b0736a..650ddad8 100644 --- a/test/bandit/websocket/http1_handshake_test.exs +++ b/test/bandit/websocket/http1_handshake_test.exs @@ -33,6 +33,25 @@ defmodule WebSocketHTTP1HandshakeTest do assert Keyword.get(headers, :"sec-websocket-accept") == "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" end + test "does not set content-encoding headers", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "GET", "/", [ + "Host: server.example.com", + "Accept-Encoding: deflate", + "Upgrade: WeBsOcKeT", + "Connection: UpGrAdE", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13" + ]) + + assert {:ok, "101 Switching Protocols", headers, <<>>} = + SimpleHTTP1Client.recv_reply(client) + + assert Keyword.get(headers, :"content-encoding") == nil + assert Keyword.get(headers, :vary) == nil + end + test "negotiates permessage-deflate if so configured", context do client = SimpleWebSocketClient.tcp_client(context) diff --git a/test/bandit/websocket/upgrade_test.exs b/test/bandit/websocket/upgrade_test.exs index 91202841..9937a3a7 100644 --- a/test/bandit/websocket/upgrade_test.exs +++ b/test/bandit/websocket/upgrade_test.exs @@ -50,6 +50,21 @@ defmodule WebSocketUpgradeTest do assert_in_delta now, then + 250, 50 end + test "upgrade responses do not include content-encoding headers", context do + client = SimpleWebSocketClient.tcp_client(context) + SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250") + + SimpleWebSocketClient.send_text_frame(client, "") + {:ok, result} = SimpleWebSocketClient.recv_text_frame(client) + assert result == inspect([:upgrade, :init]) + + # Ensure that the passed timeout was recognized + then = System.monotonic_time(:millisecond) + assert_receive :timeout, 500 + now = System.monotonic_time(:millisecond) + assert_in_delta now, then + 250, 50 + end + defmodule MyNoopWebSock do use NoopWebSock end @@ -69,10 +84,7 @@ defmodule WebSocketUpgradeTest do %{ monotonic_time: integer(), duration: integer(), - req_header_end_time: integer(), - resp_body_bytes: 0, - resp_start_time: integer(), - resp_end_time: integer() + req_header_end_time: integer() }, %{ connection_telemetry_span_context: reference(), diff --git a/test/support/simple_websocket_client.ex b/test/support/simple_websocket_client.ex index c15c23bb..8b01a0a5 100644 --- a/test/support/simple_websocket_client.ex +++ b/test/support/simple_websocket_client.ex @@ -24,13 +24,11 @@ defmodule SimpleWebSocketClient do ) # Because we don't want to consume any more than our headers, we can't use SimpleHTTP1Client - {:ok, response} = Transport.recv(client, 239) + {:ok, response} = Transport.recv(client, 164) [ "HTTP/1.1 101 Switching Protocols", "date: " <> _date, - "vary: accept-encoding", - "cache-control: max-age=0, private, must-revalidate", "upgrade: websocket", "connection: Upgrade", "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=",