Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass request to schema generation #4383

Merged
merged 3 commits into from
Aug 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions rest_framework/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,52 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns
self.patterns = urls.urlpatterns
else:
self.patterns = patterns

if url and not url.endswith('/'):
url += '/'

self.title = title
self.url = url
self.endpoints = self.get_api_endpoints(patterns)
self.endpoints = None

def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
if self.endpoints is None:
self.endpoints = self.get_api_endpoints(self.patterns)

links = []
for key, path, method, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = ()
view.kwargs = {}
view.format_kwarg = None

if request is not None:
view.request = clone_request(request, method)
view.format_kwarg = None
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))
continue
else:
view.request = None

link = self.get_link(path, method, callback, view)
links.append((key, link))

if not endpoints:
if not link:
return None

# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
for key, link in links:
insert_into(content, key, link)

# Return the schema document.
Expand All @@ -122,8 +130,7 @@ def get_api_endpoints(self, patterns, prefix=''):
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
endpoint = (key, path, method, callback)
api_endpoints.append(endpoint)

elif isinstance(pattern, RegexURLResolver):
Expand Down Expand Up @@ -190,14 +197,10 @@ def get_key(self, path, method, callback):

# Methods for generating each individual `Link` instance...

def get_link(self, path, method, callback):
def get_link(self, path, method, callback, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)

fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view)
Expand Down Expand Up @@ -260,20 +263,18 @@ def get_serializer_fields(self, path, method, callback, view):
if method not in ('PUT', 'PATCH', 'POST'):
return []

if not hasattr(view, 'get_serializer_class'):
if not hasattr(view, 'get_serializer'):
return []

fields = []

serializer_class = view.get_serializer_class()
serializer = serializer_class()
serializer = view.get_serializer()

if isinstance(serializer, serializers.ListSerializer):
return coreapi.Field(name='data', location='body', required=True)
return [coreapi.Field(name='data', location='body', required=True)]

if not isinstance(serializer, serializers.Serializer):
return []

fields = []
for field in serializer.fields.values():
if field.read_only:
continue
Expand Down
4 changes: 4 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
def custom_action(self, request, pk):
return super(ExampleSerializer, self).retrieve(self, request)

def get_serializer(self, *args, **kwargs):
assert self.request
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)


class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
Expand Down