From 9bcc8591f51e2f07e61986635fa5859411b812f3 Mon Sep 17 00:00:00 2001 From: Ivan Anishchuk Date: Fri, 14 Jul 2017 02:30:59 +0800 Subject: [PATCH] Restructure SchemaGenerator for easier subclassing Allow adding new default list actions so that bulk actions can be included in the schema with minimal changes. --- rest_framework/schemas.py | 59 +++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 875f9454b3..f208a209f2 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -124,28 +124,6 @@ def insert_into(target, keys, value): target[keys[-1]] = value -def is_custom_action(action): - return action not in set([ - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - ]) - - -def is_list_view(path, method, view): - """ - Return True if the given path/method appears to represent a list view. - """ - if hasattr(view, 'action'): - # Viewsets have an explicitly defined action, which we can inspect. - return view.action == 'list' - - if method.lower() != 'get': - return False - path_components = path.strip('/').split('/') - if path_components and '{' in path_components[-1]: - return False - return True - - def endpoint_ordering(endpoint): path, method, callback = endpoint method_priority = { @@ -265,6 +243,9 @@ class SchemaGenerator(object): 'patch': 'partial_update', 'delete': 'destroy', } + default_list_mapping = { + 'get': 'list', + } endpoint_inspector_cls = EndpointInspector # Map the method names we use for viewset actions onto external schema names. @@ -293,6 +274,10 @@ def __init__(self, title=None, url=None, description=None, patterns=None, urlcon self.description = description self.url = url self.endpoints = None + self.default_actions = set( + list(self.default_mapping.values()) + + list(self.default_list_mapping.values()) + ) def get_schema(self, request=None, public=False): """ @@ -602,7 +587,9 @@ def get_serializer_fields(self, path, method, view): return fields def get_pagination_fields(self, path, method, view): - if not is_list_view(path, method, view): + if not self.is_list_view(path, method, view): + return [] + if method.lower() != 'get': return [] pagination = getattr(view, 'pagination_class', None) @@ -613,7 +600,7 @@ def get_pagination_fields(self, path, method, view): return paginator.get_schema_fields(view) def get_filter_fields(self, path, method, view): - if not is_list_view(path, method, view): + if not self.is_list_view(path, method, view): return [] if not getattr(view, 'filter_backends', None): @@ -643,8 +630,8 @@ def get_keys(self, subpath, method, view): action = view.action else: # Views have no associated action, so we determine one from the method. - if is_list_view(subpath, method, view): - action = 'list' + if self.is_list_view(subpath, method, view): + action = self.default_list_mapping[method.lower()] else: action = self.default_mapping[method.lower()] @@ -654,7 +641,7 @@ def get_keys(self, subpath, method, view): if '{' not in component ] - if is_custom_action(action): + if self.is_custom_action(action): # Custom action, eg "/users/{pk}/activate/", "/users/active/" if len(view.action_map) > 1: action = self.default_mapping[method.lower()] @@ -670,6 +657,24 @@ def get_keys(self, subpath, method, view): # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] + def is_custom_action(self, action): + return action not in self.default_actions + + def is_list_view(self, path, method, view): + """ + Return True if the given path/method appears to represent a list view. + """ + if hasattr(view, 'action'): + # Viewsets have an explicitly defined action, which we can inspect. + return view.action in self.default_list_mapping.values() + + if method.lower() not in self.default_list_mapping: + return False + path_components = path.strip('/').split('/') + if path_components and '{' in path_components[-1]: + return False + return True + class SchemaView(APIView): _ignore_model_permissions = True