Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make request.scheme return ws/wss for WS even when http/https in SERVER_NAME or proxy headers #2854

Merged
merged 9 commits into from
Dec 7, 2023
3 changes: 3 additions & 0 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,9 @@ def url_for(self, view_name: str, **kwargs):
scheme = netloc[:8].split(":", 1)[0]
else:
scheme = "http"
# Replace http/https with ws/wss for WebSocket handlers
if route.extra.websocket:
scheme = scheme.replace("http", "ws")

if "://" in netloc[:8]:
netloc = netloc.split("://", 1)[-1]
Expand Down
22 changes: 13 additions & 9 deletions sanic/request/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,20 +964,23 @@ def scheme(self) -> str:
str: http|https|ws|wss or arbitrary value given by the headers.
"""
if not hasattr(self, "_scheme"):
if "//" in self.app.config.get("SERVER_NAME", ""):
return self.app.config.SERVER_NAME.split("//")[0]
if "proto" in self.forwarded:
return str(self.forwarded["proto"])

if (
self.app.websocket_enabled
and self.headers.getone("upgrade", "").lower() == "websocket"
and self.headers.upgrade.lower() == "websocket"
):
scheme = "ws"
else:
scheme = "http"

if self.transport.get_extra_info("sslcontext"):
proto = None
sp = self.app.config.get("SERVER_NAME", "").split("://", 1)
if len(sp) == 2:
proto = sp[0]
elif "proto" in self.forwarded:
proto = str(self.forwarded["proto"])
if proto:
# Give ws/wss if websocket, otherwise keep the same
scheme = proto.replace("http", scheme)
elif self.conn_info and self.conn_info.ssl:
scheme += "s"
self._scheme = scheme

Expand Down Expand Up @@ -1072,7 +1075,8 @@ def url_for(self, view_name: str, **kwargs) -> str:
"""
# Full URL SERVER_NAME can only be handled in app.url_for
try:
if "//" in self.app.config.SERVER_NAME:
sp = self.app.config.get("SERVER_NAME", "").split("://", 1)
if len(sp) == 2:
return self.app.url_for(view_name, _external=True, **kwargs)
except AttributeError:
pass
Expand Down
108 changes: 66 additions & 42 deletions tests/test_ws_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from sanic import Request, Sanic, Websocket


MimicClientType = Callable[
[WebSocketClientProtocol], Coroutine[None, None, Any]
]
MimicClientType = Callable[[WebSocketClientProtocol], Coroutine[None, None, Any]]


@pytest.fixture
Expand All @@ -23,39 +21,6 @@ async def client_mimic(ws: WebSocketClientProtocol):
return client_mimic


def test_ws_handler(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
while True:
msg = await ws.recv()
await ws.send(msg)

_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


def test_ws_handler_async_for(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
async for msg in ws:
await ws.send(msg)

_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


def signalapp(app):
@app.signal("websocket.handler.before")
async def ws_before(request: Request, websocket: Websocket):
Expand Down Expand Up @@ -90,16 +55,77 @@ async def ws_error(request: Request, ws: Websocket):
print("wserr2")


def test_ws_handler(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
while True:
msg = await ws.recv()
await ws.send(msg)

_, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


def test_ws_handler_async_for(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
async for msg in ws:
await ws.send(msg)

_, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


@pytest.mark.parametrize("proxy", ["", "proxy", "servername"])
def test_request_url(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
proxy: str,
):
@app.websocket("/ws")
async def ws_url_handler(request: Request, ws: Websocket):
request.headers[
"forwarded"
] = "for=[2001:db8::1];proto=https;host=example.com;by=proxy"

await ws.recv()
await ws.send(request.url)
await ws.recv()
await ws.send(request.url_for("ws_url_handler"))
await ws.recv()

app.config.FORWARDED_SECRET = proxy
app.config.SERVER_NAME = "https://example.com" if proxy == "servername" else ""
_, ws_proxy = app.test_client.websocket(
"/ws",
mimic=simple_ws_mimic_client,
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received[0] == ws_proxy.client_received[1]
if proxy:
assert ws_proxy.client_received[0] == "wss://example.com/ws"
assert ws_proxy.client_received[1] == "wss://example.com/ws"
else:
assert ws_proxy.client_received[0].startswith("ws://127.0.0.1")
assert ws_proxy.client_received[1].startswith("ws://127.0.0.1")


def test_ws_signals(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
signalapp(app)

app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
_, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client)
assert ws_proxy.client_received == ["before: test 1", "after: test 2"]
assert app.ctx.seq == ["before", "ws", "after"]

Expand All @@ -111,8 +137,6 @@ def test_ws_signals_exception(
signalapp(app)

app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/wserror", mimic=simple_ws_mimic_client
)
_, ws_proxy = app.test_client.websocket("/wserror", mimic=simple_ws_mimic_client)
assert ws_proxy.client_received == ["before: test 1", "exception: test 2"]
assert app.ctx.seq == ["before", "wserror", "exception"]
Loading