diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 836ad4b6a3..f913f046f6 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -10,7 +10,14 @@ API schemas are a useful tool that allow for a range of use cases, including generating reference documentation, or driving dynamic client libraries that can interact with your API. -## Representing schemas internally +## Install Core API + +You'll need to install the `coreapi` package in order to add schema support +for REST framework. + + pip install coreapi + +## Internal schema representation REST framework uses [Core API][coreapi] in order to model schema information in a format-independent representation. This information can then be rendered @@ -68,9 +75,34 @@ has to be rendered into the actual bytes that are used in the response. REST framework includes a renderer class for handling this media type, which is available as `renderers.CoreJSONRenderer`. +### Alternate schema formats + Other schema formats such as [Open API][open-api] ("Swagger"), -[JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can -also be supported by implementing a custom renderer class. +[JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can also +be supported by implementing a custom renderer class that handles converting a +`Document` instance into a bytestring representation. + +If there is a Core API codec package that supports encoding into the format you +want to use then implementing the renderer class can be done by using the codec. + +#### Example + +For example, the `openapi_codec` package provides support for encoding or decoding +to the Open API ("Swagger") format: + + from rest_framework import renderers + from openapi_codec import OpenAPICodec + + class SwaggerRenderer(renderers.BaseRenderer): + media_type = 'application/openapi+json' + format = 'swagger' + + def render(self, data, media_type=None, renderer_context=None): + codec = OpenAPICodec() + return codec.dump(data) + + + ## Schemas vs Hypermedia @@ -89,18 +121,121 @@ document, detailing both the current state and the available interactions. Further information and support on building Hypermedia APIs with REST framework is planned for a future version. + --- -# Adding a schema +# Creating a schema -You'll need to install the `coreapi` package in order to add schema support -for REST framework. +REST framework includes functionality for auto-generating a schema, +or allows you to specify one explicitly. - pip install coreapi +## Manual Schema Specification -REST framework includes functionality for auto-generating a schema, -or allows you to specify one explicitly. There are a few different ways to -add a schema to your API, depending on exactly what you need. +To manually specify a schema you create a Core API `Document`, similar to the +example above. + + schema = coreapi.Document( + title='Flight Search API', + content={ + ... + } + ) + + +## Automatic Schema Generation + +Automatic schema generation is provided by the `SchemaGenerator` class. + +`SchemaGenerator` processes a list of routed URL pattterns and compiles the +appropriately structured Core API Document. + +Basic usage is just to provide the title for your schema and call +`get_schema()`: + + generator = schemas.SchemaGenerator(title='Flight Search API') + schema = generator.get_schema() + +### Per-View Schema Customisation + +By default, view introspection is performed by an `AutoSchema` instance +accessible via the `schema` attribute on `APIView`. This provides the +appropriate Core API `Link` object for the view, request method and path: + + auto_schema = view.schema + coreapi_link = auto_schema.get_link(...) + +(In compiling the schema, `SchemaGenerator` calls `view.schema.get_link()` for +each view, allowed method and path.) + +To customise the `Link` generation you may: + +* Instantiate `AutoSchema` on your view with the `manual_fields` kwarg: + + from rest_framework.views import APIView + from rest_framework.schemas import AutoSchema + + class CustomView(APIView): + ... + schema = AutoSchema( + manual_fields=[ + coreapi.Field("extra_field", ...), + ] + ) + + This allows extension for the most common case without subclassing. + +* Provide an `AutoSchema` subclass with more complex customisation: + + from rest_framework.views import APIView + from rest_framework.schemas import AutoSchema + + class CustomSchema(AutoSchema): + def get_link(...): + # Implemet custom introspection here (or in other sub-methods) + + class CustomView(APIView): + ... + schema = CustomSchema() + + This provides complete control over view introspection. + +* Instantiate `ManualSchema` on your view, providing the Core API `Fields` for + the view explicitly: + + from rest_framework.views import APIView + from rest_framework.schemas import ManualSchema + + class CustomView(APIView): + ... + schema = ManualSchema(fields=[ + coreapi.Field( + "first_field", + required=True, + location="path", + schema=coreschema.String() + ), + coreapi.Field( + "second_field", + required=True, + location="path", + schema=coreschema.String() + ), + ]) + + This allows manually specifying the schema for some views whilst maintaining + automatic generation elsewhere. + +--- + +**Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and +`ManualSchema` descriptors see the [API Reference below](#api-reference). + +--- + +# Adding a schema view + +There are a few different ways to add a schema view to your API, depending on +exactly what you need. ## The get_schema_view shortcut @@ -342,38 +477,12 @@ A generic viewset with sections in the class docstring, using multi-line style. --- -# Alternate schema formats - -In order to support an alternate schema format, you need to implement a custom renderer -class that handles converting a `Document` instance into a bytestring representation. - -If there is a Core API codec package that supports encoding into the format you -want to use then implementing the renderer class can be done by using the codec. - -## Example - -For example, the `openapi_codec` package provides support for encoding or decoding -to the Open API ("Swagger") format: - - from rest_framework import renderers - from openapi_codec import OpenAPICodec - - class SwaggerRenderer(renderers.BaseRenderer): - media_type = 'application/openapi+json' - format = 'swagger' - - def render(self, data, media_type=None, renderer_context=None): - codec = OpenAPICodec() - return codec.dump(data) - ---- - # API Reference ## SchemaGenerator -A class that deals with introspecting your API views, which can be used to -generate a schema. +A class that walks a list of routed URL patterns, requests the schema for each view, +and collates the resulting CoreAPI Document. Typically you'll instantiate `SchemaGenerator` with a single argument, like so: @@ -406,39 +515,108 @@ Return a nested dictionary containing all the links that should be included in t This is a good point to override if you want to modify the resulting structure of the generated schema, as you can build a new dictionary with a different layout. -### get_link(self, path, method, view) + +## AutoSchema + +A class that deals with introspection of individual views for schema generation. + +`AutoSchema` is attached to `APIView` via the `schema` attribute. + +The `AutoSchema` constructor takes a single keyword argument `manual_fields`. + +**`manual_fields`**: a `list` of `coreapi.Field` instances that will be added to +the generated fields. Generated fields with a matching `name` will be overwritten. + + class CustomView(APIView): + schema = AutoSchema(manual_fields=[ + coreapi.Field( + "my_extra_field", + required=True, + location="path", + schema=coreschema.String() + ), + ]) + +For more advanced customisation subclass `AutoSchema` to customise schema generation. + + class CustomViewSchema(AutoSchema): + """ + Overrides `get_link()` to provide Custom Behavior X + """ + + def get_link(self, path, method, base_url): + link = super().get_link(path, method, base_url) + # Do something to customize link here... + return link + + class MyView(APIView): + schema = CustomViewSchema() + +The following methods are available to override. + +### get_link(self, path, method, base_url) Returns a `coreapi.Link` instance corresponding to the given view. +This is the main entry point. You can override this if you need to provide custom behaviors for particular views. -### get_description(self, path, method, view) +### get_description(self, path, method) Returns a string to use as the link description. By default this is based on the view docstring as described in the "Schemas as Documentation" section above. -### get_encoding(self, path, method, view) +### get_encoding(self, path, method) Returns a string to indicate the encoding for any request body, when interacting with the given view. Eg. `'application/json'`. May return a blank string for views that do not expect a request body. -### get_path_fields(self, path, method, view): +### get_path_fields(self, path, method): Return a list of `coreapi.Link()` instances. One for each path parameter in the URL. -### get_serializer_fields(self, path, method, view) +### get_serializer_fields(self, path, method) Return a list of `coreapi.Link()` instances. One for each field in the serializer class used by the view. -### get_pagination_fields(self, path, method, view +### get_pagination_fields(self, path, method) Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method on any pagination class used by the view. -### get_filter_fields(self, path, method, view) +### get_filter_fields(self, path, method) Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view. + +## ManualSchema + +Allows manually providing a list of `coreapi.Field` instances for the schema, +plus an optional description. + + class MyView(APIView): + schema = ManualSchema(fields=[ + coreapi.Field( + "first_field", + required=True, + location="path", + schema=coreschema.String() + ), + coreapi.Field( + "second_field", + required=True, + location="path", + schema=coreschema.String() + ), + ] + ) + +The `ManualSchema` constructor takes two arguments: + +**`fields`**: A list of `coreapi.Field` instances. Required. + +**`description`**: A string description. Optional. + --- ## Core API diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index 4fa36d0fc8..24dd42578e 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -184,6 +184,28 @@ The available decorators are: Each of these decorators takes a single argument which must be a list or tuple of classes. + +## View schema decorator + +To override the default schema generation for function based views you may use +the `@schema` decorator. This must come *after* (below) the `@api_view` +decorator. For example: + + from rest_framework.decorators import api_view, schema + from rest_framework.schemas import AutoSchema + + class CustomAutoSchema(AutoSchema): + def get_link(self, path, method, base_url): + # override view introspection here... + + @api_view(['GET']) + @schema(CustomAutoSchema()) + def view(request): + return Response({"message": "Hello for today! See you tomorrow!"}) + +This decorator takes a single `AutoSchema` instance, an `AutoSchema` subclass +instance or `ManualSchema` instance as described in the [Schemas documentation][schemas], + [cite]: http://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html [settings]: settings.md diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index bf9b32aaa7..1297f96b4c 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -72,6 +72,9 @@ def handler(self, *args, **kwargs): WrappedAPIView.permission_classes = getattr(func, 'permission_classes', APIView.permission_classes) + WrappedAPIView.schema = getattr(func, 'schema', + APIView.schema) + WrappedAPIView.exclude_from_schema = exclude_from_schema return WrappedAPIView.as_view() return decorator @@ -112,6 +115,13 @@ def decorator(func): return decorator +def schema(view_inspector): + def decorator(func): + func.schema = view_inspector + return func + return decorator + + def detail_route(methods=None, **kwargs): """ Used to mark a method on a ViewSet that should be routed for detail requests. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index a04bffc1ac..01daa7e7d4 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -26,7 +26,8 @@ from rest_framework.compat import NoReverseMatch from rest_framework.response import Response from rest_framework.reverse import reverse -from rest_framework.schemas import SchemaGenerator, SchemaView +from rest_framework.schemas import SchemaGenerator +from rest_framework.schemas.views import SchemaView from rest_framework.settings import api_settings from rest_framework.urlpatterns import format_suffix_patterns diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py new file mode 100644 index 0000000000..fc551640e7 --- /dev/null +++ b/rest_framework/schemas/__init__.py @@ -0,0 +1,43 @@ +""" +rest_framework.schemas + +schemas: + __init__.py + generators.py # Top-down schema generation + inspectors.py # Per-endpoint view introspection + utils.py # Shared helper functions + views.py # Houses `SchemaView`, `APIView` subclass. + +We expose a minimal "public" API directly from `schemas`. This covers the +basic use-cases: + + from rest_framework.schemas import ( + AutoSchema, + ManualSchema, + get_schema_view, + SchemaGenerator, + ) + +Other access should target the submodules directly +""" +from .generators import SchemaGenerator +from .inspectors import AutoSchema, ManualSchema # noqa + + +def get_schema_view( + title=None, url=None, description=None, urlconf=None, renderer_classes=None, + public=False, patterns=None, generator_class=SchemaGenerator): + """ + Return a schema view. + """ + # Avoid import cycle on APIView + from .views import SchemaView + generator = generator_class( + title=title, url=url, description=description, + urlconf=urlconf, patterns=patterns, + ) + return SchemaView.as_view( + renderer_classes=renderer_classes, + schema_generator=generator, + public=public, + ) diff --git a/rest_framework/schemas.py b/rest_framework/schemas/generators.py similarity index 51% rename from rest_framework/schemas.py rename to rest_framework/schemas/generators.py index 4374133554..8344f64f0e 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas/generators.py @@ -1,86 +1,26 @@ -import re +""" +generators.py # Top-down schema generation + +See schemas.__init__.py for package overview. +""" from collections import OrderedDict from importlib import import_module from django.conf import settings from django.contrib.admindocs.views import simplify_regex from django.core.exceptions import PermissionDenied -from django.db import models from django.http import Http404 from django.utils import six -from django.utils.encoding import force_text, smart_text -from django.utils.translation import ugettext_lazy as _ -from rest_framework import exceptions, renderers, serializers +from rest_framework import exceptions from rest_framework.compat import ( - RegexURLPattern, RegexURLResolver, coreapi, coreschema, uritemplate, - urlparse + RegexURLPattern, RegexURLResolver, coreapi, coreschema ) from rest_framework.request import clone_request -from rest_framework.response import Response from rest_framework.settings import api_settings -from rest_framework.utils import formatting from rest_framework.utils.model_meta import _get_pk -from rest_framework.views import APIView - -header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') - -def field_to_schema(field): - title = force_text(field.label) if field.label else '' - description = force_text(field.help_text) if field.help_text else '' - - if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = field_to_schema(field.child) - return coreschema.Array( - items=child_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.Serializer): - return coreschema.Object( - properties=OrderedDict([ - (key, field_to_schema(value)) - for key, value - in field.fields.items() - ]), - title=title, - description=description - ) - elif isinstance(field, serializers.ManyRelatedField): - return coreschema.Array( - items=coreschema.String(), - title=title, - description=description - ) - elif isinstance(field, serializers.RelatedField): - return coreschema.String(title=title, description=description) - elif isinstance(field, serializers.MultipleChoiceField): - return coreschema.Array( - items=coreschema.Enum(enum=list(field.choices.keys())), - title=title, - description=description - ) - elif isinstance(field, serializers.ChoiceField): - return coreschema.Enum( - enum=list(field.choices.keys()), - title=title, - description=description - ) - elif isinstance(field, serializers.BooleanField): - return coreschema.Boolean(title=title, description=description) - elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): - return coreschema.Number(title=title, description=description) - elif isinstance(field, serializers.IntegerField): - return coreschema.Integer(title=title, description=description) - - if field.style.get('base_template') == 'textarea.html': - return coreschema.String( - title=title, - description=description, - format='textarea' - ) - return coreschema.String(title=title, description=description) +from .utils import is_list_view def common_path(paths): @@ -104,6 +44,8 @@ def is_api_view(callback): """ Return `True` if the given view callback is a REST framework view/viewset. """ + # Avoid import cycle on APIView + from rest_framework.views import APIView cls = getattr(callback, 'cls', None) return (cls is not None) and issubclass(cls, APIView) @@ -130,22 +72,6 @@ def is_custom_action(action): ]) -def is_list_view(path, method, view): - """ - Return True if the given path/method appears to represent a list view. - """ - if hasattr(view, 'action'): - # Viewsets have an explicitly defined action, which we can inspect. - return view.action == 'list' - - if method.lower() != 'get': - return False - path_components = path.strip('/').split('/') - if path_components and '{' in path_components[-1]: - return False - return True - - def endpoint_ordering(endpoint): path, method, callback = endpoint method_priority = { @@ -158,21 +84,7 @@ def endpoint_ordering(endpoint): return (path, method_priority) -def get_pk_description(model, model_field): - if isinstance(model_field, models.AutoField): - value_type = _('unique integer value') - elif isinstance(model_field, models.UUIDField): - value_type = _('UUID string') - else: - value_type = _('unique value') - - return _('A {value_type} identifying this {name}.').format( - value_type=value_type, - name=model._meta.verbose_name, - ) - - -class EndpointInspector(object): +class EndpointEnumerator(object): """ A class to determine the available API endpoints that a project exposes. """ @@ -265,7 +177,7 @@ class SchemaGenerator(object): 'patch': 'partial_update', 'delete': 'destroy', } - endpoint_inspector_cls = EndpointInspector + endpoint_inspector_cls = EndpointEnumerator # Map the method names we use for viewset actions onto external schema names. # These give us names that are more suitable for the external representation. @@ -341,7 +253,7 @@ def get_links(self, request=None): for path, method, view in view_endpoints: if not self.has_view_permissions(path, method, view): continue - link = self.get_link(path, method, view) + link = view.schema.get_link(path, method, base_url=self.url) subpath = path[len(prefix):] keys = self.get_keys(subpath, method, view) insert_into(links, keys, link) @@ -433,197 +345,6 @@ def coerce_path(self, path, method, view): field_name = 'id' return path.replace('{pk}', '{%s}' % field_name) - # Methods for generating each individual `Link` instance... - - def get_link(self, path, method, view): - """ - Return a `coreapi.Link` instance for the given endpoint. - """ - fields = self.get_path_fields(path, method, view) - fields += self.get_serializer_fields(path, method, view) - fields += self.get_pagination_fields(path, method, view) - fields += self.get_filter_fields(path, method, view) - - if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method, view) - else: - encoding = None - - description = self.get_description(path, method, view) - - if self.url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=urlparse.urljoin(self.url, path), - action=method.lower(), - encoding=encoding, - fields=fields, - description=description - ) - - def get_description(self, path, method, view): - """ - Determine a link description. - - This will be based on the method docstring if one exists, - or else the class docstring. - """ - method_name = getattr(view, 'action', method.lower()) - method_docstring = getattr(view, method_name, None).__doc__ - if method_docstring: - # An explicit docstring on the method or action. - return formatting.dedent(smart_text(method_docstring)) - - description = view.get_view_description() - lines = [line.strip() for line in description.splitlines()] - current_section = '' - sections = {'': ''} - - for line in lines: - if header_regex.match(line): - current_section, seperator, lead = line.partition(':') - sections[current_section] = lead.strip() - else: - sections[current_section] += '\n' + line - - header = getattr(view, 'action', method.lower()) - if header in sections: - return sections[header].strip() - if header in self.coerce_method_names: - if self.coerce_method_names[header] in sections: - return sections[self.coerce_method_names[header]].strip() - return sections[''].strip() - - def get_encoding(self, path, method, view): - """ - Return the 'encoding' parameter to use for a given endpoint. - """ - # Core API supports the following request encodings over HTTP... - supported_media_types = set(( - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data', - )) - parser_classes = getattr(view, 'parser_classes', []) - for parser_class in parser_classes: - media_type = getattr(parser_class, 'media_type', None) - if media_type in supported_media_types: - return media_type - # Raw binary uploads are supported with "application/octet-stream" - if media_type == '*/*': - return 'application/octet-stream' - - return None - - def get_path_fields(self, path, method, view): - """ - Return a list of `coreapi.Field` instances corresponding to any - templated path variables. - """ - model = getattr(getattr(view, 'queryset', None), 'model', None) - fields = [] - - for variable in uritemplate.variables(path): - title = '' - description = '' - schema_cls = coreschema.String - kwargs = {} - if model is not None: - # Attempt to infer a field description if possible. - try: - model_field = model._meta.get_field(variable) - except: - model_field = None - - if model_field is not None and model_field.verbose_name: - title = force_text(model_field.verbose_name) - - if model_field is not None and model_field.help_text: - description = force_text(model_field.help_text) - elif model_field is not None and model_field.primary_key: - description = get_pk_description(model, model_field) - - if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex - elif isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - - field = coreapi.Field( - name=variable, - location='path', - required=True, - schema=schema_cls(title=title, description=description, **kwargs) - ) - fields.append(field) - - return fields - - def get_serializer_fields(self, path, method, view): - """ - Return a list of `coreapi.Field` instances corresponding to any - request body input, as determined by the serializer class. - """ - if method not in ('PUT', 'PATCH', 'POST'): - return [] - - if not hasattr(view, 'get_serializer'): - return [] - - serializer = view.get_serializer() - - if isinstance(serializer, serializers.ListSerializer): - return [ - coreapi.Field( - name='data', - location='body', - required=True, - schema=coreschema.Array() - ) - ] - - if not isinstance(serializer, serializers.Serializer): - return [] - - fields = [] - for field in serializer.fields.values(): - if field.read_only or isinstance(field, serializers.HiddenField): - continue - - required = field.required and method != 'PATCH' - field = coreapi.Field( - name=field.field_name, - location='form', - required=required, - schema=field_to_schema(field) - ) - fields.append(field) - - return fields - - def get_pagination_fields(self, path, method, view): - if not is_list_view(path, method, view): - return [] - - pagination = getattr(view, 'pagination_class', None) - if not pagination: - return [] - - paginator = view.pagination_class() - return paginator.get_schema_fields(view) - - def get_filter_fields(self, path, method, view): - if not is_list_view(path, method, view): - return [] - - if not getattr(view, 'filter_backends', None): - return [] - - fields = [] - for filter_backend in view.filter_backends: - fields += filter_backend().get_schema_fields(view) - return fields - # Method for generating the link layout.... def get_keys(self, subpath, method, view): @@ -669,45 +390,3 @@ def get_keys(self, subpath, method, view): # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] - - -class SchemaView(APIView): - _ignore_model_permissions = True - exclude_from_schema = True - renderer_classes = None - schema_generator = None - public = False - - def __init__(self, *args, **kwargs): - super(SchemaView, self).__init__(*args, **kwargs) - if self.renderer_classes is None: - if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: - self.renderer_classes = [ - renderers.CoreJSONRenderer, - renderers.BrowsableAPIRenderer, - ] - else: - self.renderer_classes = [renderers.CoreJSONRenderer] - - def get(self, request, *args, **kwargs): - schema = self.schema_generator.get_schema(request, self.public) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - - -def get_schema_view( - title=None, url=None, description=None, urlconf=None, renderer_classes=None, - public=False, patterns=None, generator_class=SchemaGenerator): - """ - Return a schema view. - """ - generator = generator_class( - title=title, url=url, description=description, - urlconf=urlconf, patterns=patterns, - ) - return SchemaView.as_view( - renderer_classes=renderer_classes, - schema_generator=generator, - public=public, - ) diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py new file mode 100644 index 0000000000..cd9fa73da9 --- /dev/null +++ b/rest_framework/schemas/inspectors.py @@ -0,0 +1,399 @@ +""" +inspectors.py # Per-endpoint view introspection + +See schemas.__init__.py for package overview. +""" +import re +from collections import OrderedDict + +from django.db import models +from django.utils.encoding import force_text, smart_text +from django.utils.translation import ugettext_lazy as _ + +from rest_framework import serializers +from rest_framework.compat import coreapi, coreschema, uritemplate, urlparse +from rest_framework.settings import api_settings +from rest_framework.utils import formatting + +from .utils import is_list_view + +header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') + + +def field_to_schema(field): + title = force_text(field.label) if field.label else '' + description = force_text(field.help_text) if field.help_text else '' + + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = field_to_schema(field.child) + return coreschema.Array( + items=child_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.Serializer): + return coreschema.Object( + properties=OrderedDict([ + (key, field_to_schema(value)) + for key, value + in field.fields.items() + ]), + title=title, + description=description + ) + elif isinstance(field, serializers.ManyRelatedField): + return coreschema.Array( + items=coreschema.String(), + title=title, + description=description + ) + elif isinstance(field, serializers.RelatedField): + return coreschema.String(title=title, description=description) + elif isinstance(field, serializers.MultipleChoiceField): + return coreschema.Array( + items=coreschema.Enum(enum=list(field.choices.keys())), + title=title, + description=description + ) + elif isinstance(field, serializers.ChoiceField): + return coreschema.Enum( + enum=list(field.choices.keys()), + title=title, + description=description + ) + elif isinstance(field, serializers.BooleanField): + return coreschema.Boolean(title=title, description=description) + elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): + return coreschema.Number(title=title, description=description) + elif isinstance(field, serializers.IntegerField): + return coreschema.Integer(title=title, description=description) + + if field.style.get('base_template') == 'textarea.html': + return coreschema.String( + title=title, + description=description, + format='textarea' + ) + return coreschema.String(title=title, description=description) + + +def get_pk_description(model, model_field): + if isinstance(model_field, models.AutoField): + value_type = _('unique integer value') + elif isinstance(model_field, models.UUIDField): + value_type = _('UUID string') + else: + value_type = _('unique value') + + return _('A {value_type} identifying this {name}.').format( + value_type=value_type, + name=model._meta.verbose_name, + ) + + +class ViewInspector(object): + """ + Descriptor class on APIView. + + Provide subclass for per-view schema generation + """ + def __get__(self, instance, owner): + """ + Enables `ViewInspector` as a Python _Descriptor_. + + This is how `view.schema` knows about `view`. + + `__get__` is called when the descriptor is accessed on the owner. + (That will be when view.schema is called in our case.) + + `owner` is always the owner class. (An APIView, or subclass for us.) + `instance` is the view instance or `None` if accessed from the class, + rather than an instance. + + See: https://docs.python.org/3/howto/descriptor.html for info on + descriptor usage. + """ + self.view = instance + return self + + @property + def view(self): + """View property.""" + assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)" + return self._view + + @view.setter + def view(self, value): + self._view = value + + @view.deleter + def view(self): + self._view = None + + def get_link(self, path, method, base_url): + """ + Generate `coreapi.Link` for self.view, path and method. + + This is the main _public_ access point. + + Parameters: + + * path: Route path for view from URLConf. + * method: The HTTP request method. + * base_url: The project "mount point" as given to SchemaGenerator + """ + raise NotImplementedError(".get_link() must be overridden.") + + +class AutoSchema(ViewInspector): + """ + Default inspector for APIView + + Responsible for per-view instrospection and schema generation. + """ + def __init__(self, manual_fields=None): + """ + Parameters: + + * `manual_fields`: list of `coreapi.Field` instances that + will be added to auto-generated fields, overwriting on `Field.name` + """ + + self._manual_fields = manual_fields + + def get_link(self, path, method, base_url): + fields = self.get_path_fields(path, method) + fields += self.get_serializer_fields(path, method) + fields += self.get_pagination_fields(path, method) + fields += self.get_filter_fields(path, method) + + if self._manual_fields is not None: + by_name = {f.name: f for f in fields} + for f in self._manual_fields: + by_name[f.name] = f + fields = list(by_name.values()) + + if fields and any([field.location in ('form', 'body') for field in fields]): + encoding = self.get_encoding(path, method) + else: + encoding = None + + description = self.get_description(path, method) + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=urlparse.urljoin(base_url, path), + action=method.lower(), + encoding=encoding, + fields=fields, + description=description + ) + + def get_description(self, path, method): + """ + Determine a link description. + + This will be based on the method docstring if one exists, + or else the class docstring. + """ + view = self.view + + method_name = getattr(view, 'action', method.lower()) + method_docstring = getattr(view, method_name, None).__doc__ + if method_docstring: + # An explicit docstring on the method or action. + return formatting.dedent(smart_text(method_docstring)) + + description = view.get_view_description() + lines = [line.strip() for line in description.splitlines()] + current_section = '' + sections = {'': ''} + + for line in lines: + if header_regex.match(line): + current_section, seperator, lead = line.partition(':') + sections[current_section] = lead.strip() + else: + sections[current_section] += '\n' + line + + # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` + coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + header = getattr(view, 'action', method.lower()) + if header in sections: + return sections[header].strip() + if header in coerce_method_names: + if coerce_method_names[header] in sections: + return sections[coerce_method_names[header]].strip() + return sections[''].strip() + + def get_path_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + templated path variables. + """ + view = self.view + model = getattr(getattr(view, 'queryset', None), 'model', None) + fields = [] + + for variable in uritemplate.variables(path): + title = '' + description = '' + schema_cls = coreschema.String + kwargs = {} + if model is not None: + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except: + model_field = None + + if model_field is not None and model_field.verbose_name: + title = force_text(model_field.verbose_name) + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: + kwargs['pattern'] = view.lookup_value_regex + elif isinstance(model_field, models.AutoField): + schema_cls = coreschema.Integer + + field = coreapi.Field( + name=variable, + location='path', + required=True, + schema=schema_cls(title=title, description=description, **kwargs) + ) + fields.append(field) + + return fields + + def get_serializer_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + request body input, as determined by the serializer class. + """ + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return [] + + if not hasattr(view, 'get_serializer'): + return [] + + serializer = view.get_serializer() + + if isinstance(serializer, serializers.ListSerializer): + return [ + coreapi.Field( + name='data', + location='body', + required=True, + schema=coreschema.Array() + ) + ] + + if not isinstance(serializer, serializers.Serializer): + return [] + + fields = [] + for field in serializer.fields.values(): + if field.read_only or isinstance(field, serializers.HiddenField): + continue + + required = field.required and method != 'PATCH' + field = coreapi.Field( + name=field.field_name, + location='form', + required=required, + schema=field_to_schema(field) + ) + fields.append(field) + + return fields + + def get_pagination_fields(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_fields(view) + + def get_filter_fields(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + if not getattr(view, 'filter_backends', None): + return [] + + fields = [] + for filter_backend in view.filter_backends: + fields += filter_backend().get_schema_fields(view) + return fields + + def get_encoding(self, path, method): + """ + Return the 'encoding' parameter to use for a given endpoint. + """ + view = self.view + + # Core API supports the following request encodings over HTTP... + supported_media_types = set(( + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data', + )) + parser_classes = getattr(view, 'parser_classes', []) + for parser_class in parser_classes: + media_type = getattr(parser_class, 'media_type', None) + if media_type in supported_media_types: + return media_type + # Raw binary uploads are supported with "application/octet-stream" + if media_type == '*/*': + return 'application/octet-stream' + + return None + + +class ManualSchema(ViewInspector): + """ + Allows providing a list of coreapi.Fields, + plus an optional description. + """ + def __init__(self, fields, description=''): + """ + Parameters: + + * `fields`: list of `coreapi.Field` instances. + * `descripton`: String description for view. Optional. + """ + assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" + self._fields = fields + self._description = description + + def get_link(self, path, method, base_url): + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=urlparse.urljoin(base_url, path), + action=method.lower(), + encoding=None, + fields=self._fields, + description=self._description + ) + + return self._link diff --git a/rest_framework/schemas/utils.py b/rest_framework/schemas/utils.py new file mode 100644 index 0000000000..1542b6154b --- /dev/null +++ b/rest_framework/schemas/utils.py @@ -0,0 +1,21 @@ +""" +utils.py # Shared helper functions + +See schemas.__init__.py for package overview. +""" + + +def is_list_view(path, method, view): + """ + Return True if the given path/method appears to represent a list view. + """ + if hasattr(view, 'action'): + # Viewsets have an explicitly defined action, which we can inspect. + return view.action == 'list' + + if method.lower() != 'get': + return False + path_components = path.strip('/').split('/') + if path_components and '{' in path_components[-1]: + return False + return True diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py new file mode 100644 index 0000000000..932b5a4871 --- /dev/null +++ b/rest_framework/schemas/views.py @@ -0,0 +1,34 @@ +""" +views.py # Houses `SchemaView`, `APIView` subclass. + +See schemas.__init__.py for package overview. +""" +from rest_framework import exceptions, renderers +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.views import APIView + + +class SchemaView(APIView): + _ignore_model_permissions = True + exclude_from_schema = True + renderer_classes = None + schema_generator = None + public = False + + def __init__(self, *args, **kwargs): + super(SchemaView, self).__init__(*args, **kwargs) + if self.renderer_classes is None: + if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: + self.renderer_classes = [ + renderers.CoreJSONRenderer, + renderers.BrowsableAPIRenderer, + ] + else: + self.renderer_classes = [renderers.CoreJSONRenderer] + + def get(self, request, *args, **kwargs): + schema = self.schema_generator.get_schema(request, self.public) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) diff --git a/rest_framework/views.py b/rest_framework/views.py index 8ec5f14ab2..ccc2047eec 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -19,6 +19,7 @@ from rest_framework.compat import set_rollback from rest_framework.request import Request from rest_framework.response import Response +from rest_framework.schemas import AutoSchema from rest_framework.settings import api_settings from rest_framework.utils import formatting @@ -113,6 +114,7 @@ class APIView(View): # Mark the view as being included or excluded from schema generation. exclude_from_schema = False + schema = AutoSchema() @classmethod def as_view(cls, **initkwargs): diff --git a/tests/test_decorators.py b/tests/test_decorators.py index b187e5fd6a..6331742db2 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -6,12 +6,13 @@ from rest_framework.authentication import BasicAuthentication from rest_framework.decorators import ( api_view, authentication_classes, parser_classes, permission_classes, - renderer_classes, throttle_classes + renderer_classes, schema, throttle_classes ) from rest_framework.parsers import JSONParser from rest_framework.permissions import IsAuthenticated from rest_framework.renderers import JSONRenderer from rest_framework.response import Response +from rest_framework.schemas import AutoSchema from rest_framework.test import APIRequestFactory from rest_framework.throttling import UserRateThrottle from rest_framework.views import APIView @@ -151,3 +152,17 @@ def view(request): response = view(request) assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + def test_schema(self): + """ + Checks CustomSchema class is set on view + """ + class CustomSchema(AutoSchema): + pass + + @api_view(['GET']) + @schema(CustomSchema()) + def view(request): + return Response({}) + + assert isinstance(view.cls.schema, CustomSchema) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index b435dfdd78..14ed0f6b6f 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,5 +1,6 @@ import unittest +import pytest from django.conf.urls import include, url from django.core.exceptions import PermissionDenied from django.http import Http404 @@ -10,7 +11,9 @@ from rest_framework.decorators import detail_route, list_route from rest_framework.request import Request from rest_framework.routers import DefaultRouter -from rest_framework.schemas import SchemaGenerator, get_schema_view +from rest_framework.schemas import ( + AutoSchema, ManualSchema, SchemaGenerator, get_schema_view +) from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView from rest_framework.viewsets import ModelViewSet @@ -496,3 +499,81 @@ def test_4605_regression(self): '/auth/convert-token/' ]) assert prefix == '/' + + +class TestDescriptor(TestCase): + + def test_apiview_schema_descriptor(self): + view = APIView() + assert hasattr(view, 'schema') + assert isinstance(view.schema, AutoSchema) + + def test_get_link_requires_instance(self): + descriptor = APIView.schema # Accessed from class + with pytest.raises(AssertionError): + descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? + + def test_manual_fields(self): + + class CustomView(APIView): + schema = AutoSchema(manual_fields=[ + coreapi.Field( + "my_extra_field", + required=True, + location="path", + schema=coreschema.String() + ), + ]) + + view = CustomView() + link = view.schema.get_link('/a/url/{id}/', 'GET', '') + fields = link.fields + + assert len(fields) == 2 + assert "my_extra_field" in [f.name for f in fields] + + def test_view_with_manual_schema(self): + + path = '/example' + method = 'get' + base_url = None + + fields = [ + coreapi.Field( + "first_field", + required=True, + location="path", + schema=coreschema.String() + ), + coreapi.Field( + "second_field", + required=True, + location="path", + schema=coreschema.String() + ), + coreapi.Field( + "third_field", + required=True, + location="path", + schema=coreschema.String() + ), + ] + description = "A test endpoint" + + class CustomView(APIView): + """ + ManualSchema takes list of fields for endpoint. + - Provides url and action, which are always dynamic + """ + schema = ManualSchema(fields, description) + + expected = coreapi.Link( + url=path, + action=method, + fields=fields, + description=description + ) + + view = CustomView() + link = view.schema.get_link(path, method, base_url) + assert link == expected