123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- """
- generators.py # Top-down schema generation
- See schemas.__init__.py for package overview.
- """
- import re
- from importlib import import_module
- from django.conf import settings
- from django.contrib.admindocs.views import simplify_regex
- from django.core.exceptions import PermissionDenied
- from django.http import Http404
- from rest_framework import exceptions
- from rest_framework.compat import URLPattern, URLResolver, get_original_route
- from rest_framework.request import clone_request
- from rest_framework.settings import api_settings
- from rest_framework.utils.model_meta import _get_pk
- def get_pk_name(model):
- meta = model._meta.concrete_model._meta
- return _get_pk(meta).name
- def is_api_view(callback):
- """
- Return `True` if the given view callback is a REST framework view/viewset.
- """
- # Avoid import cycle on APIView
- from rest_framework.views import APIView
- cls = getattr(callback, 'cls', None)
- return (cls is not None) and issubclass(cls, APIView)
- def endpoint_ordering(endpoint):
- path, method, callback = endpoint
- method_priority = {
- 'GET': 0,
- 'POST': 1,
- 'PUT': 2,
- 'PATCH': 3,
- 'DELETE': 4
- }.get(method, 5)
- return (method_priority,)
- _PATH_PARAMETER_COMPONENT_RE = re.compile(
- r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
- )
- class EndpointEnumerator:
- """
- A class to determine the available API endpoints that a project exposes.
- """
- def __init__(self, patterns=None, urlconf=None):
- if patterns is None:
- if urlconf is None:
- # Use the default Django URL conf
- urlconf = settings.ROOT_URLCONF
- # Load the given URLconf module
- if isinstance(urlconf, str):
- urls = import_module(urlconf)
- else:
- urls = urlconf
- patterns = urls.urlpatterns
- self.patterns = patterns
- def get_api_endpoints(self, patterns=None, prefix=''):
- """
- Return a list of all available API endpoints by inspecting the URL conf.
- """
- if patterns is None:
- patterns = self.patterns
- api_endpoints = []
- for pattern in patterns:
- path_regex = prefix + get_original_route(pattern)
- if isinstance(pattern, URLPattern):
- path = self.get_path_from_regex(path_regex)
- callback = pattern.callback
- if self.should_include_endpoint(path, callback):
- for method in self.get_allowed_methods(callback):
- endpoint = (path, method, callback)
- api_endpoints.append(endpoint)
- elif isinstance(pattern, URLResolver):
- nested_endpoints = self.get_api_endpoints(
- patterns=pattern.url_patterns,
- prefix=path_regex
- )
- api_endpoints.extend(nested_endpoints)
- return sorted(api_endpoints, key=endpoint_ordering)
- def get_path_from_regex(self, path_regex):
- """
- Given a URL conf regex, return a URI template string.
- """
- # ???: Would it be feasible to adjust this such that we generate the
- # path, plus the kwargs, plus the type from the convertor, such that we
- # could feed that straight into the parameter schema object?
- path = simplify_regex(path_regex)
- # Strip Django 2.0 convertors as they are incompatible with uritemplate format
- return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
- def should_include_endpoint(self, path, callback):
- """
- Return `True` if the given endpoint should be included.
- """
- if not is_api_view(callback):
- return False # Ignore anything except REST framework views.
- if callback.cls.schema is None:
- return False
- if 'schema' in callback.initkwargs:
- if callback.initkwargs['schema'] is None:
- return False
- if path.endswith('.{format}') or path.endswith('.{format}/'):
- return False # Ignore .json style URLs.
- return True
- def get_allowed_methods(self, callback):
- """
- Return a list of the valid HTTP methods for this endpoint.
- """
- if hasattr(callback, 'actions'):
- actions = set(callback.actions)
- http_method_names = set(callback.cls.http_method_names)
- methods = [method.upper() for method in actions & http_method_names]
- else:
- methods = callback.cls().allowed_methods
- return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
- class BaseSchemaGenerator(object):
- endpoint_inspector_cls = EndpointEnumerator
- # 'pk' isn't great as an externally exposed name for an identifier,
- # so by default we prefer to use the actual model field name for schemas.
- # Set by 'SCHEMA_COERCE_PATH_PK'.
- coerce_path_pk = None
- def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
- if url and not url.endswith('/'):
- url += '/'
- self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
- self.patterns = patterns
- self.urlconf = urlconf
- self.title = title
- self.description = description
- self.version = version
- self.url = url
- self.endpoints = None
- def _initialise_endpoints(self):
- if self.endpoints is None:
- inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
- self.endpoints = inspector.get_api_endpoints()
- def _get_paths_and_endpoints(self, request):
- """
- Generate (path, method, view) given (path, method, callback) for paths.
- """
- paths = []
- view_endpoints = []
- for path, method, callback in self.endpoints:
- view = self.create_view(callback, method, request)
- path = self.coerce_path(path, method, view)
- paths.append(path)
- view_endpoints.append((path, method, view))
- return paths, view_endpoints
- def create_view(self, callback, method, request=None):
- """
- Given a callback, return an actual view instance.
- """
- view = callback.cls(**getattr(callback, 'initkwargs', {}))
- view.args = ()
- view.kwargs = {}
- view.format_kwarg = None
- view.request = None
- view.action_map = getattr(callback, 'actions', None)
- actions = getattr(callback, 'actions', None)
- if actions is not None:
- if method == 'OPTIONS':
- view.action = 'metadata'
- else:
- view.action = actions.get(method.lower())
- if request is not None:
- view.request = clone_request(request, method)
- return view
- def coerce_path(self, path, method, view):
- """
- Coerce {pk} path arguments into the name of the model field,
- where possible. This is cleaner for an external representation.
- (Ie. "this is an identifier", not "this is a database primary key")
- """
- if not self.coerce_path_pk or '{pk}' not in path:
- return path
- model = getattr(getattr(view, 'queryset', None), 'model', None)
- if model:
- field_name = get_pk_name(model)
- else:
- field_name = 'id'
- return path.replace('{pk}', '{%s}' % field_name)
- def get_schema(self, request=None, public=False):
- raise NotImplementedError(".get_schema() must be implemented in subclasses.")
- def has_view_permissions(self, path, method, view):
- """
- Return `True` if the incoming request has the correct view permissions.
- """
- if view.request is None:
- return True
- try:
- view.check_permissions(view.request)
- except (exceptions.APIException, Http404, PermissionDenied):
- return False
- return True
|