Skip to content

Commit

Permalink
Set User-Agent header in CONNECT requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Feb 1, 2025
1 parent 2b9a90a commit 3dac6c4
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 28 deletions.
11 changes: 7 additions & 4 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ Backwards-incompatible changes

.. _python-socks: https://github.com/romis2012/python-socks

New features
............
.. admonition:: Keepalive is enabled in the :mod:`threading` implementation.
:class: note

The :mod:`threading` implementation now sends Ping frames at regular
intervals and closes the connection if it doesn't receive a matching Pong
frame just like the :mod:`asyncio` implementation.

* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the
:mod:`threading` implementation.
See :doc:`keepalive and latency <../topics/keepalive>` for details.

Improvements
............
Expand Down
37 changes: 27 additions & 10 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def handshake(
self.request = self.protocol.connect()
if additional_headers is not None:
self.request.headers.update(additional_headers)
if user_agent_header:
if user_agent_header is not None:
self.request.headers.setdefault("User-Agent", user_agent_header)
self.protocol.send_request(self.request)

Expand Down Expand Up @@ -363,10 +363,8 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:

self.proxy = proxy
self.protocol_factory = protocol_factory
self.handshake_args = (
additional_headers,
user_agent_header,
)
self.additional_headers = additional_headers
self.user_agent_header = user_agent_header
self.process_exception = process_exception
self.open_timeout = open_timeout
self.logger = logger
Expand Down Expand Up @@ -442,6 +440,7 @@ def factory() -> ClientConnection:
transport = await connect_http_proxy(
proxy_parsed,
ws_uri,
user_agent_header=self.user_agent_header,
**proxy_kwargs,
)
# Initialize WebSocket connection via the proxy.
Expand Down Expand Up @@ -541,7 +540,10 @@ async def __await_impl__(self) -> ClientConnection:
for _ in range(MAX_REDIRECTS):
self.connection = await self.create_connection()
try:
await self.connection.handshake(*self.handshake_args)
await self.connection.handshake(
self.additional_headers,
self.user_agent_header,
)
except asyncio.CancelledError:
self.connection.transport.abort()
raise
Expand Down Expand Up @@ -717,10 +719,16 @@ async def connect_socks_proxy(
raise ImportError("python-socks is required to use a SOCKS proxy")


def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
def prepare_connect_request(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = None,
) -> bytes:
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
headers = Headers()
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
if user_agent_header is not None:
headers["User-Agent"] = user_agent_header
if proxy.username is not None:
assert proxy.password is not None # enforced by parse_proxy()
headers["Proxy-Authorization"] = build_authorization_basic(
Expand All @@ -731,9 +739,15 @@ def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:


class HTTPProxyConnection(asyncio.Protocol):
def __init__(self, ws_uri: WebSocketURI, proxy: Proxy):
def __init__(
self,
ws_uri: WebSocketURI,
proxy: Proxy,
user_agent_header: str | None = None,
):
self.ws_uri = ws_uri
self.proxy = proxy
self.user_agent_header = user_agent_header

self.reader = StreamReader()
self.parser = Response.parse(
Expand Down Expand Up @@ -765,7 +779,9 @@ def run_parser(self) -> None:
def connection_made(self, transport: asyncio.BaseTransport) -> None:
transport = cast(asyncio.Transport, transport)
self.transport = transport
self.transport.write(prepare_connect_request(self.proxy, self.ws_uri))
self.transport.write(
prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header)
)

def data_received(self, data: bytes) -> None:
self.reader.feed_data(data)
Expand All @@ -784,10 +800,11 @@ def connection_lost(self, exc: Exception | None) -> None:
async def connect_http_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = None,
**kwargs: Any,
) -> asyncio.Transport:
transport, protocol = await asyncio.get_running_loop().create_connection(
lambda: HTTPProxyConnection(ws_uri, proxy),
lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header),
proxy.host,
proxy.port,
**kwargs,
Expand Down
12 changes: 10 additions & 2 deletions src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def connect(
proxy_parsed,
ws_uri,
deadline,
user_agent_header=user_agent_header,
ssl=proxy_ssl,
server_hostname=proxy_server_hostname,
**kwargs,
Expand Down Expand Up @@ -472,10 +473,16 @@ def connect_socks_proxy(
raise ImportError("python-socks is required to use a SOCKS proxy")


def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
def prepare_connect_request(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = None,
) -> bytes:
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
headers = Headers()
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
if user_agent_header is not None:
headers["User-Agent"] = user_agent_header
if proxy.username is not None:
assert proxy.password is not None # enforced by parse_proxy()
headers["Proxy-Authorization"] = build_authorization_basic(
Expand Down Expand Up @@ -524,6 +531,7 @@ def connect_http_proxy(
ws_uri: WebSocketURI,
deadline: Deadline,
*,
user_agent_header: str | None = None,
ssl: ssl_module.SSLContext | None = None,
server_hostname: str | None = None,
**kwargs: Any,
Expand All @@ -546,7 +554,7 @@ def connect_http_proxy(

# Send CONNECT request to the proxy and read response.

sock.sendall(prepare_connect_request(proxy, ws_uri))
sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header))
try:
read_connect_response(sock, deadline)
except Exception:
Expand Down
18 changes: 18 additions & 0 deletions tests/asyncio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,24 @@ async def test_authenticated_http_proxy_error(self):
)
self.assertNumFlows(0)

@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
async def test_http_proxy_override_user_agent(self):
"""Client can override User-Agent header with user_agent_header."""
async with serve(*args) as server:
async with connect(get_uri(server), user_agent_header="Smith") as client:
self.assertEqual(client.protocol.state.name, "OPEN")
[http_connect] = self.get_http_connects()
self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith")

@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
async def test_http_proxy_remove_user_agent(self):
"""Client can remove User-Agent header with user_agent_header."""
async with serve(*args) as server:
async with connect(get_uri(server), user_agent_header=None) as client:
self.assertEqual(client.protocol.state.name, "OPEN")
[http_connect] = self.get_http_connects()
self.assertNotIn(b"User-Agent", http_connect.request.headers)

@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
async def test_http_proxy_protocol_error(self):
"""Client receives invalid data when connecting to the HTTP proxy."""
Expand Down
38 changes: 26 additions & 12 deletions tests/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,26 @@
class RecordFlows:
def __init__(self, on_running):
self.running = on_running
self.flows = []
self.http_connects = []
self.tcp_flows = []

def http_connect(self, flow):
self.http_connects.append(flow)

def tcp_start(self, flow):
self.flows.append(flow)
self.tcp_flows.append(flow)

def get_http_connects(self):
http_connects, self.http_connects[:] = self.http_connects[:], []
return http_connects

def get_flows(self):
flows, self.flows[:] = self.flows[:], []
return flows
def get_tcp_flows(self):
tcp_flows, self.tcp_flows[:] = self.tcp_flows[:], []
return tcp_flows

def reset_flows(self):
self.flows = []
def reset(self):
self.http_connects = []
self.tcp_flows = []


class AlterRequest:
Expand Down Expand Up @@ -121,13 +130,18 @@ def setUpClass(cls):
cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
cls.proxy_context.load_verify_locations(bytes(certificate))

def assertNumFlows(self, num_flows):
record_flows = self.proxy_master.addons.get("recordflows")
self.assertEqual(len(record_flows.get_flows()), num_flows)
def get_http_connects(self):
return self.proxy_master.addons.get("recordflows").get_http_connects()

def get_tcp_flows(self):
return self.proxy_master.addons.get("recordflows").get_tcp_flows()

def assertNumFlows(self, num_tcp_flows):
self.assertEqual(len(self.get_tcp_flows()), num_tcp_flows)

def tearDown(self):
record_flows = self.proxy_master.addons.get("recordflows")
record_flows.reset_flows()
record_tcp_flows = self.proxy_master.addons.get("recordflows")
record_tcp_flows.reset()
super().tearDown()

@classmethod
Expand Down
18 changes: 18 additions & 0 deletions tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,24 @@ def test_authenticated_http_proxy_error(self):
)
self.assertNumFlows(0)

@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
def test_http_proxy_override_user_agent(self):
"""Client can override User-Agent header with user_agent_header."""
with run_server() as server:
with connect(get_uri(server), user_agent_header="Smith") as client:
self.assertEqual(client.protocol.state.name, "OPEN")
[http_connect] = self.get_http_connects()
self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith")

@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
def test_http_proxy_remove_user_agent(self):
"""Client can remove User-Agent header with user_agent_header."""
with run_server() as server:
with connect(get_uri(server), user_agent_header=None) as client:
self.assertEqual(client.protocol.state.name, "OPEN")
[http_connect] = self.get_http_connects()
self.assertNotIn(b"User-Agent", http_connect.request.headers)

@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
def test_http_proxy_protocol_error(self):
"""Client receives invalid data when connecting to the HTTP proxy."""
Expand Down

0 comments on commit 3dac6c4

Please sign in to comment.