Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure SchemaGenerator for easier subclassing #5271

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions rest_framework/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,28 +124,6 @@ def insert_into(target, keys, value):
target[keys[-1]] = value


def is_custom_action(action):
return action not in set([
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
])


def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'

if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True


def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
Expand Down Expand Up @@ -265,6 +243,9 @@ class SchemaGenerator(object):
'patch': 'partial_update',
'delete': 'destroy',
}
default_list_mapping = {
'get': 'list',
}
endpoint_inspector_cls = EndpointInspector

# Map the method names we use for viewset actions onto external schema names.
Expand Down Expand Up @@ -293,6 +274,10 @@ def __init__(self, title=None, url=None, description=None, patterns=None, urlcon
self.description = description
self.url = url
self.endpoints = None
self.default_actions = set(
list(self.default_mapping.values()) +
list(self.default_list_mapping.values())
)

def get_schema(self, request=None, public=False):
"""
Expand Down Expand Up @@ -602,7 +587,9 @@ def get_serializer_fields(self, path, method, view):
return fields

def get_pagination_fields(self, path, method, view):
if not is_list_view(path, method, view):
if not self.is_list_view(path, method, view):
return []
if method.lower() != 'get':
return []

pagination = getattr(view, 'pagination_class', None)
Expand All @@ -613,7 +600,7 @@ def get_pagination_fields(self, path, method, view):
return paginator.get_schema_fields(view)

def get_filter_fields(self, path, method, view):
if not is_list_view(path, method, view):
if not self.is_list_view(path, method, view):
return []

if not getattr(view, 'filter_backends', None):
Expand Down Expand Up @@ -643,8 +630,8 @@ def get_keys(self, subpath, method, view):
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
if self.is_list_view(subpath, method, view):
action = self.default_list_mapping[method.lower()]
else:
action = self.default_mapping[method.lower()]

Expand All @@ -654,7 +641,7 @@ def get_keys(self, subpath, method, view):
if '{' not in component
]

if is_custom_action(action):
if self.is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
Expand All @@ -670,6 +657,24 @@ def get_keys(self, subpath, method, view):
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]

def is_custom_action(self, action):
return action not in self.default_actions

def is_list_view(self, path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action in self.default_list_mapping.values()

if method.lower() not in self.default_list_mapping:
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True


class SchemaView(APIView):
_ignore_model_permissions = True
Expand Down