Skip to content

Commit

Permalink
Pass request to schema generation (#4383)
Browse files Browse the repository at this point in the history
Pass request to schema generation
  • Loading branch information
tomchristie authored Aug 11, 2016
1 parent 3698d9e commit b50d895
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 31 deletions.
63 changes: 32 additions & 31 deletions rest_framework/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,52 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
else:
self.patterns = patterns

if url and not url.endswith('/'):
url += '/'

self.title = title
self.url = url
self.endpoints = self.get_api_endpoints(patterns)
self.endpoints = None

def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
if self.endpoints is None:
self.endpoints = self.get_api_endpoints(self.patterns)

links = []
for key, path, method, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = ()
view.kwargs = {}
view.format_kwarg = None

if request is not None:
view.request = clone_request(request, method)
view.format_kwarg = None
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))
continue
else:
view.request = None

link = self.get_link(path, method, callback, view)
links.append((key, link))

if not endpoints:
if not link:
return None

# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
for key, link in links:
insert_into(content, key, link)

# Return the schema document.
Expand All @@ -122,8 +130,7 @@ def get_api_endpoints(self, patterns, prefix=''):
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
endpoint = (key, path, method, callback)
api_endpoints.append(endpoint)

elif isinstance(pattern, RegexURLResolver):
Expand Down Expand Up @@ -190,14 +197,10 @@ def get_key(self, path, method, callback):

# Methods for generating each individual `Link` instance...

def get_link(self, path, method, callback):
def get_link(self, path, method, callback, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)

fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view)
Expand Down Expand Up @@ -260,20 +263,18 @@ def get_serializer_fields(self, path, method, callback, view):
if method not in ('PUT', 'PATCH', 'POST'):
return []

if not hasattr(view, 'get_serializer_class'):
if not hasattr(view, 'get_serializer'):
return []

fields = []

serializer_class = view.get_serializer_class()
serializer = serializer_class()
serializer = view.get_serializer()

if isinstance(serializer, serializers.ListSerializer):
return coreapi.Field(name='data', location='body', required=True)
return [coreapi.Field(name='data', location='body', required=True)]

if not isinstance(serializer, serializers.Serializer):
return []

fields = []
for field in serializer.fields.values():
if field.read_only:
continue
Expand Down
4 changes: 4 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
def custom_action(self, request, pk):
return super(ExampleSerializer, self).retrieve(self, request)

def get_serializer(self, *args, **kwargs):
assert self.request
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)


class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
Expand Down

0 comments on commit b50d895

Please sign in to comment.