-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from terrycain/aiohttp_middleware
AioHttp server middleware
- Loading branch information
Showing
12 changed files
with
468 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import asyncio | ||
|
||
from .context import Context as _Context | ||
|
||
|
||
class AsyncContext(_Context): | ||
""" | ||
Async Context for storing segments. | ||
Inherits nearly everything from the main Context class. | ||
Replaces threading.local with a task based local storage class, | ||
Also overrides clear_trace_entities | ||
""" | ||
def __init__(self, *args, loop=None, use_task_factory=True, **kwargs): | ||
super(AsyncContext, self).__init__(*args, **kwargs) | ||
|
||
self._loop = loop | ||
if loop is None: | ||
self._loop = asyncio.get_event_loop() | ||
|
||
if use_task_factory: | ||
self._loop.set_task_factory(task_factory) | ||
|
||
self._local = TaskLocalStorage(loop=loop) | ||
|
||
def clear_trace_entities(self): | ||
""" | ||
Clear all trace_entities stored in the task local context. | ||
""" | ||
if self._local is not None: | ||
self._local.clear() | ||
|
||
|
||
class TaskLocalStorage(object): | ||
""" | ||
Simple task local storage | ||
""" | ||
def __init__(self, loop=None): | ||
if loop is None: | ||
loop = asyncio.get_event_loop() | ||
self._loop = loop | ||
|
||
def __setattr__(self, name, value): | ||
if name in ('_loop',): | ||
# Set normal attributes | ||
object.__setattr__(self, name, value) | ||
|
||
else: | ||
# Set task local attributes | ||
task = asyncio.Task.current_task(loop=self._loop) | ||
if task is None: | ||
return None | ||
|
||
if not hasattr(task, 'context'): | ||
task.context = {} | ||
|
||
task.context[name] = value | ||
|
||
def __getattribute__(self, item): | ||
if item in ('_loop', 'clear'): | ||
# Return references to local objects | ||
return object.__getattribute__(self, item) | ||
|
||
task = asyncio.Task.current_task(loop=self._loop) | ||
if task is None: | ||
return None | ||
|
||
if hasattr(task, 'context') and item in task.context: | ||
return task.context[item] | ||
|
||
raise AttributeError('Task context does not have attribute {0}'.format(item)) | ||
|
||
def clear(self): | ||
# If were in a task, clear the context dictionary | ||
task = asyncio.Task.current_task(loop=self._loop) | ||
if task is not None and hasattr(task, 'context'): | ||
task.context.clear() | ||
|
||
|
||
def task_factory(loop, coro): | ||
""" | ||
Task factory function | ||
Fuction closely mirrors the logic inside of | ||
asyncio.BaseEventLoop.create_task. Then if there is a current | ||
task and the current task has a context then share that context | ||
with the new task | ||
""" | ||
task = asyncio.Task(coro, loop=loop) | ||
if task._source_traceback: # flake8: noqa | ||
del task._source_traceback[-1] # flake8: noqa | ||
|
||
# Share context with new task if possible | ||
current_task = asyncio.Task.current_task(loop=loop) | ||
if current_task is not None and hasattr(current_task, 'context'): | ||
setattr(task, 'context', current_task.context) | ||
|
||
return task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
AioHttp Middleware | ||
""" | ||
import traceback | ||
|
||
from aws_xray_sdk.core import xray_recorder | ||
from aws_xray_sdk.core.models import http | ||
from aws_xray_sdk.ext.util import calculate_sampling_decision, calculate_segment_name, construct_xray_header | ||
|
||
|
||
async def middleware(app, handler): | ||
""" | ||
AioHttp Middleware Factory | ||
""" | ||
async def _middleware(request): | ||
""" | ||
Main middleware function, deals with all the X-Ray segment logic | ||
""" | ||
# Create X-Ray headers | ||
xray_header = construct_xray_header(request.headers) | ||
# Get name of service or generate a dynamic one from host | ||
name = calculate_segment_name(request.headers['host'].split(':', 1)[0], xray_recorder) | ||
|
||
sampling_decision = calculate_sampling_decision( | ||
trace_header=xray_header, | ||
recorder=xray_recorder, | ||
service_name=request.headers['host'], | ||
method=request.method, | ||
path=request.path, | ||
) | ||
|
||
# Start a segment | ||
segment = xray_recorder.begin_segment( | ||
name=name, | ||
traceid=xray_header.root, | ||
parent_id=xray_header.parent, | ||
sampling=sampling_decision, | ||
) | ||
|
||
# Store request metadata in the current segment | ||
segment.put_http_meta(http.URL, request.url) | ||
segment.put_http_meta(http.METHOD, request.method) | ||
|
||
if 'User-Agent' in request.headers: | ||
segment.put_http_meta(http.USER_AGENT, request.headers['User-Agent']) | ||
|
||
if 'X-Forwarded-For' in request.headers: | ||
segment.put_http_meta(http.CLIENT_IP, request.headers['X-Forwarded-For']) | ||
segment.put_http_meta(http.X_FORWARDED_FOR, True) | ||
elif 'remote_addr' in request.headers: | ||
segment.put_http_meta(http.CLIENT_IP, request.headers['remote_addr']) | ||
else: | ||
segment.put_http_meta(http.CLIENT_IP, request.remote) | ||
|
||
try: | ||
# Call next middleware or request handler | ||
response = await handler(request) | ||
except Exception as err: | ||
# Store exception information including the stacktrace to the segment | ||
segment = xray_recorder.current_segment() | ||
segment.put_http_meta(http.STATUS, 500) | ||
stack = traceback.extract_stack(limit=xray_recorder._max_trace_back) | ||
segment.add_exception(err, stack) | ||
xray_recorder.end_segment() | ||
raise | ||
|
||
# Store response metadata into the current segment | ||
segment.put_http_meta(http.STATUS, response.status) | ||
|
||
if 'Content-Length' in response.headers: | ||
length = int(response.headers['Content-Length']) | ||
segment.put_http_meta(http.CONTENT_LENGTH, length) | ||
|
||
# Close segment so it can be dispatched off to the daemon | ||
xray_recorder.end_segment() | ||
|
||
return response | ||
return _middleware |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.