Skip to content

Commit

Permalink
Merge pull request #6 from terrycain/aiobotocore
Browse files Browse the repository at this point in the history
Add support for aiobotocore
  • Loading branch information
haotianw465 authored Nov 15, 2017
2 parents 8295724 + 3ec21fb commit e8a1d4f
Show file tree
Hide file tree
Showing 28 changed files with 515 additions and 149 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include aws_xray_sdk/ext/botocore/*.json
include aws_xray_sdk/ext/resources/*.json
include aws_xray_sdk/core/sampling/*.json
include README.md
include LICENSE
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ def myfunc():
myfunc()
```

```python
from aws_xray_sdk.core import xray_recorder

@xray_recorder.capture_async('subsegment_name')
async def myfunc():
# Do something here

async def main():
await myfunc()
```

**Trace AWS Lambda functions**

```python
Expand Down Expand Up @@ -149,6 +160,22 @@ xray_recorder.configure(service='fallback_name', dynamic_naming='*mysite.com*')
XRayMiddleware(app, xray_recorder)
```

**Add aiohttp middleware**
```python
from aiohttp import web

from aws_xray_sdk.ext.aiohttp.middleware import middleware
from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core.async_context import AsyncContext

xray_recorder.configure(service='fallback_name', context=AsyncContext())

app = web.Application(middlewares=[middleware])
app.router.add_get("/", handler)

web.run_app(app)
```

## License

The AWS X-Ray SDK for Python is licensed under the Apache 2.0 License. See LICENSE and NOTICE.txt for more information.
7 changes: 6 additions & 1 deletion aws_xray_sdk/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from .recorder import AWSXRayRecorder
from .patcher import patch_all, patch
from .utils.compat import PY35


xray_recorder = AWSXRayRecorder()
if not PY35:
xray_recorder = AWSXRayRecorder()
else:
from .async_recorder import AsyncAWSXRayRecorder
xray_recorder = AsyncAWSXRayRecorder()

__all__ = [
'patch',
Expand Down
69 changes: 69 additions & 0 deletions aws_xray_sdk/core/async_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import time
import traceback

import wrapt

from aws_xray_sdk.core.recorder import AWSXRayRecorder


class AsyncAWSXRayRecorder(AWSXRayRecorder):
def capture_async(self, name=None):
"""
A decorator that records enclosed function in a subsegment.
It only works with asynchronous functions.
params str name: The name of the subsegment. If not specified
the function name will be used.
"""

@wrapt.decorator
async def wrapper(wrapped, instance, args, kwargs):
func_name = name
if not func_name:
func_name = wrapped.__name__

result = await self.record_subsegment_async(
wrapped, instance, args, kwargs,
name=func_name,
namespace='local',
meta_processor=None,
)

return result

return wrapper

async def record_subsegment_async(self, wrapped, instance, args, kwargs, name,
namespace, meta_processor):

subsegment = self.begin_subsegment(name, namespace)

exception = None
stack = None
return_value = None

try:
return_value = await wrapped(*args, **kwargs)
return return_value
except Exception as e:
exception = e
stack = traceback.extract_stack(limit=self._max_trace_back)
raise
finally:
end_time = time.time()
if callable(meta_processor):
meta_processor(
wrapped=wrapped,
instance=instance,
args=args,
kwargs=kwargs,
return_value=return_value,
exception=exception,
subsegment=subsegment,
stack=stack,
)
elif exception:
if subsegment:
subsegment.add_exception(exception, stack)

self.end_subsegment(end_time)
7 changes: 6 additions & 1 deletion aws_xray_sdk/core/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
log = logging.getLogger(__name__)

SUPPORTED_MODULES = (
'aiobotocore',
'botocore',
'requests',
'sqlite3',
Expand All @@ -23,10 +24,14 @@ def patch(modules_to_patch, raise_errors=True):


def _patch_module(module_to_patch, raise_errors=True):
# boto3 depends on botocore and patch botocore is sufficient
# boto3 depends on botocore and patching botocore is sufficient
if module_to_patch == 'boto3':
module_to_patch = 'botocore'

# aioboto3 depends on aiobotocore and patching aiobotocore is sufficient
if module_to_patch == 'aioboto3':
module_to_patch = 'aiobotocore'

if module_to_patch not in SUPPORTED_MODULES:
raise Exception('module %s is currently not supported for patching'
% module_to_patch)
Expand Down
9 changes: 6 additions & 3 deletions aws_xray_sdk/core/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class AWSXRayRecorder(object):
A global AWS X-Ray recorder that will begin/end segments/subsegments
and send them to the X-Ray daemon. This recorder is initialized during
loading time so you can use::
from aws_xray_sdk.core import xray_recorder
in your module to access it
"""
def __init__(self):
Expand Down Expand Up @@ -312,15 +314,16 @@ def record_subsegment(self, wrapped, instance, args, kwargs, name,

subsegment = self.begin_subsegment(name, namespace)

exception = None
stack = None
return_value = None

try:
return_value = wrapped(*args, **kwargs)
exception = None
stack = None
return return_value
except Exception as e:
exception = e
stack = traceback.extract_stack(limit=self._max_trace_back)
return_value = None
raise
finally:
end_time = time.time()
Expand Down
5 changes: 2 additions & 3 deletions aws_xray_sdk/core/sampling/default_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import json
from random import Random

from pkg_resources import resource_filename
from .sampling_rule import SamplingRule
from ..exceptions.exceptions import InvalidSamplingManifestError

__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))

with open(os.path.join(__location__, 'default_sampling_rule.json')) as f:
with open(resource_filename(__name__, 'default_sampling_rule.json')) as f:
default_sampling_rule = json.load(f)


Expand Down
1 change: 1 addition & 0 deletions aws_xray_sdk/core/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


PY2 = sys.version_info < (3,)
PY35 = sys.version_info >= (3, 5)

if PY2:
annotation_value_types = (int, long, float, bool, str) # noqa: F821
Expand Down
3 changes: 3 additions & 0 deletions aws_xray_sdk/ext/aiobotocore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .patch import patch

__all__ = ['patch']
39 changes: 39 additions & 0 deletions aws_xray_sdk/ext/aiobotocore/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import aiobotocore.client
import wrapt

from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.ext.boto_utils import inject_header, aws_meta_processor


def patch():
"""
Patch aiobotocore client so it generates subsegments
when calling AWS services.
"""
if hasattr(aiobotocore.client, '_xray_enabled'):
return
setattr(aiobotocore.client, '_xray_enabled', True)

wrapt.wrap_function_wrapper(
'aiobotocore.client',
'AioBaseClient._make_api_call',
_xray_traced_aiobotocore,
)

wrapt.wrap_function_wrapper(
'aiobotocore.endpoint',
'AioEndpoint._encode_headers',
inject_header,
)


async def _xray_traced_aiobotocore(wrapped, instance, args, kwargs):
service = instance._service_model.metadata["endpointPrefix"]
result = await xray_recorder.record_subsegment_async(
wrapped, instance, args, kwargs,
name=service,
namespace='aws',
meta_processor=aws_meta_processor,
)

return result
134 changes: 134 additions & 0 deletions aws_xray_sdk/ext/boto_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import absolute_import
# Need absolute import as botocore is also in the current folder for py27
import json

from pkg_resources import resource_filename
from botocore.exceptions import ClientError

from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core.models import http

from aws_xray_sdk.ext.util import inject_trace_header, to_snake_case


with open(resource_filename(__name__, 'resources/aws_para_whitelist.json'), 'r') as data_file:
whitelist = json.load(data_file)


def inject_header(wrapped, instance, args, kwargs):
headers = args[0]
inject_trace_header(headers, xray_recorder.current_subsegment())
return wrapped(*args, **kwargs)


def aws_meta_processor(wrapped, instance, args, kwargs,
return_value, exception, subsegment, stack):
region = instance.meta.region_name

if 'operation_name' in kwargs:
operation_name = kwargs['operation_name']
else:
operation_name = args[0]

aws_meta = {
'operation': operation_name,
'region': region,
}

if return_value:
resp_meta = return_value.get('ResponseMetadata')
if resp_meta:
aws_meta['request_id'] = resp_meta.get('RequestId')
subsegment.put_http_meta(http.STATUS,
resp_meta.get('HTTPStatusCode'))
# for service like S3 that returns special request id in response headers
if 'HTTPHeaders' in resp_meta and resp_meta['HTTPHeaders'].get('x-amz-id-2'):
aws_meta['id_2'] = resp_meta['HTTPHeaders']['x-amz-id-2']

elif exception:
_aws_error_handler(exception, stack, subsegment, aws_meta)

_extract_whitelisted_params(subsegment.name, operation_name,
aws_meta, args, kwargs, return_value)

subsegment.set_aws(aws_meta)


def _aws_error_handler(exception, stack, subsegment, aws_meta):

if not exception or not isinstance(exception, ClientError):
return

response_metadata = exception.response.get('ResponseMetadata')

if not response_metadata:
return

aws_meta['request_id'] = response_metadata.get('RequestId')

status_code = response_metadata.get('HTTPStatusCode')

subsegment.put_http_meta(http.STATUS, status_code)
if status_code == 429:
subsegment.add_throttle_flag()
if status_code / 100 == 4:
subsegment.add_error_flag()

subsegment.add_exception(exception, stack, True)


def _extract_whitelisted_params(service, operation,
aws_meta, args, kwargs, response):

# check if service is whitelisted
if service not in whitelist['services']:
return
operations = whitelist['services'][service]['operations']

# check if operation is whitelisted
if operation not in operations:
return
params = operations[operation]

# record whitelisted request/response parameters
if 'request_parameters' in params:
_record_params(params['request_parameters'], args[1], aws_meta)

if 'request_descriptors' in params:
_record_special_params(params['request_descriptors'],
args[1], aws_meta)

if 'response_parameters' in params and response:
_record_params(params['response_parameters'], response, aws_meta)

if 'response_descriptors' in params and response:
_record_special_params(params['response_descriptors'],
response, aws_meta)


def _record_params(whitelisted, actual, aws_meta):

for key in whitelisted:
if key in actual:
snake_key = to_snake_case(key)
aws_meta[snake_key] = actual[key]


def _record_special_params(whitelisted, actual, aws_meta):

for key in whitelisted:
if key in actual:
_process_descriptor(whitelisted[key], actual[key], aws_meta)


def _process_descriptor(descriptor, value, aws_meta):

# "get_count" = true
if 'get_count' in descriptor and descriptor['get_count']:
value = len(value)

# "get_keys" = true
if 'get_keys' in descriptor and descriptor['get_keys']:
value = value.keys()

aws_meta[descriptor['rename_to']] = value
Loading

0 comments on commit e8a1d4f

Please sign in to comment.