diff --git a/django_forest/resources/utils/queryset/__init__.py b/django_forest/resources/utils/queryset/__init__.py index 4279d0d..3ca161c 100644 --- a/django_forest/resources/utils/queryset/__init__.py +++ b/django_forest/resources/utils/queryset/__init__.py @@ -1,14 +1,15 @@ -from django_forest.utils.collection import Collection from .filters import FiltersMixin from .limit_fields import LimitFieldsMixin from .pagination import PaginationMixin from .scope import ScopeMixin from .search import SearchMixin +from .segment import SegmentMixin from django_forest.resources.utils.decorators import DecoratorsMixin -class QuerysetMixin(PaginationMixin, FiltersMixin, SearchMixin, ScopeMixin, DecoratorsMixin, LimitFieldsMixin): - +class QuerysetMixin( + PaginationMixin, FiltersMixin, SearchMixin, ScopeMixin, DecoratorsMixin, LimitFieldsMixin, SegmentMixin +): def filter_queryset(self, queryset, Model, params, request): # Notice: first apply scope scope_filters = self.get_scope(request, Model) @@ -34,11 +35,7 @@ def enhance_queryset(self, queryset, Model, params, request, apply_pagination=Tr queryset = queryset.order_by(params['sort'].replace('.', '__')) # segment - if 'segment' in params: - collection = Collection._registry[Model._meta.db_table] - segment = next((x for x in collection.segments if x['name'] == params['segment']), None) - if segment is not None and 'where' in segment: - queryset = queryset.filter(segment['where']()) + queryset = self.handle_segment(params, Model, queryset) # limit fields queryset = self.handle_limit_fields(params, Model, queryset) diff --git a/django_forest/resources/utils/queryset/live_query_segment.py b/django_forest/resources/utils/queryset/live_query_segment.py new file mode 100644 index 0000000..7cab22f --- /dev/null +++ b/django_forest/resources/utils/queryset/live_query_segment.py @@ -0,0 +1,28 @@ +import re +from django.db import connection + + +class LiveQuerySegmentMixin: + def handle_live_query_segment(self, live_query, Model, queryset): + ids = self._get_live_query_ids(live_query) + pk_field = Model._meta.pk.attname + queryset = queryset.filter(**{f"{pk_field}__in": ids}) + return queryset + + def _get_live_query_ids(self, live_query): + self._validate_query(live_query) + sql_query = "select id from (%s) as ids;" % live_query[0:live_query.find(";")] + with connection.cursor() as cursor: + cursor.execute(sql_query) + res = cursor.fetchall() + return [r[0] for r in res] + + def _validate_query(self, query): + if len(query.strip()) == 0: + raise Exception("Live Query Segment: You cannot execute an empty SQL query.") + + if ';' in query and query.find(';') < len(query.strip())-1: + raise Exception("Live Query Segment: You cannot chain SQL queries.") + + if not re.search(r'^SELECT\s.*FROM\s.*$', query, flags=re.IGNORECASE | re.MULTILINE | re.DOTALL): + raise Exception("Live Query Segment: Only SELECT queries are allowed.") diff --git a/django_forest/resources/utils/queryset/segment.py b/django_forest/resources/utils/queryset/segment.py new file mode 100644 index 0000000..7d2b273 --- /dev/null +++ b/django_forest/resources/utils/queryset/segment.py @@ -0,0 +1,17 @@ +from .live_query_segment import LiveQuerySegmentMixin +from django_forest.utils.collection import Collection + + +class SegmentMixin(LiveQuerySegmentMixin): + def handle_segment(self, params, Model, queryset): + if 'segment' in params: + collection = Collection._registry[Model._meta.db_table] + segment = next((x for x in collection.segments if x['name'] == params['segment']), None) + if segment is not None and 'where' in segment: + queryset = queryset.filter(segment['where']()) + + # live query segment + if "segmentQuery" in params: + queryset = self.handle_live_query_segment(params['segmentQuery'], Model, queryset) + + return queryset diff --git a/django_forest/tests/resources/views/list/test_list_live_query_segment.py b/django_forest/tests/resources/views/list/test_list_live_query_segment.py new file mode 100644 index 0000000..f57e8db --- /dev/null +++ b/django_forest/tests/resources/views/list/test_list_live_query_segment.py @@ -0,0 +1,174 @@ +import copy +import sys +from datetime import datetime +from unittest import mock + +import pytest +import pytz +from django.test import TransactionTestCase +from django.urls import reverse +from freezegun import freeze_time + +from django_forest.tests.fixtures.schema import test_schema +from django_forest.tests.resources.views.list.test_list_scope import mocked_scope +from django_forest.utils.schema import Schema +from django_forest.utils.schema.json_api_schema import JsonApiSchema +from django_forest.utils.date import get_timezone + + +# reset forest config dir auto import +from django_forest.utils.scope import ScopeManager + + +@pytest.fixture() +def reset_config_dir_import(): + for key in list(sys.modules.keys()): + if key.startswith('django_forest.tests.forest'): + del sys.modules[key] + + +@pytest.mark.usefixtures('reset_config_dir_import') +class ResourceListSmartSegmentViewTests(TransactionTestCase): + fixtures = ['article.json', 'publication.json', + 'session.json', + 'question.json', 'choice.json', + 'place.json', 'restaurant.json', + 'student.json', + 'serial.json'] + + @pytest.fixture(autouse=True) + def inject_fixtures(self, django_assert_num_queries): + self._django_assert_num_queries = django_assert_num_queries + + def setUp(self): + Schema.schema = copy.deepcopy(test_schema) + Schema.add_smart_features() + Schema.handle_json_api_schema() + self.url = reverse('django_forest:resources:list', kwargs={'resource': 'tests_question'}) + self.client = self.client_class( + HTTP_AUTHORIZATION='Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjUiLCJlbWFpbCI6Imd1aWxsYXVtZWNAZm9yZXN0YWRtaW4uY29tIiwiZmlyc3RfbmFtZSI6Ikd1aWxsYXVtZSIsImxhc3RfbmFtZSI6IkNpc2NvIiwidGVhbSI6Ik9wZXJhdGlvbnMiLCJyZW5kZXJpbmdfaWQiOjEsImV4cCI6MTYyNTY3OTYyNi44ODYwMTh9.mHjA05yvMr99gFMuFv0SnPDCeOd2ZyMSN868V7lsjnw') + + def tearDown(self): + # reset _registry after each test + JsonApiSchema._registry = {} + ScopeManager.cache = {} + + @mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) + @freeze_time( + lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) + ) + @mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) + def test_get(self, mocked_scope_has_expired, mocked_decode): + ScopeManager.cache = { + '1': { + 'scopes': mocked_scope, + 'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) + } + } + response = self.client.get(self.url, { + 'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', + 'fields[topic]': 'name', + 'segmentQuery': 'select * from tests_question where id=1;', + 'page[number]': '1', + 'page[size]': '15', + 'timezone': 'Europe/Paris', + }) + data = response.json() + self.assertEqual(response.status_code, 200) + self.assertEqual(data, { + 'data': [ + { + 'type': 'tests_question', + 'attributes': { + 'pub_date': '2021-06-02T13:52:53.528000+00:00', + 'question_text': 'what is your favorite color?', + 'foo': 'what is your favorite color?+foo', + 'bar': 'what is your favorite color?+bar' + }, + 'id': 1, + 'links': { + 'self': '/forest/tests_question/1' + }, + 'relationships': { + 'topic': { + 'data': None, + 'links': { + 'related': '/forest/tests_question/1/relationships/topic' + } + } + }, + }, + ] + }) + + @mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) + @freeze_time( + lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) + ) + @mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) + def test_get_error_when_multiple_request(self, mocked_scope_has_expired, mocked_decode): + ScopeManager.cache = { + '1': { + 'scopes': mocked_scope, + 'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) + } + } + response = self.client.get(self.url, { + 'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', + 'fields[topic]': 'name', + 'segmentQuery': 'select * from tests_question where id=1;select * from user_users', + 'page[number]': '1', + 'page[size]': '15', + 'timezone': 'Europe/Paris', + }) + data = response.json() + self.assertEqual(response.status_code, 400) + self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot chain SQL queries.'}]}) + + @mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) + @freeze_time( + lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) + ) + @mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) + def test_get_error_when_sql_is_not_select(self, mocked_scope_has_expired, mocked_decode): + ScopeManager.cache = { + '1': { + 'scopes': mocked_scope, + 'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) + } + } + response = self.client.get(self.url, { + 'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', + 'fields[topic]': 'name', + 'segmentQuery': 'insert into tests_question(id) values(999)', + 'page[number]': '1', + 'page[size]': '15', + 'timezone': 'Europe/Paris', + }) + data = response.json() + self.assertEqual(response.status_code, 400) + self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: Only SELECT queries are allowed.'}]}) + + @mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) + @freeze_time( + lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC')) + ) + @mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) + def test_get_error_when_sql_is_empty(self, mocked_scope_has_expired, mocked_decode): + ScopeManager.cache = { + '1': { + 'scopes': mocked_scope, + 'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC) + } + } + response = self.client.get(self.url, { + 'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar', + 'fields[topic]': 'name', + 'segmentQuery': ' \n', + 'page[number]': '1', + 'page[size]': '15', + 'timezone': 'Europe/Paris', + }) + data = response.json() + self.assertEqual(response.status_code, 400) + self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot execute an empty SQL query.'}]}) \ No newline at end of file