diff --git a/lib/bandit/adapter.ex b/lib/bandit/adapter.ex index 5fcbf183..d53787ac 100644 --- a/lib/bandit/adapter.ex +++ b/lib/bandit/adapter.ex @@ -13,6 +13,7 @@ defmodule Bandit.Adapter do method: nil, status: nil, content_encoding: nil, + compression_context: nil, upgrade: nil, metrics: %{}, opts: [] @@ -24,6 +25,7 @@ defmodule Bandit.Adapter do method: Plug.Conn.method() | nil, status: Plug.Conn.status() | nil, content_encoding: String.t(), + compression_context: Bandit.Compression.t() | nil, upgrade: nil | {:websocket, opts :: keyword(), websocket_opts :: keyword()}, metrics: %{}, opts: %{ @@ -87,47 +89,18 @@ defmodule Bandit.Adapter do def send_resp(%__MODULE__{} = adapter, status, headers, body) do validate_calling_process!(adapter) start_time = Bandit.Telemetry.monotonic_time() - response_content_encoding_header = Bandit.Headers.get_header(headers, "content-encoding") - response_has_strong_etag = - case Bandit.Headers.get_header(headers, "etag") do - nil -> false - "\W" <> _rest -> false - _strong_etag -> true - end - - response_indicates_no_transform = - case Bandit.Headers.get_header(headers, "cache-control") do - nil -> false - header -> "no-transform" in Plug.Conn.Utils.list(header) - end - - raw_body_bytes = IO.iodata_length(body) - - {body, headers, compression_metrics} = - case {body, adapter.content_encoding, response_content_encoding_header, - response_has_strong_etag, response_indicates_no_transform} do - {body, content_encoding, nil, false, false} - when raw_body_bytes > 0 and not is_nil(content_encoding) -> - metrics = %{ - resp_uncompressed_body_bytes: raw_body_bytes, - resp_compression_method: content_encoding - } + # Save an extra iodata_length by checking common cases + empty_body? = body == "" || body == [] + {headers, compression_context} = Bandit.Compression.new(adapter, status, headers, empty_body?) - deflate_options = Keyword.get(adapter.opts.http, :deflate_options, []) - deflated_body = Bandit.Compression.compress(body, content_encoding, deflate_options) - headers = [{"content-encoding", adapter.content_encoding} | headers] - {deflated_body, headers, metrics} + {encoded_body, compression_context} = + Bandit.Compression.compress_chunk(body, compression_context) - _ -> - {body, headers, %{}} - end - - compress = Keyword.get(adapter.opts.http, :compress, true) - headers = if compress, do: [{"vary", "accept-encoding"} | headers], else: headers + compression_metrics = Bandit.Compression.close(compression_context) - length = IO.iodata_length(body) - headers = Bandit.Headers.add_content_length(headers, length, status, adapter.method) + encoded_length = IO.iodata_length(encoded_body) + headers = Bandit.Headers.add_content_length(headers, encoded_length, status, adapter.method) metrics = adapter.metrics @@ -137,7 +110,7 @@ defmodule Bandit.Adapter do adapter = %{adapter | metrics: metrics} |> send_headers(status, headers, :raw) - |> send_data(body, true) + |> send_data(encoded_body, true) {:ok, nil, adapter} end @@ -182,7 +155,9 @@ defmodule Bandit.Adapter do validate_calling_process!(adapter) start_time = Bandit.Telemetry.monotonic_time() metrics = Map.put(adapter.metrics, :resp_start_time, start_time) - adapter = %{adapter | metrics: metrics} + + {headers, compression_context} = Bandit.Compression.new(adapter, status, headers, false, true) + adapter = %{adapter | metrics: metrics, compression_context: compression_context} {:ok, nil, send_headers(adapter, status, headers, :chunk_encoded)} end @@ -199,7 +174,17 @@ defmodule Bandit.Adapter do # chunk/2 is unique among Plug.Conn.Adapter's sending callbacks in that it can return an error # tuple instead of just raising or dying on error. Rescue here to implement this try do - {:ok, nil, send_data(adapter, chunk, IO.iodata_length(chunk) == 0)} + if IO.iodata_length(chunk) == 0 do + compression_metrics = Bandit.Compression.close(adapter.compression_context) + adapter = %{adapter | metrics: Map.merge(adapter.metrics, compression_metrics)} + {:ok, nil, send_data(adapter, chunk, true)} + else + {encoded_chunk, compression_context} = + Bandit.Compression.compress_chunk(chunk, adapter.compression_context) + + adapter = %{adapter | compression_context: compression_context} + {:ok, nil, send_data(adapter, encoded_chunk, false)} + end rescue error -> {:error, Exception.message(error)} end diff --git a/lib/bandit/compression.ex b/lib/bandit/compression.ex index 47a71d5b..558cedeb 100644 --- a/lib/bandit/compression.ex +++ b/lib/bandit/compression.ex @@ -1,6 +1,15 @@ defmodule Bandit.Compression do @moduledoc false + defstruct method: nil, bytes_in: 0, lib_context: nil + + @typedoc "A struct containing the context for response compression" + @type t :: %__MODULE__{ + method: :deflate | :gzip | :identity, + bytes_in: non_neg_integer(), + lib_context: term() + } + @spec negotiate_content_encoding(nil | binary(), boolean()) :: String.t() | nil def negotiate_content_encoding(nil, _), do: nil def negotiate_content_encoding(_, false), do: nil @@ -11,27 +20,101 @@ defmodule Bandit.Compression do |> Enum.find(&(&1 in ~w(deflate gzip x-gzip))) end - @spec compress(iolist(), String.t(), Bandit.deflate_options()) :: iodata() - def compress(response, "deflate", opts) do - deflate_context = :zlib.open() + def new(adapter, status, headers, empty_body?, streamable \\ false) do + response_content_encoding_header = Bandit.Headers.get_header(headers, "content-encoding") + + headers = maybe_add_vary_header(adapter, status, headers) + + if status not in [204, 304] && not is_nil(adapter.content_encoding) && + is_nil(response_content_encoding_header) && + !response_has_strong_etag(headers) && !response_indicates_no_transform(headers) && + !empty_body? do + deflate_options = Keyword.get(adapter.opts.http, :deflate_options, []) + + case start_stream(adapter.content_encoding, deflate_options, streamable) do + {:ok, context} -> {[{"content-encoding", adapter.content_encoding} | headers], context} + {:error, :unsupported_encoding} -> {headers, %__MODULE__{method: :identity}} + end + else + {headers, %__MODULE__{method: :identity}} + end + end - try do - :ok = - :zlib.deflateInit( - deflate_context, - Keyword.get(opts, :level, :default), - :deflated, - Keyword.get(opts, :window_bits, 15), - Keyword.get(opts, :mem_level, 8), - Keyword.get(opts, :strategy, :default) - ) - - :zlib.deflate(deflate_context, response, :sync) - after - :zlib.close(deflate_context) + defp maybe_add_vary_header(adapter, status, headers) do + if status != 204 && Keyword.get(adapter.opts.http, :compress, true), + do: [{"vary", "accept-encoding"} | headers], + else: headers + end + + defp response_has_strong_etag(headers) do + case Bandit.Headers.get_header(headers, "etag") do + nil -> false + "\W" <> _rest -> false + _strong_etag -> true + end + end + + defp response_indicates_no_transform(headers) do + case Bandit.Headers.get_header(headers, "cache-control") do + nil -> false + header -> "no-transform" in Plug.Conn.Utils.list(header) end end - def compress(response, "x-gzip", _opts), do: compress(response, "gzip", []) - def compress(response, "gzip", _opts), do: :zlib.gzip(response) + defp start_stream("deflate", opts, _streamable) do + deflate_context = :zlib.open() + + :zlib.deflateInit( + deflate_context, + Keyword.get(opts, :level, :default), + :deflated, + Keyword.get(opts, :window_bits, 15), + Keyword.get(opts, :mem_level, 8), + Keyword.get(opts, :strategy, :default) + ) + + {:ok, %__MODULE__{method: :deflate, lib_context: deflate_context}} + end + + defp start_stream("x-gzip", _opts, false), do: {:ok, %__MODULE__{method: :gzip}} + defp start_stream("gzip", _opts, false), do: {:ok, %__MODULE__{method: :gzip}} + defp start_stream(_encoding, _opts, _streamable), do: {:error, :unsupported_encoding} + + def compress_chunk(chunk, %__MODULE__{method: :deflate} = context) do + result = :zlib.deflate(context.lib_context, chunk, :sync) + + context = + context + |> Map.update!(:bytes_in, &(&1 + IO.iodata_length(chunk))) + + {result, context} + end + + def compress_chunk(chunk, %__MODULE__{method: :gzip, lib_context: nil} = context) do + result = :zlib.gzip(chunk) + + context = + context + |> Map.update!(:bytes_in, &(&1 + IO.iodata_length(chunk))) + |> Map.put(:lib_context, :done) + + {result, context} + end + + def compress_chunk(chunk, %__MODULE__{method: :identity} = context) do + {chunk, context} + end + + def close(%__MODULE__{} = context) do + if context.method == :deflate, do: :zlib.close(context.lib_context) + + if context.method == :identity do + %{} + else + %{ + resp_compression_method: to_string(context.method), + resp_uncompressed_body_bytes: context.bytes_in + } + end + end end diff --git a/test/bandit/http1/request_test.exs b/test/bandit/http1/request_test.exs index cc03a5de..ea75b2c8 100644 --- a/test/bandit/http1/request_test.exs +++ b/test/bandit/http1/request_test.exs @@ -1275,15 +1275,11 @@ defmodule HTTP1RequestTest do assert response.headers["content-encoding"] == ["deflate"] assert response.headers["vary"] == ["accept-encoding"] - deflate_context = :zlib.open() - :ok = :zlib.deflateInit(deflate_context) + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, response.body) |> IO.iodata_to_binary() - expected = - deflate_context - |> :zlib.deflate(String.duplicate("a", 10_000), :sync) - |> IO.iodata_to_binary() - - assert response.body == expected + assert inflated_body == String.duplicate("a", 10_000) end test "writes out a response with gzip encoding if so negotiated", context do @@ -1320,15 +1316,47 @@ defmodule HTTP1RequestTest do assert response.headers["content-encoding"] == ["deflate"] assert response.headers["vary"] == ["accept-encoding"] - deflate_context = :zlib.open() - :ok = :zlib.deflateInit(deflate_context) + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, response.body) |> IO.iodata_to_binary() + + assert inflated_body == String.duplicate("a", 10_000) + end + + test "does not indicate content encoding or vary for 204 responses", context do + response = + Req.get!(context.req, url: "/send_204", headers: [{"accept-encoding", "deflate"}]) + + assert response.status == 204 + assert response.headers["content-encoding"] == nil + assert response.headers["vary"] == nil + assert response.body == "" + end + + test "does not indicate content encoding but indicates vary for 304 responses", context do + response = + Req.get!(context.req, url: "/send_304", headers: [{"accept-encoding", "deflate"}]) + + assert response.status == 304 + assert response.headers["content-encoding"] == nil + assert response.headers["vary"] == ["accept-encoding"] + assert response.body == "" + end + + test "does not indicate content encoding but indicates vary for zero byte responses", + context do + response = + Req.get!(context.req, url: "/send_empty", headers: [{"accept-encoding", "deflate"}]) - expected = - deflate_context - |> :zlib.deflate(String.duplicate("a", 10_000), :sync) - |> IO.iodata_to_binary() + assert response.status == 200 + assert response.headers["content-encoding"] == nil + assert response.headers["vary"] == ["accept-encoding"] + assert response.body == "" + end - assert response.body == expected + def send_empty(conn) do + conn + |> send_resp(200, "") end test "writes out an encoded response for an iolist body", context do @@ -1340,15 +1368,42 @@ defmodule HTTP1RequestTest do assert response.headers["content-encoding"] == ["deflate"] assert response.headers["vary"] == ["accept-encoding"] - deflate_context = :zlib.open() - :ok = :zlib.deflateInit(deflate_context) + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, response.body) |> IO.iodata_to_binary() + + assert inflated_body == String.duplicate("a", 10_000) + end + + test "deflate encodes chunk responses", context do + response = + Req.get!(context.req, + url: "/send_big_body_chunked", + headers: [{"accept-encoding", "deflate"}] + ) + + assert response.status == 200 + assert response.headers["content-encoding"] == ["deflate"] + assert response.headers["vary"] == ["accept-encoding"] + + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, response.body) |> IO.iodata_to_binary() + + assert inflated_body == String.duplicate("a", 10_000) + end - expected = - deflate_context - |> :zlib.deflate(String.duplicate("a", 10_000), :sync) - |> IO.iodata_to_binary() + test "does not gzip encode chunk responses", context do + response = + Req.get!(context.req, + url: "/send_big_body_chunked", + headers: [{"accept-encoding", "gzip"}] + ) - assert response.body == expected + assert response.status == 200 + assert response.headers["content-encoding"] == nil + assert response.headers["vary"] == ["accept-encoding"] + assert response.body == String.duplicate("a", 10_000) end test "falls back to no encoding if no encodings provided", context do @@ -1453,6 +1508,23 @@ defmodule HTTP1RequestTest do |> send_resp(200, String.duplicate("a", 10_000)) end + def send_big_body_chunked(conn) do + conn = send_chunked(conn, 200) + + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + {:ok, conn} = chunk(conn, String.duplicate("a", 1_000)) + + conn + end + def send_iolist_body(conn) do conn |> put_resp_header("content-length", "10000") diff --git a/test/bandit/http2/protocol_test.exs b/test/bandit/http2/protocol_test.exs index 9236af8a..c3f743d2 100644 --- a/test/bandit/http2/protocol_test.exs +++ b/test/bandit/http2/protocol_test.exs @@ -300,20 +300,18 @@ defmodule HTTP2ProtocolTest do {":status", "200"}, {"date", _date}, {"content-length", "34"}, - {"vary", "accept-encoding"}, {"content-encoding", "deflate"}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) - deflate_context = :zlib.open() - :ok = :zlib.deflateInit(deflate_context) + {:ok, 1, true, body} = SimpleH2Client.recv_body(socket) - expected = - deflate_context - |> :zlib.deflate(String.duplicate("a", 10_000), :sync) - |> IO.iodata_to_binary() + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, body) |> IO.iodata_to_binary() - assert SimpleH2Client.recv_body(socket) == {:ok, 1, true, expected} + assert inflated_body == String.duplicate("a", 10_000) end test "writes out a response with gzip encoding if so negotiated", context do @@ -334,8 +332,8 @@ defmodule HTTP2ProtocolTest do {":status", "200"}, {"date", _date}, {"content-length", "46"}, - {"vary", "accept-encoding"}, {"content-encoding", "gzip"}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) @@ -362,8 +360,8 @@ defmodule HTTP2ProtocolTest do {":status", "200"}, {"date", _date}, {"content-length", "46"}, - {"vary", "accept-encoding"}, {"content-encoding", "x-gzip"}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) @@ -390,20 +388,94 @@ defmodule HTTP2ProtocolTest do {":status", "200"}, {"date", _date}, {"content-length", "34"}, - {"vary", "accept-encoding"}, {"content-encoding", "deflate"}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) - deflate_context = :zlib.open() - :ok = :zlib.deflateInit(deflate_context) + {:ok, 1, true, body} = SimpleH2Client.recv_body(socket) - expected = - deflate_context - |> :zlib.deflate(String.duplicate("a", 10_000), :sync) - |> IO.iodata_to_binary() + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, body) |> IO.iodata_to_binary() - assert SimpleH2Client.recv_body(socket) == {:ok, 1, true, expected} + assert inflated_body == String.duplicate("a", 10_000) + end + + test "does not indicate content encoding or vary for 204 responses", context do + socket = SimpleH2Client.setup_connection(context) + + headers = [ + {":method", "GET"}, + {":path", "/send_204"}, + {":scheme", "https"}, + {":authority", "localhost:#{context.port}"}, + {"accept-encoding", "deflate"} + ] + + SimpleH2Client.send_headers(socket, 1, true, headers) + + assert {:ok, 1, true, + [ + {":status", "204"}, + {"date", _date}, + {"cache-control", "max-age=0, private, must-revalidate"} + ], _ctx} = SimpleH2Client.recv_headers(socket) + end + + # RFC9110ยง15.4.5 + test "does not indicate content encoding but indicates vary for 304 responses", context do + socket = SimpleH2Client.setup_connection(context) + + headers = [ + {":method", "GET"}, + {":path", "/send_304"}, + {":scheme", "https"}, + {":authority", "localhost:#{context.port}"}, + {"accept-encoding", "deflate"} + ] + + SimpleH2Client.send_headers(socket, 1, true, headers) + + assert {:ok, 1, true, + [ + {":status", "304"}, + {"date", _date}, + {"content-length", "5"}, + {"vary", "accept-encoding"}, + {"cache-control", "max-age=0, private, must-revalidate"} + ], _ctx} = SimpleH2Client.recv_headers(socket) + end + + test "does not indicate content encoding but indicates vary for zero byte responses", + context do + socket = SimpleH2Client.setup_connection(context) + + headers = [ + {":method", "GET"}, + {":path", "/send_empty"}, + {":scheme", "https"}, + {":authority", "localhost:#{context.port}"}, + {"accept-encoding", "deflate"} + ] + + SimpleH2Client.send_headers(socket, 1, true, headers) + + assert {:ok, 1, false, + [ + {":status", "200"}, + {"date", _date}, + {"content-length", "0"}, + {"vary", "accept-encoding"}, + {"cache-control", "max-age=0, private, must-revalidate"} + ], _ctx} = SimpleH2Client.recv_headers(socket) + + assert SimpleH2Client.recv_body(socket) == {:ok, 1, true, ""} + end + + def send_empty(conn) do + conn + |> send_resp(200, "") end test "writes out a response with deflate encoding for an iolist body", context do @@ -424,20 +496,18 @@ defmodule HTTP2ProtocolTest do {":status", "200"}, {"date", _date}, {"content-length", "34"}, - {"vary", "accept-encoding"}, {"content-encoding", "deflate"}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) - deflate_context = :zlib.open() - :ok = :zlib.deflateInit(deflate_context) + {:ok, 1, true, body} = SimpleH2Client.recv_body(socket) - expected = - deflate_context - |> :zlib.deflate(String.duplicate("a", 10_000), :sync) - |> IO.iodata_to_binary() + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, body) |> IO.iodata_to_binary() - assert SimpleH2Client.recv_body(socket) == {:ok, 1, true, expected} + assert inflated_body == String.duplicate("a", 10_000) end test "does no encoding if content-encoding header already present in response", context do @@ -512,8 +582,8 @@ defmodule HTTP2ProtocolTest do {":status", "200"}, {"date", _date}, {"content-length", "46"}, - {"vary", "accept-encoding"}, {"content-encoding", "gzip"}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"}, {"etag", "W/\"1234\""} ], _ctx} = SimpleH2Client.recv_headers(socket) @@ -714,7 +784,6 @@ defmodule HTTP2ProtocolTest do [ {":status", "204"}, {"date", _date}, - {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) end @@ -815,6 +884,64 @@ defmodule HTTP2ProtocolTest do assert SimpleH2Client.recv_body(socket) == {:ok, 1, true, ""} end + test "deflate encodes multiple DATA frames when chunking", context do + socket = SimpleH2Client.setup_connection(context) + + headers = [ + {":method", "GET"}, + {":path", "/chunk_response"}, + {":scheme", "https"}, + {":authority", "localhost:#{context.port}"}, + {"accept-encoding", "deflate"} + ] + + SimpleH2Client.send_headers(socket, 1, true, headers) + + assert {:ok, 1, false, + [ + {":status", "200"}, + {"date", _date}, + {"content-encoding", "deflate"}, + {"vary", "accept-encoding"}, + {"cache-control", "max-age=0, private, must-revalidate"} + ], _ctx} = SimpleH2Client.recv_headers(socket) + + {:ok, 1, false, chunk_1} = SimpleH2Client.recv_body(socket) + {:ok, 1, false, chunk_2} = SimpleH2Client.recv_body(socket) + assert {:ok, 1, true, ""} == SimpleH2Client.recv_body(socket) + + inflate_context = :zlib.open() + :ok = :zlib.inflateInit(inflate_context) + inflated_body = :zlib.inflate(inflate_context, [chunk_1, chunk_2]) |> IO.iodata_to_binary() + + assert inflated_body == "OKDOKEE" + end + + test "does not gzip encode DATA frames when chunking", context do + socket = SimpleH2Client.setup_connection(context) + + headers = [ + {":method", "GET"}, + {":path", "/chunk_response"}, + {":scheme", "https"}, + {":authority", "localhost:#{context.port}"}, + {"accept-encoding", "gzip"} + ] + + SimpleH2Client.send_headers(socket, 1, true, headers) + + assert {:ok, 1, false, + [ + {":status", "200"}, + {"date", _date}, + {"vary", "accept-encoding"}, + {"cache-control", "max-age=0, private, must-revalidate"} + ], _ctx} = SimpleH2Client.recv_headers(socket) + + assert {:ok, 1, false, "OK"} == SimpleH2Client.recv_body(socket) + assert {:ok, 1, false, "DOKEE"} == SimpleH2Client.recv_body(socket) + end + test "does not write out a body for a chunked response to a HEAD request", context do socket = SimpleH2Client.setup_connection(context) @@ -824,6 +951,7 @@ defmodule HTTP2ProtocolTest do [ {":status", "200"}, {"date", _date}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket) @@ -870,6 +998,7 @@ defmodule HTTP2ProtocolTest do [ {":status", "304"}, {"date", _date}, + {"vary", "accept-encoding"}, {"cache-control", "max-age=0, private, must-revalidate"} ], _ctx} = SimpleH2Client.recv_headers(socket)