diff --git a/aws_xray_sdk/core/async_context.py b/aws_xray_sdk/core/async_context.py new file mode 100644 index 00000000..c827868c --- /dev/null +++ b/aws_xray_sdk/core/async_context.py @@ -0,0 +1,98 @@ +import asyncio + +from .context import Context as _Context + + +class AsyncContext(_Context): + """ + Async Context for storing segments. + + Inherits nearly everything from the main Context class. + Replaces threading.local with a task based local storage class, + Also overrides clear_trace_entities + """ + def __init__(self, *args, loop=None, use_task_factory=True, **kwargs): + super(AsyncContext, self).__init__(*args, **kwargs) + + self._loop = loop + if loop is None: + self._loop = asyncio.get_event_loop() + + if use_task_factory: + self._loop.set_task_factory(task_factory) + + self._local = TaskLocalStorage(loop=loop) + + def clear_trace_entities(self): + """ + Clear all trace_entities stored in the task local context. + """ + if self._local is not None: + self._local.clear() + + +class TaskLocalStorage(object): + """ + Simple task local storage + """ + def __init__(self, loop=None): + if loop is None: + loop = asyncio.get_event_loop() + self._loop = loop + + def __setattr__(self, name, value): + if name in ('_loop',): + # Set normal attributes + object.__setattr__(self, name, value) + + else: + # Set task local attributes + task = asyncio.Task.current_task(loop=self._loop) + if task is None: + return None + + if not hasattr(task, 'context'): + task.context = {} + + task.context[name] = value + + def __getattribute__(self, item): + if item in ('_loop', 'clear'): + # Return references to local objects + return object.__getattribute__(self, item) + + task = asyncio.Task.current_task(loop=self._loop) + if task is None: + return None + + if hasattr(task, 'context') and item in task.context: + return task.context[item] + + raise AttributeError('Task context does not have attribute {0}'.format(item)) + + def clear(self): + # If were in a task, clear the context dictionary + task = asyncio.Task.current_task(loop=self._loop) + if task is not None and hasattr(task, 'context'): + task.context.clear() + + +def task_factory(loop, coro): + """ + Task factory function + + Fuction closely mirrors the logic inside of + asyncio.BaseEventLoop.create_task. Then if there is a current + task and the current task has a context then share that context + with the new task + """ + task = asyncio.Task(coro, loop=loop) + if task._source_traceback: # flake8: noqa + del task._source_traceback[-1] # flake8: noqa + + # Share context with new task if possible + current_task = asyncio.Task.current_task(loop=loop) + if current_task is not None and hasattr(current_task, 'context'): + setattr(task, 'context', current_task.context) + + return task diff --git a/aws_xray_sdk/core/sampling/sampling_rule.py b/aws_xray_sdk/core/sampling/sampling_rule.py index b6f0dd48..83224506 100644 --- a/aws_xray_sdk/core/sampling/sampling_rule.py +++ b/aws_xray_sdk/core/sampling/sampling_rule.py @@ -40,7 +40,7 @@ def applies(self, service_name, method, path): the incoming request based on some of the request's parameters. Any None parameters provided will be considered an implicit match. """ - return (not service_name or wildcard_match(self.service_name, service_name)) \ + return (not service_name or wildcard_match(self.service_name, service_name)) \ and (not method or wildcard_match(self.service_name, method)) \ and (not path or wildcard_match(self.path, path)) @@ -89,11 +89,14 @@ def reservoir(self): def _validate(self): if self.fixed_target < 0 or self.rate < 0: - raise InvalidSamplingManifestError('All rules must have non-negative values for fixed_target and rate') + raise InvalidSamplingManifestError('All rules must have non-negative values for ' + 'fixed_target and rate') if self._default: if self.service_name or self.method or self.path: - raise InvalidSamplingManifestError('The default rule must not specify values for url_path, service_name, or http_method') + raise InvalidSamplingManifestError('The default rule must not specify values for ' + 'url_path, service_name, or http_method') else: if not self.service_name or not self.method or not self.path: - raise InvalidSamplingManifestError('All non-default rules must have values for url_path, service_name, and http_method') + raise InvalidSamplingManifestError('All non-default rules must have values for ' + 'url_path, service_name, and http_method') diff --git a/aws_xray_sdk/core/utils/compat.py b/aws_xray_sdk/core/utils/compat.py index 546d8f51..1d60f882 100644 --- a/aws_xray_sdk/core/utils/compat.py +++ b/aws_xray_sdk/core/utils/compat.py @@ -4,8 +4,8 @@ PY2 = sys.version_info < (3,) if PY2: - annotation_value_types = (int, long, float, bool, str) - string_types = basestring + annotation_value_types = (int, long, float, bool, str) # noqa: F821 + string_types = basestring # noqa: F821 else: annotation_value_types = (int, float, bool, str) string_types = str diff --git a/aws_xray_sdk/ext/aiohttp/__init__.py b/aws_xray_sdk/ext/aiohttp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aws_xray_sdk/ext/aiohttp/middleware.py b/aws_xray_sdk/ext/aiohttp/middleware.py new file mode 100644 index 00000000..47c7248f --- /dev/null +++ b/aws_xray_sdk/ext/aiohttp/middleware.py @@ -0,0 +1,78 @@ +""" +AioHttp Middleware +""" +import traceback + +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.models import http +from aws_xray_sdk.ext.util import calculate_sampling_decision, calculate_segment_name, construct_xray_header + + +async def middleware(app, handler): + """ + AioHttp Middleware Factory + """ + async def _middleware(request): + """ + Main middleware function, deals with all the X-Ray segment logic + """ + # Create X-Ray headers + xray_header = construct_xray_header(request.headers) + # Get name of service or generate a dynamic one from host + name = calculate_segment_name(request.headers['host'].split(':', 1)[0], xray_recorder) + + sampling_decision = calculate_sampling_decision( + trace_header=xray_header, + recorder=xray_recorder, + service_name=request.headers['host'], + method=request.method, + path=request.path, + ) + + # Start a segment + segment = xray_recorder.begin_segment( + name=name, + traceid=xray_header.root, + parent_id=xray_header.parent, + sampling=sampling_decision, + ) + + # Store request metadata in the current segment + segment.put_http_meta(http.URL, request.url) + segment.put_http_meta(http.METHOD, request.method) + + if 'User-Agent' in request.headers: + segment.put_http_meta(http.USER_AGENT, request.headers['User-Agent']) + + if 'X-Forwarded-For' in request.headers: + segment.put_http_meta(http.CLIENT_IP, request.headers['X-Forwarded-For']) + segment.put_http_meta(http.X_FORWARDED_FOR, True) + elif 'remote_addr' in request.headers: + segment.put_http_meta(http.CLIENT_IP, request.headers['remote_addr']) + else: + segment.put_http_meta(http.CLIENT_IP, request.remote) + + try: + # Call next middleware or request handler + response = await handler(request) + except Exception as err: + # Store exception information including the stacktrace to the segment + segment = xray_recorder.current_segment() + segment.put_http_meta(http.STATUS, 500) + stack = traceback.extract_stack(limit=xray_recorder._max_trace_back) + segment.add_exception(err, stack) + xray_recorder.end_segment() + raise + + # Store response metadata into the current segment + segment.put_http_meta(http.STATUS, response.status) + + if 'Content-Length' in response.headers: + length = int(response.headers['Content-Length']) + segment.put_http_meta(http.CONTENT_LENGTH, length) + + # Close segment so it can be dispatched off to the daemon + xray_recorder.end_segment() + + return response + return _middleware diff --git a/docs/conf.py b/docs/conf.py index bb2101d5..59ac36cf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,9 +31,9 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.coverage'] + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.coverage'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -171,7 +171,5 @@ ] - - # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'https://docs.python.org/': None} diff --git a/docs/frameworks.rst b/docs/frameworks.rst index 0124d9be..af00a01d 100644 --- a/docs/frameworks.rst +++ b/docs/frameworks.rst @@ -81,4 +81,36 @@ To generate segment based on incoming requests, you need to instantiate the X-Ra XRayMiddleware(app, xray_recorder) Flask built-in template rendering will be wrapped into subsegments. -You can configure the recorder, see :ref:`Configure Global Recorder ` for more details. \ No newline at end of file +You can configure the recorder, see :ref:`Configure Global Recorder ` for more details. + +aiohttp Server +============== + +For X-Ray to create a segment based on an incoming request, you need register some middleware with aiohttp. As aiohttp +is an asyncronous framework, X-Ray will also need to be configured with an ``AsyncContext`` compared to the default threadded +version.:: + + import asyncio + + from aiohttp import web + + from aws_xray_sdk.ext.aiohttp.middleware import middleware + from aws_xray_sdk.core.async_context import AsyncContext + from aws_xray_sdk.core import xray_recorder + # Configure X-Ray to use AsyncContext + xray_recorder.configure(service='service_name', context=AsyncContext()) + + + async def handler(request): + return web.Response(body='Hello World') + + loop = asyncio.get_event_loop() + # Use X-Ray SDK middleware, its crucial the X-Ray middleware comes first + app = web.Application(middlewares=[middleware]) + app.router.add_get("/", handler) + + web.run_app(app) + +There are two things to note from the example above. Firstly a middleware corountine from aws-xray-sdk is provided during the creation +of an aiohttp server app. Lastly the ``xray_recorder`` has also been configured with a name and an ``AsyncContext``. See +:ref:`Configure Global Recorder ` for more information about configuring the ``xray_recorder``. diff --git a/setup.py b/setup.py index d4e65870..5fb5719b 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,8 @@ name='aws-xray-sdk', version='0.93', - description='The AWS X-Ray SDK for Python (the SDK) enables Python developers to record and emit information from within their applications to the AWS X-Ray service.', + description='The AWS X-Ray SDK for Python (the SDK) enables Python developers to record' + ' and emit information from within their applications to the AWS X-Ray service.', long_description=long_description, url='https://github.com/aws/aws-xray-sdk-python', diff --git a/tests/ext/aiohttp/__init__.py b/tests/ext/aiohttp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/aiohttp/test_aiohttp.py b/tests/ext/aiohttp/test_aiohttp.py new file mode 100644 index 00000000..d6e694f0 --- /dev/null +++ b/tests/ext/aiohttp/test_aiohttp.py @@ -0,0 +1,202 @@ +""" +Tests the middleware for aiohttp server + +Expects pytest-aiohttp +""" +import asyncio +from unittest.mock import patch + +from aiohttp import web +import pytest + +from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter +from aws_xray_sdk.core.async_context import AsyncContext +from tests.util import get_new_stubbed_recorder +from aws_xray_sdk.ext.aiohttp.middleware import middleware + + +class CustomStubbedEmitter(UDPEmitter): + """ + Custom stubbed emitter which stores all segments instead of the last one + """ + def __init__(self, daemon_address='127.0.0.1:2000'): + super(CustomStubbedEmitter, self).__init__(daemon_address) + self.local = [] + + def send_entity(self, entity): + self.local.append(entity) + + def pop(self): + try: + return self.local.pop(0) + except IndexError: + return None + + +class TestServer(object): + """ + Simple class to hold a copy of the event loop + """ + def __init__(self, loop): + self._loop = loop + + async def handle_ok(self, request: web.Request) -> web.Response: + """ + Handle / request + """ + return web.Response(text="ok") + + async def handle_error(self, request: web.Request) -> web.Response: + """ + Handle /error which returns a 404 + """ + return web.Response(text="not found", status=404) + + async def handle_exception(self, request: web.Request) -> web.Response: + """ + Handle /exception which raises a KeyError + """ + return {}['key'] + + async def handle_delay(self, request: web.Request) -> web.Response: + """ + Handle /delay request + """ + await asyncio.sleep(0.3, loop=self._loop) + return web.Response(text="ok") + + def get_app(self) -> web.Application: + app = web.Application(middlewares=[middleware]) + app.router.add_get('/', self.handle_ok) + app.router.add_get('/error', self.handle_error) + app.router.add_get('/exception', self.handle_exception) + app.router.add_get('/delay', self.handle_delay) + + return app + + @classmethod + def app(cls, loop=None) -> web.Application: + return cls(loop=loop).get_app() + + +@pytest.fixture(scope='function') +def recorder(loop): + """ + Clean up context storage before and after each test run + """ + xray_recorder = get_new_stubbed_recorder() + xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop)) + + patcher = patch('aws_xray_sdk.ext.aiohttp.middleware.xray_recorder', xray_recorder) + patcher.start() + + xray_recorder.clear_trace_entities() + yield xray_recorder + xray_recorder.clear_trace_entities() + patcher.stop() + + +async def test_ok(test_client, loop, recorder): + """ + Test a normal response + + :param test_client: AioHttp test client fixture + :param loop: Eventloop fixture + :param recorder: X-Ray recorder fixture + """ + client = await test_client(TestServer.app(loop=loop)) + + resp = await client.get('/') + assert resp.status == 200 + + segment = recorder.emitter.pop() + assert not segment.in_progress + + request = segment.http['request'] + response = segment.http['response'] + + assert request['method'] == 'GET' + assert str(request['url']).startswith('http://127.0.0.1') + assert request['url'].host == '127.0.0.1' + assert request['url'].path == '/' + assert response['status'] == 200 + + +async def test_error(test_client, loop, recorder): + """ + Test a 4XX response + + :param test_client: AioHttp test client fixture + :param loop: Eventloop fixture + :param recorder: X-Ray recorder fixture + """ + client = await test_client(TestServer.app(loop=loop)) + + resp = await client.get('/error') + assert resp.status == 404 + + segment = recorder.emitter.pop() + assert not segment.in_progress + assert segment.error + + request = segment.http['request'] + response = segment.http['response'] + assert request['method'] == 'GET' + assert request['url'].host == '127.0.0.1' + assert request['url'].path == '/error' + assert request['client_ip'] == '127.0.0.1' + assert response['status'] == 404 + + +async def test_exception(test_client, loop, recorder): + """ + Test handling an exception + + :param test_client: AioHttp test client fixture + :param loop: Eventloop fixture + :param recorder: X-Ray recorder fixture + """ + client = await test_client(TestServer.app(loop=loop)) + + resp = await client.get('/exception') + await resp.text() # Need this to trigger Exception + + segment = recorder.emitter.pop() + assert not segment.in_progress + assert segment.fault + + request = segment.http['request'] + response = segment.http['response'] + exception = segment.cause['exceptions'][0] + assert request['method'] == 'GET' + assert request['url'].host == '127.0.0.1' + assert request['url'].path == '/exception' + assert request['client_ip'] == '127.0.0.1' + assert response['status'] == 500 + assert exception.type == 'KeyError' + + +async def test_concurrent(test_client, loop, recorder): + """ + Test multiple concurrent requests + + :param test_client: AioHttp test client fixture + :param loop: Eventloop fixture + :param recorder: X-Ray recorder fixture + """ + client = await test_client(TestServer.app(loop=loop)) + + recorder.emitter = CustomStubbedEmitter() + + async def get_delay(): + resp = await client.get('/delay') + assert resp.status == 200 + + await asyncio.wait([get_delay(), get_delay(), get_delay(), + get_delay(), get_delay(), get_delay(), + get_delay(), get_delay(), get_delay()], + loop=loop) + + # Ensure all ID's are different + ids = [item.id for item in recorder.emitter.local] + assert len(ids) == len(set(ids)) diff --git a/tests/test_async_local_storage.py b/tests/test_async_local_storage.py new file mode 100644 index 00000000..b43cc0ec --- /dev/null +++ b/tests/test_async_local_storage.py @@ -0,0 +1,38 @@ +import asyncio +import random + +from aws_xray_sdk.core.async_context import TaskLocalStorage + + +def test_localstorage_isolation(loop): + local_storage = TaskLocalStorage(loop=loop) + + async def _test(): + """ + Compute a random number + Store it in task local storage + Suspend task so another can run + Retrieve random number from task local storage + Compare that to the local variable + """ + try: + random_int = random.random() + local_storage.randint = random_int + + await asyncio.sleep(0.0, loop=loop) + + current_random_int = local_storage.randint + assert random_int == current_random_int + + return True + except: + return False + + # Run loads of concurrent tasks + results = loop.run_until_complete( + asyncio.wait([_test() for _ in range(0, 100)], loop=loop) + ) + results = [item.result() for item in results[0]] + + # Double check all is good + assert all(results) diff --git a/tox.ini b/tox.ini index 297a1625..c206ceb9 100644 --- a/tox.ini +++ b/tox.ini @@ -14,9 +14,13 @@ deps = flask >= 0.10 # the sdk dosen't support earlier version of django django >= 1.10 + # Python3.5+ only deps + py{35,36}: aiohttp >= 2.3.0 + py{35,36}: pytest-aiohttp commands = - coverage run --source aws_xray_sdk -m py.test tests + py{27,34}: coverage run --source aws_xray_sdk -m py.test tests --ignore tests/ext/aiohttp --ignore tests/test_async_local_storage.py + py{35,36}: coverage run --source aws_xray_sdk -m py.test tests setenv = DJANGO_SETTINGS_MODULE = tests.ext.django.app.settings