-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Tracing for SQLAlchemy and Flask-SQLAlcemy (#14)
* Initial checkin of Query and BaseQuery overrides * Fix ext name * Fix import * Add support for SQLAlchemy.orm and Flask-SQLAlchemy * Remove print() statement * Attempt to fix handling of Flask not having a request with a xray segment * Fix handling of missing segment * Fix test and add docstrings * Fix bug with End segment * Code Review Cleanup. Files now all pass flake8 tests * Move find_subsegment and _search_entity functions to tests/util.py * Uset set_sql to corectly test the sanitized_query value. Add test to sqlalcemy to test filter() and verify params not present in sanitized_query * Comment out set_sql for sanitized_query for seperate code review * Starting to add in set_sql * Add more SQL info to trace * Correct URL handling for connection strings * Bug fix and remove sanitized_query * Fix unit test and add helper util for finding subsegment by annotation key/value * Minor cleanups
- Loading branch information
1 parent
0b00e4b
commit d110386
Showing
14 changed files
with
406 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from builtins import super | ||
from flask_sqlalchemy.model import Model | ||
from sqlalchemy.orm.session import sessionmaker | ||
from flask_sqlalchemy import SQLAlchemy, BaseQuery, _SessionSignalEvents, get_state | ||
from aws_xray_sdk.ext.sqlalchemy.query import XRaySession, XRayQuery | ||
from aws_xray_sdk.ext.sqlalchemy.util.decerators import xray_on_call, decorate_all_functions | ||
|
||
|
||
@decorate_all_functions(xray_on_call) | ||
class XRayBaseQuery(BaseQuery): | ||
BaseQuery.__bases__ = (XRayQuery,) | ||
|
||
|
||
class XRaySignallingSession(XRaySession): | ||
"""The signalling session is the default session that Flask-SQLAlchemy | ||
uses. It extends the default session system with bind selection and | ||
modification tracking. | ||
If you want to use a different session you can override the | ||
:meth:`SQLAlchemy.create_session` function. | ||
.. versionadded:: 2.0 | ||
.. versionadded:: 2.1 | ||
The `binds` option was added, which allows a session to be joined | ||
to an external transaction. | ||
""" | ||
|
||
def __init__(self, db, autocommit=False, autoflush=True, **options): | ||
#: The application that this session belongs to. | ||
self.app = app = db.get_app() | ||
track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] | ||
bind = options.pop('bind', None) or db.engine | ||
binds = options.pop('binds', db.get_binds(app)) | ||
|
||
if track_modifications is None or track_modifications: | ||
_SessionSignalEvents.register(self) | ||
|
||
XRaySession.__init__( | ||
self, autocommit=autocommit, autoflush=autoflush, | ||
bind=bind, binds=binds, **options | ||
) | ||
|
||
def get_bind(self, mapper=None, clause=None): | ||
# mapper is None if someone tries to just get a connection | ||
if mapper is not None: | ||
info = getattr(mapper.mapped_table, 'info', {}) | ||
bind_key = info.get('bind_key') | ||
if bind_key is not None: | ||
state = get_state(self.app) | ||
return state.db.get_engine(self.app, bind=bind_key) | ||
return XRaySession.get_bind(self, mapper, clause) | ||
|
||
|
||
class XRayFlaskSqlAlchemy(SQLAlchemy): | ||
def __init__(self, app=None, use_native_unicode=True, session_options=None, | ||
metadata=None, query_class=XRayBaseQuery, model_class=Model): | ||
super().__init__(app, use_native_unicode, session_options, | ||
metadata, query_class, model_class) | ||
|
||
def create_session(self, options): | ||
return sessionmaker(class_=XRaySignallingSession, db=self, **options) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from builtins import super | ||
from sqlalchemy.orm.query import Query | ||
from sqlalchemy.orm.session import Session, sessionmaker | ||
from .util.decerators import xray_on_call, decorate_all_functions | ||
|
||
|
||
@decorate_all_functions(xray_on_call) | ||
class XRaySession(Session): | ||
pass | ||
|
||
|
||
@decorate_all_functions(xray_on_call) | ||
class XRayQuery(Query): | ||
pass | ||
|
||
|
||
@decorate_all_functions(xray_on_call) | ||
class XRaySessionMaker(sessionmaker): | ||
def __init__(self, bind=None, class_=XRaySession, autoflush=True, | ||
autocommit=False, | ||
expire_on_commit=True, | ||
info=None, **kw): | ||
kw['query_cls'] = XRayQuery | ||
super().__init__(bind, class_, autoflush, autocommit, expire_on_commit, | ||
info, **kw) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import re | ||
from aws_xray_sdk.core import xray_recorder | ||
from future.standard_library import install_aliases | ||
install_aliases() | ||
from urllib.parse import urlparse, uses_netloc | ||
|
||
|
||
|
||
def decorate_all_functions(function_decorator): | ||
def decorator(cls): | ||
for c in cls.__bases__: | ||
for name, obj in vars(c).items(): | ||
if name.startswith("_"): | ||
continue | ||
if callable(obj): | ||
try: | ||
obj = obj.__func__ # unwrap Python 2 unbound method | ||
except AttributeError: | ||
pass # not needed in Python 3 | ||
setattr(c, name, function_decorator(c, obj)) | ||
return cls | ||
return decorator | ||
|
||
def xray_on_call(cls, func): | ||
def wrapper(*args, **kw): | ||
from ..query import XRayQuery, XRaySession | ||
from ...flask_sqlalchemy.query import XRaySignallingSession | ||
class_name = str(cls.__module__) | ||
c = xray_recorder._context | ||
sql = None | ||
subsegment = None | ||
if class_name == "sqlalchemy.orm.session": | ||
for arg in args: | ||
if isinstance(arg, XRaySession): | ||
sql = parse_bind(arg.bind) | ||
if isinstance(arg, XRaySignallingSession): | ||
sql = parse_bind(arg.bind) | ||
if class_name == 'sqlalchemy.orm.query': | ||
for arg in args: | ||
if isinstance(arg, XRayQuery): | ||
try: | ||
sql = parse_bind(arg.session.bind) | ||
# Commented our for later PR | ||
# sql['sanitized_query'] = str(arg) | ||
except: | ||
sql = None | ||
if sql is not None: | ||
if getattr(c._local, 'entities', None) is not None: | ||
subsegment = xray_recorder.begin_subsegment(sql['url'], namespace='remote') | ||
else: | ||
subsegment = None | ||
res = func(*args, **kw) | ||
if subsegment is not None: | ||
subsegment.set_sql(sql) | ||
subsegment.put_annotation("sqlalchemy", class_name+'.'+func.__name__ ); | ||
xray_recorder.end_subsegment() | ||
return res | ||
return wrapper | ||
# URL Parse output | ||
# scheme 0 URL scheme specifier scheme parameter | ||
# netloc 1 Network location part empty string | ||
# path 2 Hierarchical path empty string | ||
# query 3 Query component empty string | ||
# fragment 4 Fragment identifier empty string | ||
# username User name None | ||
# password Password None | ||
# hostname Host name (lower case) None | ||
# port Port number as integer, if present None | ||
# | ||
# XRAY Trace SQL metaData Sample | ||
# "sql" : { | ||
# "url": "jdbc:postgresql://aawijb5u25wdoy.cpamxznpdoq8.us-west-2.rds.amazonaws.com:5432/ebdb", | ||
# "preparation": "statement", | ||
# "database_type": "PostgreSQL", | ||
# "database_version": "9.5.4", | ||
# "driver_version": "PostgreSQL 9.4.1211.jre7", | ||
# "user" : "dbuser", | ||
# "sanitized_query" : "SELECT * FROM customers WHERE customer_id=?;" | ||
# } | ||
def parse_bind(bind): | ||
"""Parses a connection string and creates SQL trace metadata""" | ||
m = re.match(r"Engine\((.*?)\)", str(bind)) | ||
if m is not None: | ||
u = urlparse(m.group(1)) | ||
# Add Scheme to uses_netloc or // will be missing from url. | ||
uses_netloc.append(u.scheme) | ||
safe_url = "" | ||
if u.password is None: | ||
safe_url = u.geturl() | ||
else: | ||
# Strip password from URL | ||
host_info = u.netloc.rpartition('@')[-1] | ||
parts = u._replace(netloc='{}@{}'.format(u.username, host_info)) | ||
safe_url = u.geturl() | ||
sql = {} | ||
sql['database_type'] = u.scheme | ||
sql['url'] = safe_url | ||
if u.username is not None: | ||
sql['user'] = "{}".format(u.username) | ||
return sql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import absolute_import | ||
import pytest | ||
from aws_xray_sdk.core import xray_recorder | ||
from aws_xray_sdk.core.context import Context | ||
from aws_xray_sdk.ext.flask_sqlalchemy.query import XRayFlaskSqlAlchemy | ||
from flask import Flask | ||
from ...util import find_subsegment_by_annotation | ||
|
||
|
||
app = Flask(__name__) | ||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False | ||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" | ||
db = XRayFlaskSqlAlchemy(app) | ||
|
||
|
||
class User(db.Model): | ||
__tablename__ = "users" | ||
|
||
id = db.Column(db.Integer, primary_key=True) | ||
name = db.Column(db.String(255), nullable=False, unique=True) | ||
fullname = db.Column(db.String(255), nullable=False) | ||
password = db.Column(db.String(255), nullable=False) | ||
|
||
|
||
@pytest.fixture() | ||
def session(): | ||
"""Test Fixture to Create DataBase Tables and start a trace segment""" | ||
xray_recorder.configure(service='test', sampling=False, context=Context()) | ||
xray_recorder.clear_trace_entities() | ||
xray_recorder.begin_segment('SQLAlchemyTest') | ||
db.create_all() | ||
yield | ||
xray_recorder.end_segment() | ||
xray_recorder.clear_trace_entities() | ||
|
||
|
||
def test_all(capsys, session): | ||
""" Test calling all() on get all records. | ||
Verify that we capture trace of query and return the SQL as metdata""" | ||
# with capsys.disabled(): | ||
User.query.all() | ||
subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.query.all') | ||
assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.all' | ||
# assert subsegment['sql']['sanitized_query'] | ||
assert subsegment['sql']['url'] | ||
|
||
|
||
def test_add(capsys, session): | ||
""" Test calling add() on insert a row. | ||
Verify we that we capture trace for the add""" | ||
# with capsys.disabled(): | ||
john = User(name='John', fullname="John Doe", password="password") | ||
db.session.add(john) | ||
subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.session.add') | ||
assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.session.add' | ||
assert subsegment['sql']['url'] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import absolute_import | ||
import pytest | ||
from aws_xray_sdk.core import xray_recorder | ||
from aws_xray_sdk.core.context import Context | ||
from aws_xray_sdk.ext.sqlalchemy.query import XRaySessionMaker | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy import create_engine, Column, Integer, String | ||
from ...util import find_subsegment_by_annotation | ||
|
||
|
||
Base = declarative_base() | ||
|
||
|
||
class User(Base): | ||
__tablename__ = 'users' | ||
|
||
id = Column(Integer, primary_key=True) | ||
name = Column(String) | ||
fullname = Column(String) | ||
password = Column(String) | ||
|
||
|
||
@pytest.fixture() | ||
def session(): | ||
"""Test Fixture to Create DataBase Tables and start a trace segment""" | ||
engine = create_engine('sqlite:///:memory:') | ||
xray_recorder.configure(service='test', sampling=False, context=Context()) | ||
xray_recorder.clear_trace_entities() | ||
xray_recorder.begin_segment('SQLAlchemyTest') | ||
Session = XRaySessionMaker(bind=engine) | ||
Base.metadata.create_all(engine) | ||
session = Session() | ||
yield session | ||
xray_recorder.end_segment() | ||
xray_recorder.clear_trace_entities() | ||
|
||
|
||
def test_all(capsys, session): | ||
""" Test calling all() on get all records. | ||
Verify we run the query and return the SQL as metdata""" | ||
# with capsys.disabled(): | ||
session.query(User).all() | ||
subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.query.all') | ||
assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.all' | ||
# assert subsegment['sql']['sanitized_query'] | ||
assert subsegment['sql']['url'] | ||
|
||
|
||
def test_add(capsys, session): | ||
""" Test calling add() on insert a row. | ||
Verify we that we capture trace for the add""" | ||
# with capsys.disabled(): | ||
john = User(name='John', fullname="John Doe", password="password") | ||
session.add(john) | ||
subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.session.add') | ||
assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.session.add' | ||
assert subsegment['sql']['url'] | ||
|
||
|
||
def test_filter(capsys, session): | ||
""" Test calling all() on get all records. | ||
Verify we run the query and return the SQL as metdata""" | ||
# with capsys.disabled(): | ||
session.query(User).filter(User.password=="mypassword!") | ||
subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.query.filter') | ||
assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.filter' | ||
# assert subsegment['sql']['sanitized_query'] | ||
# assert "mypassword!" not in subsegment['sql']['sanitized_query'] | ||
assert subsegment['sql']['url'] |
Oops, something went wrong.