Skip to content

Commit

Permalink
Refactor workers (#3471)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Dec 30, 2018
1 parent 856dba5 commit 6dfb03d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 64 deletions.
2 changes: 2 additions & 0 deletions CHANGES/3471.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Use the same task for app initialization and web server handling in gunicorn workers.
It allows to use Python3.7 context vars smoothly.
42 changes: 23 additions & 19 deletions aiohttp/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import signal
import sys
from types import FrameType
from typing import Any, Optional # noqa
from typing import Any, Awaitable, Callable, Optional, Union # noqa

from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
from gunicorn.workers import base

from aiohttp import web

from .helpers import set_result
from .web_app import Application
from .web_log import AccessLogger

try:
Expand All @@ -37,7 +38,6 @@ class GunicornWebWorker(base.Worker):
def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
super().__init__(*args, **kw)

self._runner = None # type: Optional[web.AppRunner]
self._task = None # type: Optional[asyncio.Task[None]]
self.exit_code = 0
self._notify_waiter = None # type: Optional[asyncio.Future[bool]]
Expand All @@ -52,35 +52,39 @@ def init_process(self) -> None:
super().init_process()

def run(self) -> None:
access_log = self.log.access_log if self.cfg.accesslog else None
params = dict(
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
if asyncio.iscoroutinefunction(self.wsgi): # type: ignore
self.wsgi = self.loop.run_until_complete(
self.wsgi()) # type: ignore
self._runner = web.AppRunner(self.wsgi, **params)
self.loop.run_until_complete(self._runner.setup())
self._task = self.loop.create_task(self._run())

try: # ignore all finalization problems
self.loop.run_until_complete(self._task)
except Exception as error:
self.log.exception(error)
except Exception:
self.log.exception("Exception in gunicorn worker")
if sys.version_info >= (3, 6):
if hasattr(self.loop, 'shutdown_asyncgens'):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()

sys.exit(self.exit_code)

async def _run(self) -> None:
if isinstance(self.wsgi, Application):
app = self.wsgi
elif asyncio.iscoroutinefunction(self.wsgi):
app = await self.wsgi()
else:
raise RuntimeError("wsgi app should be either Application or "
"async function returning Application, got {}"
.format(self.wsgi))
access_log = self.log.access_log if self.cfg.accesslog else None
runner = web.AppRunner(app,
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
await runner.setup()

ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None

runner = self._runner
runner = runner
assert runner is not None
server = runner.server
assert server is not None
Expand Down
70 changes: 25 additions & 45 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

from aiohttp import web
from aiohttp.test_utils import make_mocked_coro

base_worker = pytest.importorskip('aiohttp.worker')

Expand Down Expand Up @@ -42,13 +41,15 @@ def __init__(self):
self.wsgi = web.Application()


class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): # type: ignore # noqa
class AsyncioWorker(BaseTestWorker, # type: ignore
base_worker.GunicornWebWorker):
pass


PARAMS = [AsyncioWorker]
if uvloop is not None:
class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): # type: ignore # noqa
class UvloopWorker(BaseTestWorker, # type: ignore
base_worker.GunicornUVLoopWebWorker):
pass

PARAMS.append(UvloopWorker)
Expand Down Expand Up @@ -78,30 +79,47 @@ def test_run(worker, loop) -> None:
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.is_ssl = False
worker.sockets = []

worker.loop = loop
worker._run = make_mocked_coro(None)
with pytest.raises(SystemExit):
worker.run()
assert worker._run.called
worker.log.exception.assert_not_called()
assert loop.is_closed()


def test_run_async_factory(worker, loop) -> None:
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.is_ssl = False
worker.sockets = []
app = worker.wsgi

async def make_app():
return app
worker.wsgi = make_app

worker.loop = loop
worker._run = make_mocked_coro(None)
worker.alive = False
with pytest.raises(SystemExit):
worker.run()
worker.log.exception.assert_not_called()
assert loop.is_closed()


def test_run_not_app(worker, loop) -> None:
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT

worker.loop = loop
worker.wsgi = "not-app"
worker.alive = False
with pytest.raises(SystemExit):
worker.run()
assert worker._run.called
worker.log.exception.assert_called_with('Exception in gunicorn worker')
assert loop.is_closed()


Expand Down Expand Up @@ -197,15 +215,11 @@ async def test__run_ok_parent_changed(worker, loop,
worker.cfg.max_requests = 0
worker.cfg.is_ssl = False

worker._runner = web.AppRunner(worker.wsgi)
await worker._runner.setup()

await worker._run()

worker.notify.assert_called_with()
worker.log.info.assert_called_with("Parent changed, shutting down: %s",
worker)
assert worker._runner.server is None


async def test__run_exc(worker, loop, aiohttp_unused_port) -> None:
Expand All @@ -223,9 +237,6 @@ async def test__run_exc(worker, loop, aiohttp_unused_port) -> None:
worker.cfg.max_requests = 0
worker.cfg.is_ssl = False

worker._runner = web.AppRunner(worker.wsgi)
await worker._runner.setup()

def raiser():
waiter = worker._notify_waiter
worker.alive = False
Expand All @@ -235,37 +246,6 @@ def raiser():
await worker._run()

worker.notify.assert_called_with()
assert worker._runner.server is None


async def test__run_ok_max_requests_exceeded(worker, loop,
aiohttp_unused_port):
skip_if_no_dict(loop)

worker.ppid = os.getppid()
worker.alive = True
worker.servers = {}
sock = socket.socket()
addr = ('localhost', aiohttp_unused_port())
sock.bind(addr)
worker.sockets = [sock]
worker.log = mock.Mock()
worker.loop = loop
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.max_requests = 10
worker.cfg.is_ssl = False

worker._runner = web.AppRunner(worker.wsgi)
await worker._runner.setup()
worker._runner.server.requests_count = 30

await worker._run()

worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests, shutting down: %s",
worker)

assert worker._runner.server is None


def test__create_ssl_context_without_certs_and_ciphers(worker) -> None:
Expand Down

0 comments on commit 6dfb03d

Please sign in to comment.