-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(segments): live query segments is now working
- Loading branch information
Showing
4 changed files
with
224 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
django_forest/resources/utils/queryset/live_query_segment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import re | ||
from django.db import connection | ||
|
||
|
||
class LiveQuerySegmentMixin: | ||
def handle_live_query_segment(self, live_query, Model, queryset): | ||
ids = self._get_live_query_ids(live_query) | ||
pk_field = Model._meta.pk.attname | ||
queryset = queryset.filter(**{f"{pk_field}__in": ids}) | ||
return queryset | ||
|
||
def _get_live_query_ids(self, live_query): | ||
self._validate_query(live_query) | ||
sql_query = "select id from (%s) as ids;" % live_query[0:live_query.find(";")] | ||
with connection.cursor() as cursor: | ||
cursor.execute(sql_query) | ||
res = cursor.fetchall() | ||
return [r[0] for r in res] | ||
|
||
def _validate_query(self, query): | ||
if len(query.strip()) == 0: | ||
raise Exception("Live Query Segment: You cannot execute an empty SQL query.") | ||
|
||
if ';' in query and query.find(';') < len(query.strip())-1: | ||
raise Exception("Live Query Segment: You cannot chain SQL queries.") | ||
|
||
if not re.search(r'^SELECT\s.*FROM\s.*$', query, flags=re.IGNORECASE | re.MULTILINE | re.DOTALL): | ||
raise Exception("Live Query Segment: Only SELECT queries are allowed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .live_query_segment import LiveQuerySegmentMixin | ||
from django_forest.utils.collection import Collection | ||
|
||
|
||
class SegmentMixin(LiveQuerySegmentMixin): | ||
def handle_segment(self, params, Model, queryset): | ||
if 'segment' in params: | ||
collection = Collection._registry[Model._meta.db_table] | ||
segment = next((x for x in collection.segments if x['name'] == params['segment']), None) | ||
if segment is not None and 'where' in segment: | ||
queryset = queryset.filter(segment['where']()) | ||
|
||
# live query segment | ||
if "segmentQuery" in params: | ||
queryset = self.handle_live_query_segment(params['segmentQuery'], Model, queryset) | ||
|
||
return queryset |
174 changes: 174 additions & 0 deletions
174
django_forest/tests/resources/views/list/test_list_live_query_segment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import copy | ||
import sys | ||
from datetime import datetime | ||
from unittest import mock | ||
|
||
import pytest | ||
import pytz | ||
from django.test import TransactionTestCase | ||
from django.urls import reverse | ||
from freezegun import freeze_time | ||
|
||
from django_forest.tests.fixtures.schema import test_schema | ||
from django_forest.tests.resources.views.list.test_list_scope import mocked_scope | ||
from django_forest.utils.schema import Schema | ||
from django_forest.utils.schema.json_api_schema import JsonApiSchema | ||
from django_forest.utils.date import get_timezone | ||
|
||
|
||
# reset forest config dir auto import | ||
from django_forest.utils.scope import ScopeManager | ||
|
||
|
||
@pytest.fixture() | ||
def reset_config_dir_import(): | ||
for key in list(sys.modules.keys()): | ||
if key.startswith('django_forest.tests.forest'): | ||
del sys.modules[key] | ||
|
||
|
||
@pytest.mark.usefixtures('reset_config_dir_import') | ||
class ResourceListSmartSegmentViewTests(TransactionTestCase): | ||
fixtures = ['article.json', 'publication.json', | ||
'session.json', | ||
'question.json', 'choice.json', | ||
'place.json', 'restaurant.json', | ||
'student.json', | ||
'serial.json'] | ||
|
||
@pytest.fixture(autouse=True) | ||
def inject_fixtures(self, django_assert_num_queries): | ||
self._django_assert_num_queries = django_assert_num_queries | ||
|
||
def setUp(self): | ||
Schema.schema = copy.deepcopy(test_schema) | ||
Schema.add_smart_features() | ||
Schema.handle_json_api_schema() | ||
self.url = reverse('django_forest:resources:list', kwargs={'resource': 'tests_question'}) | ||
self.client = self.client_class( | ||
HTTP_AUTHORIZATION='Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjUiLCJlbWFpbCI6Imd1aWxsYXVtZWNAZm9yZXN0YWRtaW4uY29tIiwiZmlyc3RfbmFtZSI6Ikd1aWxsYXVtZSIsImxhc3RfbmFtZSI6IkNpc2NvIiwidGVhbSI6Ik9wZXJhdGlvbnMiLCJyZW5kZXJpbmdfaWQiOjEsImV4cCI6MTYyNTY3OTYyNi44ODYwMTh9.mHjA05yvMr99gFMuFv0SnPDCeOd2ZyMSN868V7lsjnw') | ||
|
||
def tearDown(self): | ||
# reset _registry after each test | ||
JsonApiSchema._registry = {} | ||
ScopeManager.cache = {} | ||
|
||
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) | ||
@freeze_time( | ||
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) | ||
) | ||
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) | ||
def test_get(self, mocked_scope_has_expired, mocked_decode): | ||
ScopeManager.cache = { | ||
'1': { | ||
'scopes': mocked_scope, | ||
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) | ||
} | ||
} | ||
response = self.client.get(self.url, { | ||
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', | ||
'fields[topic]': 'name', | ||
'segmentQuery': 'select * from tests_question where id=1;', | ||
'page[number]': '1', | ||
'page[size]': '15', | ||
'timezone': 'Europe/Paris', | ||
}) | ||
data = response.json() | ||
self.assertEqual(response.status_code, 200) | ||
self.assertEqual(data, { | ||
'data': [ | ||
{ | ||
'type': 'tests_question', | ||
'attributes': { | ||
'pub_date': '2021-06-02T13:52:53.528000+00:00', | ||
'question_text': 'what is your favorite color?', | ||
'foo': 'what is your favorite color?+foo', | ||
'bar': 'what is your favorite color?+bar' | ||
}, | ||
'id': 1, | ||
'links': { | ||
'self': '/forest/tests_question/1' | ||
}, | ||
'relationships': { | ||
'topic': { | ||
'data': None, | ||
'links': { | ||
'related': '/forest/tests_question/1/relationships/topic' | ||
} | ||
} | ||
}, | ||
}, | ||
] | ||
}) | ||
|
||
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) | ||
@freeze_time( | ||
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) | ||
) | ||
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) | ||
def test_get_error_when_multiple_request(self, mocked_scope_has_expired, mocked_decode): | ||
ScopeManager.cache = { | ||
'1': { | ||
'scopes': mocked_scope, | ||
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) | ||
} | ||
} | ||
response = self.client.get(self.url, { | ||
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', | ||
'fields[topic]': 'name', | ||
'segmentQuery': 'select * from tests_question where id=1;select * from user_users', | ||
'page[number]': '1', | ||
'page[size]': '15', | ||
'timezone': 'Europe/Paris', | ||
}) | ||
data = response.json() | ||
self.assertEqual(response.status_code, 400) | ||
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot chain SQL queries.'}]}) | ||
|
||
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) | ||
@freeze_time( | ||
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) | ||
) | ||
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) | ||
def test_get_error_when_sql_is_not_select(self, mocked_scope_has_expired, mocked_decode): | ||
ScopeManager.cache = { | ||
'1': { | ||
'scopes': mocked_scope, | ||
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) | ||
} | ||
} | ||
response = self.client.get(self.url, { | ||
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', | ||
'fields[topic]': 'name', | ||
'segmentQuery': 'insert into tests_question(id) values(999)', | ||
'page[number]': '1', | ||
'page[size]': '15', | ||
'timezone': 'Europe/Paris', | ||
}) | ||
data = response.json() | ||
self.assertEqual(response.status_code, 400) | ||
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: Only SELECT queries are allowed.'}]}) | ||
|
||
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) | ||
@freeze_time( | ||
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) | ||
) | ||
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) | ||
def test_get_error_when_sql_is_empty(self, mocked_scope_has_expired, mocked_decode): | ||
ScopeManager.cache = { | ||
'1': { | ||
'scopes': mocked_scope, | ||
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) | ||
} | ||
} | ||
response = self.client.get(self.url, { | ||
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', | ||
'fields[topic]': 'name', | ||
'segmentQuery': ' \n', | ||
'page[number]': '1', | ||
'page[size]': '15', | ||
'timezone': 'Europe/Paris', | ||
}) | ||
data = response.json() | ||
self.assertEqual(response.status_code, 400) | ||
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot execute an empty SQL query.'}]}) |