Skip to content

Commit

Permalink
Websockets refactoring (#2836)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Mar 15, 2018
1 parent 9cdf32c commit 89abb89
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 53 deletions.
2 changes: 2 additions & 0 deletions CHANGES/2836.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Websockets refactoring, all websocket writer methods are converted
into coroutines.
9 changes: 6 additions & 3 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, reader, writer, protocol,
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat/2.0
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._loop = loop
self._waiting = None
Expand Down Expand Up @@ -61,7 +61,10 @@ def _reset_heartbeat(self):

def _send_heartbeat(self):
if self._heartbeat is not None and not self._closed:
self._writer.ping()
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
Expand Down Expand Up @@ -137,7 +140,7 @@ async def close(self, *, code=1000, message=b''):
self._cancel_heartbeat()
self._closed = True
try:
self._writer.close(code, message)
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
Expand Down
26 changes: 12 additions & 14 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import IntEnum
from struct import Struct

from .helpers import NO_EXTENSIONS, noop
from .helpers import NO_EXTENSIONS
from .log import ws_logger


Expand Down Expand Up @@ -527,7 +527,7 @@ def __init__(self, protocol, transport, *,
self._output_size = 0
self._compressobj = None

def _send_frame(self, message, opcode, compress=None):
async def _send_frame(self, message, opcode, compress=None):
"""Send a frame over the websocket with message as its payload."""
if self._closing:
ws_logger.warning('websocket connection is closing.')
Expand Down Expand Up @@ -585,37 +585,35 @@ def _send_frame(self, message, opcode, compress=None):

if self._output_size > self._limit:
self._output_size = 0
return self.protocol._drain_helper()
await self.protocol._drain_helper()

return noop()

def pong(self, message=b''):
async def pong(self, message=b''):
"""Send pong message."""
if isinstance(message, str):
message = message.encode('utf-8')
return self._send_frame(message, WSMsgType.PONG)
return await self._send_frame(message, WSMsgType.PONG)

def ping(self, message=b''):
async def ping(self, message=b''):
"""Send ping message."""
if isinstance(message, str):
message = message.encode('utf-8')
return self._send_frame(message, WSMsgType.PING)
return await self._send_frame(message, WSMsgType.PING)

def send(self, message, binary=False, compress=None):
async def send(self, message, binary=False, compress=None):
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode('utf-8')
if binary:
return self._send_frame(message, WSMsgType.BINARY, compress)
return await self._send_frame(message, WSMsgType.BINARY, compress)
else:
return self._send_frame(message, WSMsgType.TEXT, compress)
return await self._send_frame(message, WSMsgType.TEXT, compress)

def close(self, code=1000, message=b''):
async def close(self, code=1000, message=b''):
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode('utf-8')
try:
return self._send_frame(
return await self._send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
finally:
self._closing = True
9 changes: 6 additions & 3 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *,
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat/2.0
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._compress = compress

Expand All @@ -80,7 +80,10 @@ def _reset_heartbeat(self):

def _send_heartbeat(self):
if self._heartbeat is not None and not self._closed:
self._writer.ping()
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
Expand Down Expand Up @@ -286,7 +289,7 @@ async def close(self, *, code=1000, message=b''):
if not self._closed:
self._closed = True
try:
self._writer.close(code, message)
await self._writer.close(code, message)
await self._payload_writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
Expand Down
13 changes: 10 additions & 3 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ async def test_close(loop, ws_key, key_data):
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)
writer = WebSocketWriter.return_value = mock.Mock()
writer = mock.Mock()
WebSocketWriter.return_value = writer
writer.close = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect(
Expand Down Expand Up @@ -280,7 +282,9 @@ async def test_close_exc(loop, ws_key, key_data):
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)
WebSocketWriter.return_value = mock.Mock()
writer = mock.Mock()
WebSocketWriter.return_value = writer
writer.close = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect('http://test.org')
Expand Down Expand Up @@ -400,7 +404,10 @@ async def test_reader_read_exception(ws_key, key_data, loop):
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(hresp)
WebSocketWriter.return_value = mock.Mock()

writer = mock.Mock()
WebSocketWriter.return_value = writer
writer.close = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect('http://test.org')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ async def handler(request):

client = await aiohttp_client(app)
resp = await client.ws_connect('/', heartbeat=0.01)

await asyncio.sleep(0.1)
await resp.receive()
await resp.close()

Expand Down
61 changes: 32 additions & 29 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import pytest

from aiohttp.http import WebSocketWriter
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
def protocol():
return mock.Mock()
ret = mock.Mock()
ret._drain_helper = make_mocked_coro()
return ret


@pytest.fixture
Expand All @@ -21,83 +24,83 @@ def writer(protocol, transport):
return WebSocketWriter(protocol, transport, use_mask=False)


def test_pong(writer):
writer.pong()
async def test_pong(writer):
await writer.pong()
writer.transport.write.assert_called_with(b'\x8a\x00')


def test_ping(writer):
writer.ping()
async def test_ping(writer):
await writer.ping()
writer.transport.write.assert_called_with(b'\x89\x00')


def test_send_text(writer):
writer.send(b'text')
async def test_send_text(writer):
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\x81\x04text')


def test_send_binary(writer):
writer.send('binary', True)
async def test_send_binary(writer):
await writer.send('binary', True)
writer.transport.write.assert_called_with(b'\x82\x06binary')


def test_send_binary_long(writer):
writer.send(b'b' * 127, True)
async def test_send_binary_long(writer):
await writer.send(b'b' * 127, True)
assert writer.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')


def test_send_binary_very_long(writer):
writer.send(b'b' * 65537, True)
async def test_send_binary_very_long(writer):
await writer.send(b'b' * 65537, True)
assert (writer.transport.write.call_args_list[0][0][0] ==
b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01')
assert writer.transport.write.call_args_list[1][0][0] == b'b' * 65537


def test_close(writer):
writer.close(1001, 'msg')
async def test_close(writer):
await writer.close(1001, 'msg')
writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')

writer.close(1001, b'msg')
await writer.close(1001, b'msg')
writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')

# Test that Service Restart close code is also supported
writer.close(1012, b'msg')
await writer.close(1012, b'msg')
writer.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg')


def test_send_text_masked(protocol, transport):
async def test_send_text_masked(protocol, transport):
writer = WebSocketWriter(protocol,
transport,
use_mask=True,
random=random.Random(123))
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12')


def test_send_compress_text(protocol, transport):
async def test_send_compress_text(protocol, transport):
writer = WebSocketWriter(protocol, transport, compress=15)
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00')


def test_send_compress_text_notakeover(protocol, transport):
async def test_send_compress_text_notakeover(protocol, transport):
writer = WebSocketWriter(protocol,
transport,
compress=15,
notakeover=True)
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')


def test_send_compress_text_per_message(protocol, transport):
async def test_send_compress_text_per_message(protocol, transport):
writer = WebSocketWriter(protocol, transport)
writer.send(b'text', compress=15)
await writer.send(b'text', compress=15)
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\x81\x04text')
writer.send(b'text', compress=15)
await writer.send(b'text', compress=15)
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')

0 comments on commit 89abb89

Please sign in to comment.