Skip to content

Commit

Permalink
fix(segments): add an util to handle live query (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarreau authored Nov 17, 2023
1 parent 49103ae commit 8c896ed
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 8 deletions.
13 changes: 5 additions & 8 deletions django_forest/resources/utils/queryset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from django_forest.utils.collection import Collection
from .filters import FiltersMixin
from .limit_fields import LimitFieldsMixin
from .pagination import PaginationMixin
from .scope import ScopeMixin
from .search import SearchMixin
from .segment import SegmentMixin
from django_forest.resources.utils.decorators import DecoratorsMixin


class QuerysetMixin(PaginationMixin, FiltersMixin, SearchMixin, ScopeMixin, DecoratorsMixin, LimitFieldsMixin):

class QuerysetMixin(
PaginationMixin, FiltersMixin, SearchMixin, ScopeMixin, DecoratorsMixin, LimitFieldsMixin, SegmentMixin
):
def filter_queryset(self, queryset, Model, params, request):
# Notice: first apply scope
scope_filters = self.get_scope(request, Model)
Expand All @@ -34,11 +35,7 @@ def enhance_queryset(self, queryset, Model, params, request, apply_pagination=Tr
queryset = queryset.order_by(params['sort'].replace('.', '__'))

# segment
if 'segment' in params:
collection = Collection._registry[Model._meta.db_table]
segment = next((x for x in collection.segments if x['name'] == params['segment']), None)
if segment is not None and 'where' in segment:
queryset = queryset.filter(segment['where']())
queryset = self.handle_segment(params, Model, queryset)

# limit fields
queryset = self.handle_limit_fields(params, Model, queryset)
Expand Down
28 changes: 28 additions & 0 deletions django_forest/resources/utils/queryset/live_query_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import re
from django.db import connection


class LiveQuerySegmentMixin:
def handle_live_query_segment(self, live_query, Model, queryset):
ids = self._get_live_query_ids(live_query)
pk_field = Model._meta.pk.attname
queryset = queryset.filter(**{f"{pk_field}__in": ids})
return queryset

def _get_live_query_ids(self, live_query):
self._validate_query(live_query)
sql_query = "select id from (%s) as ids;" % live_query[0:live_query.find(";")]
with connection.cursor() as cursor:
cursor.execute(sql_query)
res = cursor.fetchall()
return [r[0] for r in res]

def _validate_query(self, query):
if len(query.strip()) == 0:
raise Exception("Live Query Segment: You cannot execute an empty SQL query.")

if ';' in query and query.find(';') < len(query.strip())-1:
raise Exception("Live Query Segment: You cannot chain SQL queries.")

if not re.search(r'^SELECT\s.*FROM\s.*$', query, flags=re.IGNORECASE | re.MULTILINE | re.DOTALL):
raise Exception("Live Query Segment: Only SELECT queries are allowed.")
17 changes: 17 additions & 0 deletions django_forest/resources/utils/queryset/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .live_query_segment import LiveQuerySegmentMixin
from django_forest.utils.collection import Collection


class SegmentMixin(LiveQuerySegmentMixin):
def handle_segment(self, params, Model, queryset):
if 'segment' in params:
collection = Collection._registry[Model._meta.db_table]
segment = next((x for x in collection.segments if x['name'] == params['segment']), None)
if segment is not None and 'where' in segment:
queryset = queryset.filter(segment['where']())

# live query segment
if "segmentQuery" in params:
queryset = self.handle_live_query_segment(params['segmentQuery'], Model, queryset)

return queryset
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import copy
import sys
from datetime import datetime
from unittest import mock

import pytest
import pytz
from django.test import TransactionTestCase
from django.urls import reverse
from freezegun import freeze_time

from django_forest.tests.fixtures.schema import test_schema
from django_forest.tests.resources.views.list.test_list_scope import mocked_scope
from django_forest.utils.schema import Schema
from django_forest.utils.schema.json_api_schema import JsonApiSchema
from django_forest.utils.date import get_timezone


# reset forest config dir auto import
from django_forest.utils.scope import ScopeManager


@pytest.fixture()
def reset_config_dir_import():
for key in list(sys.modules.keys()):
if key.startswith('django_forest.tests.forest'):
del sys.modules[key]


@pytest.mark.usefixtures('reset_config_dir_import')
class ResourceListSmartSegmentViewTests(TransactionTestCase):
fixtures = ['article.json', 'publication.json',
'session.json',
'question.json', 'choice.json',
'place.json', 'restaurant.json',
'student.json',
'serial.json']

@pytest.fixture(autouse=True)
def inject_fixtures(self, django_assert_num_queries):
self._django_assert_num_queries = django_assert_num_queries

def setUp(self):
Schema.schema = copy.deepcopy(test_schema)
Schema.add_smart_features()
Schema.handle_json_api_schema()
self.url = reverse('django_forest:resources:list', kwargs={'resource': 'tests_question'})
self.client = self.client_class(
HTTP_AUTHORIZATION='Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjUiLCJlbWFpbCI6Imd1aWxsYXVtZWNAZm9yZXN0YWRtaW4uY29tIiwiZmlyc3RfbmFtZSI6Ikd1aWxsYXVtZSIsImxhc3RfbmFtZSI6IkNpc2NvIiwidGVhbSI6Ik9wZXJhdGlvbnMiLCJyZW5kZXJpbmdfaWQiOjEsImV4cCI6MTYyNTY3OTYyNi44ODYwMTh9.mHjA05yvMr99gFMuFv0SnPDCeOd2ZyMSN868V7lsjnw')

def tearDown(self):
# reset _registry after each test
JsonApiSchema._registry = {}
ScopeManager.cache = {}

@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
@freeze_time(
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
)
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
def test_get(self, mocked_scope_has_expired, mocked_decode):
ScopeManager.cache = {
'1': {
'scopes': mocked_scope,
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
}
}
response = self.client.get(self.url, {
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
'fields[topic]': 'name',
'segmentQuery': 'select * from tests_question where id=1;',
'page[number]': '1',
'page[size]': '15',
'timezone': 'Europe/Paris',
})
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data, {
'data': [
{
'type': 'tests_question',
'attributes': {
'pub_date': '2021-06-02T13:52:53.528000+00:00',
'question_text': 'what is your favorite color?',
'foo': 'what is your favorite color?+foo',
'bar': 'what is your favorite color?+bar'
},
'id': 1,
'links': {
'self': '/forest/tests_question/1'
},
'relationships': {
'topic': {
'data': None,
'links': {
'related': '/forest/tests_question/1/relationships/topic'
}
}
},
},
]
})

@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
@freeze_time(
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
)
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
def test_get_error_when_multiple_request(self, mocked_scope_has_expired, mocked_decode):
ScopeManager.cache = {
'1': {
'scopes': mocked_scope,
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
}
}
response = self.client.get(self.url, {
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
'fields[topic]': 'name',
'segmentQuery': 'select * from tests_question where id=1;select * from user_users',
'page[number]': '1',
'page[size]': '15',
'timezone': 'Europe/Paris',
})
data = response.json()
self.assertEqual(response.status_code, 400)
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot chain SQL queries.'}]})

@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
@freeze_time(
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
)
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
def test_get_error_when_sql_is_not_select(self, mocked_scope_has_expired, mocked_decode):
ScopeManager.cache = {
'1': {
'scopes': mocked_scope,
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
}
}
response = self.client.get(self.url, {
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
'fields[topic]': 'name',
'segmentQuery': 'insert into tests_question(id) values(999)',
'page[number]': '1',
'page[size]': '15',
'timezone': 'Europe/Paris',
})
data = response.json()
self.assertEqual(response.status_code, 400)
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: Only SELECT queries are allowed.'}]})

@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
@freeze_time(
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
)
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
def test_get_error_when_sql_is_empty(self, mocked_scope_has_expired, mocked_decode):
ScopeManager.cache = {
'1': {
'scopes': mocked_scope,
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
}
}
response = self.client.get(self.url, {
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
'fields[topic]': 'name',
'segmentQuery': ' \n',
'page[number]': '1',
'page[size]': '15',
'timezone': 'Europe/Paris',
})
data = response.json()
self.assertEqual(response.status_code, 400)
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot execute an empty SQL query.'}]})

0 comments on commit 8c896ed

Please sign in to comment.