diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 6b6324033c..0618e94fd2 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -39,6 +39,10 @@ class SchemaGenerator(object): 'patch': 'partial_update', 'delete': 'destroy', } + known_actions = ( + 'create', 'read', 'retrieve', 'list', + 'update', 'partial_update', 'destroy' + ) def __init__(self, title=None, url=None, patterns=None, urlconf=None): assert coreapi, '`coreapi` must be installed for schema support.' @@ -118,7 +122,8 @@ def get_api_endpoints(self, patterns, prefix=''): if self.should_include_endpoint(path, callback): for method in self.get_allowed_methods(callback): action = self.get_action(path, method, callback) - endpoint = (path, method, action, callback) + category = self.get_category(path, method, callback, action) + endpoint = (path, method, category, action, callback) api_endpoints.append(endpoint) elif isinstance(pattern, RegexURLResolver): @@ -128,21 +133,7 @@ def get_api_endpoints(self, patterns, prefix=''): ) api_endpoints.extend(nested_endpoints) - return self.add_categories(api_endpoints) - - def add_categories(self, api_endpoints): - """ - (path, method, action, callback) -> (path, method, category, action, callback) - """ - # Determine the top level categories for the schema content, - # based on the URLs of the endpoints. Eg `set(['users', 'organisations'])` - paths = [endpoint[0] for endpoint in api_endpoints] - categories = self.get_categories(paths) - - return [ - (path, method, self.get_category(categories, path), action, callback) - for (path, method, action, callback) in api_endpoints - ] + return api_endpoints def get_path(self, path_regex): """ @@ -181,36 +172,41 @@ def get_allowed_methods(self, callback): def get_action(self, path, method, callback): """ - Return a description action string for the endpoint, eg. 'list'. + Return a descriptive action string for the endpoint, eg. 'list'. """ actions = getattr(callback, 'actions', self.default_mapping) return actions[method.lower()] - def get_categories(self, paths): - categories = set() - split_paths = set([ - tuple(path.split("{")[0].strip('/').split('/')) - for path in paths - ]) - - while split_paths: - for split_path in list(split_paths): - if len(split_path) == 0: - split_paths.remove(split_path) - elif len(split_path) == 1: - categories.add(split_path[0]) - split_paths.remove(split_path) - elif split_path[0] in categories: - split_paths.remove(split_path) - - return categories - - def get_category(self, categories, path): - path_components = path.split("{")[0].strip('/').split('/') - for path_component in path_components: - if path_component in categories: - return path_component - return None + def get_category(self, path, method, callback, action): + """ + Return a descriptive category string for the endpoint, eg. 'users'. + + Examples of category/action pairs that should be generated for various + endpoints: + + /users/ [users][list], [users][create] + /users/{pk}/ [users][read], [users][update], [users][destroy] + /users/enabled/ [users][enabled] (custom action) + /users/{pk}/star/ [users][star] (custom action) + /users/{pk}/groups/ [groups][list], [groups][create] + /users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy] + """ + path_components = path.strip('/').split('/') + path_components = [ + component for component in path_components + if '{' not in component + ] + if action in self.known_actions: + # Default action, eg "/users/", "/users/{pk}/" + idx = -1 + else: + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + idx = -2 + + try: + return path_components[idx] + except IndexError: + return None # Methods for generating each individual `Link` instance...