openapi.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import warnings
  2. from operator import attrgetter
  3. from urllib.parse import urljoin
  4. from django.core.validators import (
  5. DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
  6. MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
  7. )
  8. from django.db import models
  9. from django.utils.encoding import force_str
  10. from rest_framework import exceptions, renderers, serializers
  11. from rest_framework.compat import uritemplate
  12. from rest_framework.fields import _UnvalidatedField, empty
  13. from .generators import BaseSchemaGenerator
  14. from .inspectors import ViewInspector
  15. from .utils import get_pk_description, is_list_view
  16. class SchemaGenerator(BaseSchemaGenerator):
  17. def get_info(self):
  18. # Title and version are required by openapi specification 3.x
  19. info = {
  20. 'title': self.title or '',
  21. 'version': self.version or ''
  22. }
  23. if self.description is not None:
  24. info['description'] = self.description
  25. return info
  26. def get_paths(self, request=None):
  27. result = {}
  28. paths, view_endpoints = self._get_paths_and_endpoints(request)
  29. # Only generate the path prefix for paths that will be included
  30. if not paths:
  31. return None
  32. for path, method, view in view_endpoints:
  33. if not self.has_view_permissions(path, method, view):
  34. continue
  35. operation = view.schema.get_operation(path, method)
  36. # Normalise path for any provided mount url.
  37. if path.startswith('/'):
  38. path = path[1:]
  39. path = urljoin(self.url or '/', path)
  40. result.setdefault(path, {})
  41. result[path][method.lower()] = operation
  42. return result
  43. def get_schema(self, request=None, public=False):
  44. """
  45. Generate a OpenAPI schema.
  46. """
  47. self._initialise_endpoints()
  48. paths = self.get_paths(None if public else request)
  49. if not paths:
  50. return None
  51. schema = {
  52. 'openapi': '3.0.2',
  53. 'info': self.get_info(),
  54. 'paths': paths,
  55. }
  56. return schema
  57. # View Inspectors
  58. class AutoSchema(ViewInspector):
  59. request_media_types = []
  60. response_media_types = []
  61. method_mapping = {
  62. 'get': 'Retrieve',
  63. 'post': 'Create',
  64. 'put': 'Update',
  65. 'patch': 'PartialUpdate',
  66. 'delete': 'Destroy',
  67. }
  68. def get_operation(self, path, method):
  69. operation = {}
  70. operation['operationId'] = self._get_operation_id(path, method)
  71. operation['description'] = self.get_description(path, method)
  72. parameters = []
  73. parameters += self._get_path_parameters(path, method)
  74. parameters += self._get_pagination_parameters(path, method)
  75. parameters += self._get_filter_parameters(path, method)
  76. operation['parameters'] = parameters
  77. request_body = self._get_request_body(path, method)
  78. if request_body:
  79. operation['requestBody'] = request_body
  80. operation['responses'] = self._get_responses(path, method)
  81. return operation
  82. def _get_operation_id(self, path, method):
  83. """
  84. Compute an operation ID from the model, serializer or view name.
  85. """
  86. method_name = getattr(self.view, 'action', method.lower())
  87. if is_list_view(path, method, self.view):
  88. action = 'list'
  89. elif method_name not in self.method_mapping:
  90. action = method_name
  91. else:
  92. action = self.method_mapping[method.lower()]
  93. # Try to deduce the ID from the view's model
  94. model = getattr(getattr(self.view, 'queryset', None), 'model', None)
  95. if model is not None:
  96. name = model.__name__
  97. # Try with the serializer class name
  98. elif hasattr(self.view, 'get_serializer_class'):
  99. name = self.view.get_serializer_class().__name__
  100. if name.endswith('Serializer'):
  101. name = name[:-10]
  102. # Fallback to the view name
  103. else:
  104. name = self.view.__class__.__name__
  105. if name.endswith('APIView'):
  106. name = name[:-7]
  107. elif name.endswith('View'):
  108. name = name[:-4]
  109. # Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
  110. # comes at the end of the name
  111. if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
  112. name = name[:-len(action)]
  113. if action == 'list' and not name.endswith('s'): # listThings instead of listThing
  114. name += 's'
  115. return action + name
  116. def _get_path_parameters(self, path, method):
  117. """
  118. Return a list of parameters from templated path variables.
  119. """
  120. assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
  121. model = getattr(getattr(self.view, 'queryset', None), 'model', None)
  122. parameters = []
  123. for variable in uritemplate.variables(path):
  124. description = ''
  125. if model is not None: # TODO: test this.
  126. # Attempt to infer a field description if possible.
  127. try:
  128. model_field = model._meta.get_field(variable)
  129. except Exception:
  130. model_field = None
  131. if model_field is not None and model_field.help_text:
  132. description = force_str(model_field.help_text)
  133. elif model_field is not None and model_field.primary_key:
  134. description = get_pk_description(model, model_field)
  135. parameter = {
  136. "name": variable,
  137. "in": "path",
  138. "required": True,
  139. "description": description,
  140. 'schema': {
  141. 'type': 'string', # TODO: integer, pattern, ...
  142. },
  143. }
  144. parameters.append(parameter)
  145. return parameters
  146. def _get_filter_parameters(self, path, method):
  147. if not self._allows_filters(path, method):
  148. return []
  149. parameters = []
  150. for filter_backend in self.view.filter_backends:
  151. parameters += filter_backend().get_schema_operation_parameters(self.view)
  152. return parameters
  153. def _allows_filters(self, path, method):
  154. """
  155. Determine whether to include filter Fields in schema.
  156. Default implementation looks for ModelViewSet or GenericAPIView
  157. actions/methods that cause filtering on the default implementation.
  158. """
  159. if getattr(self.view, 'filter_backends', None) is None:
  160. return False
  161. if hasattr(self.view, 'action'):
  162. return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
  163. return method.lower() in ["get", "put", "patch", "delete"]
  164. def _get_pagination_parameters(self, path, method):
  165. view = self.view
  166. if not is_list_view(path, method, view):
  167. return []
  168. paginator = self._get_paginator()
  169. if not paginator:
  170. return []
  171. return paginator.get_schema_operation_parameters(view)
  172. def _map_field(self, field):
  173. # Nested Serializers, `many` or not.
  174. if isinstance(field, serializers.ListSerializer):
  175. return {
  176. 'type': 'array',
  177. 'items': self._map_serializer(field.child)
  178. }
  179. if isinstance(field, serializers.Serializer):
  180. data = self._map_serializer(field)
  181. data['type'] = 'object'
  182. return data
  183. # Related fields.
  184. if isinstance(field, serializers.ManyRelatedField):
  185. return {
  186. 'type': 'array',
  187. 'items': self._map_field(field.child_relation)
  188. }
  189. if isinstance(field, serializers.PrimaryKeyRelatedField):
  190. model = getattr(field.queryset, 'model', None)
  191. if model is not None:
  192. model_field = model._meta.pk
  193. if isinstance(model_field, models.AutoField):
  194. return {'type': 'integer'}
  195. # ChoiceFields (single and multiple).
  196. # Q:
  197. # - Is 'type' required?
  198. # - can we determine the TYPE of a choicefield?
  199. if isinstance(field, serializers.MultipleChoiceField):
  200. return {
  201. 'type': 'array',
  202. 'items': {
  203. 'enum': list(field.choices)
  204. },
  205. }
  206. if isinstance(field, serializers.ChoiceField):
  207. return {
  208. 'enum': list(field.choices),
  209. }
  210. # ListField.
  211. if isinstance(field, serializers.ListField):
  212. mapping = {
  213. 'type': 'array',
  214. 'items': {},
  215. }
  216. if not isinstance(field.child, _UnvalidatedField):
  217. map_field = self._map_field(field.child)
  218. items = {
  219. "type": map_field.get('type')
  220. }
  221. if 'format' in map_field:
  222. items['format'] = map_field.get('format')
  223. mapping['items'] = items
  224. return mapping
  225. # DateField and DateTimeField type is string
  226. if isinstance(field, serializers.DateField):
  227. return {
  228. 'type': 'string',
  229. 'format': 'date',
  230. }
  231. if isinstance(field, serializers.DateTimeField):
  232. return {
  233. 'type': 'string',
  234. 'format': 'date-time',
  235. }
  236. # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
  237. # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
  238. # see also: https://swagger.io/docs/specification/data-models/data-types/#string
  239. if isinstance(field, serializers.EmailField):
  240. return {
  241. 'type': 'string',
  242. 'format': 'email'
  243. }
  244. if isinstance(field, serializers.URLField):
  245. return {
  246. 'type': 'string',
  247. 'format': 'uri'
  248. }
  249. if isinstance(field, serializers.UUIDField):
  250. return {
  251. 'type': 'string',
  252. 'format': 'uuid'
  253. }
  254. if isinstance(field, serializers.IPAddressField):
  255. content = {
  256. 'type': 'string',
  257. }
  258. if field.protocol != 'both':
  259. content['format'] = field.protocol
  260. return content
  261. # DecimalField has multipleOf based on decimal_places
  262. if isinstance(field, serializers.DecimalField):
  263. content = {
  264. 'type': 'number'
  265. }
  266. if field.decimal_places:
  267. content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
  268. if field.max_whole_digits:
  269. content['maximum'] = int(field.max_whole_digits * '9') + 1
  270. content['minimum'] = -content['maximum']
  271. self._map_min_max(field, content)
  272. return content
  273. if isinstance(field, serializers.FloatField):
  274. content = {
  275. 'type': 'number'
  276. }
  277. self._map_min_max(field, content)
  278. return content
  279. if isinstance(field, serializers.IntegerField):
  280. content = {
  281. 'type': 'integer'
  282. }
  283. self._map_min_max(field, content)
  284. # 2147483647 is max for int32_size, so we use int64 for format
  285. if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
  286. content['format'] = 'int64'
  287. return content
  288. if isinstance(field, serializers.FileField):
  289. return {
  290. 'type': 'string',
  291. 'format': 'binary'
  292. }
  293. # Simplest cases, default to 'string' type:
  294. FIELD_CLASS_SCHEMA_TYPE = {
  295. serializers.BooleanField: 'boolean',
  296. serializers.JSONField: 'object',
  297. serializers.DictField: 'object',
  298. serializers.HStoreField: 'object',
  299. }
  300. return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
  301. def _map_min_max(self, field, content):
  302. if field.max_value:
  303. content['maximum'] = field.max_value
  304. if field.min_value:
  305. content['minimum'] = field.min_value
  306. def _map_serializer(self, serializer):
  307. # Assuming we have a valid serializer instance.
  308. # TODO:
  309. # - field is Nested or List serializer.
  310. # - Handle read_only/write_only for request/response differences.
  311. # - could do this with readOnly/writeOnly and then filter dict.
  312. required = []
  313. properties = {}
  314. for field in serializer.fields.values():
  315. if isinstance(field, serializers.HiddenField):
  316. continue
  317. if field.required:
  318. required.append(field.field_name)
  319. schema = self._map_field(field)
  320. if field.read_only:
  321. schema['readOnly'] = True
  322. if field.write_only:
  323. schema['writeOnly'] = True
  324. if field.allow_null:
  325. schema['nullable'] = True
  326. if field.default and field.default != empty: # why don't they use None?!
  327. schema['default'] = field.default
  328. if field.help_text:
  329. schema['description'] = str(field.help_text)
  330. self._map_field_validators(field, schema)
  331. properties[field.field_name] = schema
  332. result = {
  333. 'properties': properties
  334. }
  335. if required:
  336. result['required'] = required
  337. return result
  338. def _map_field_validators(self, field, schema):
  339. """
  340. map field validators
  341. """
  342. for v in field.validators:
  343. # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
  344. # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
  345. if isinstance(v, EmailValidator):
  346. schema['format'] = 'email'
  347. if isinstance(v, URLValidator):
  348. schema['format'] = 'uri'
  349. if isinstance(v, RegexValidator):
  350. schema['pattern'] = v.regex.pattern
  351. elif isinstance(v, MaxLengthValidator):
  352. attr_name = 'maxLength'
  353. if isinstance(field, serializers.ListField):
  354. attr_name = 'maxItems'
  355. schema[attr_name] = v.limit_value
  356. elif isinstance(v, MinLengthValidator):
  357. attr_name = 'minLength'
  358. if isinstance(field, serializers.ListField):
  359. attr_name = 'minItems'
  360. schema[attr_name] = v.limit_value
  361. elif isinstance(v, MaxValueValidator):
  362. schema['maximum'] = v.limit_value
  363. elif isinstance(v, MinValueValidator):
  364. schema['minimum'] = v.limit_value
  365. elif isinstance(v, DecimalValidator):
  366. if v.decimal_places:
  367. schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
  368. if v.max_digits:
  369. digits = v.max_digits
  370. if v.decimal_places is not None and v.decimal_places > 0:
  371. digits -= v.decimal_places
  372. schema['maximum'] = int(digits * '9') + 1
  373. schema['minimum'] = -schema['maximum']
  374. def _get_paginator(self):
  375. pagination_class = getattr(self.view, 'pagination_class', None)
  376. if pagination_class:
  377. return pagination_class()
  378. return None
  379. def map_parsers(self, path, method):
  380. return list(map(attrgetter('media_type'), self.view.parser_classes))
  381. def map_renderers(self, path, method):
  382. media_types = []
  383. for renderer in self.view.renderer_classes:
  384. # BrowsableAPIRenderer not relevant to OpenAPI spec
  385. if renderer == renderers.BrowsableAPIRenderer:
  386. continue
  387. media_types.append(renderer.media_type)
  388. return media_types
  389. def _get_serializer(self, method, path):
  390. view = self.view
  391. if not hasattr(view, 'get_serializer'):
  392. return None
  393. try:
  394. return view.get_serializer()
  395. except exceptions.APIException:
  396. warnings.warn('{}.get_serializer() raised an exception during '
  397. 'schema generation. Serializer fields will not be '
  398. 'generated for {} {}.'
  399. .format(view.__class__.__name__, method, path))
  400. return None
  401. def _get_request_body(self, path, method):
  402. if method not in ('PUT', 'PATCH', 'POST'):
  403. return {}
  404. self.request_media_types = self.map_parsers(path, method)
  405. serializer = self._get_serializer(path, method)
  406. if not isinstance(serializer, serializers.Serializer):
  407. return {}
  408. content = self._map_serializer(serializer)
  409. # No required fields for PATCH
  410. if method == 'PATCH':
  411. content.pop('required', None)
  412. # No read_only fields for request.
  413. for name, schema in content['properties'].copy().items():
  414. if 'readOnly' in schema:
  415. del content['properties'][name]
  416. return {
  417. 'content': {
  418. ct: {'schema': content}
  419. for ct in self.request_media_types
  420. }
  421. }
  422. def _get_responses(self, path, method):
  423. # TODO: Handle multiple codes and pagination classes.
  424. if method == 'DELETE':
  425. return {
  426. '204': {
  427. 'description': ''
  428. }
  429. }
  430. self.response_media_types = self.map_renderers(path, method)
  431. item_schema = {}
  432. serializer = self._get_serializer(path, method)
  433. if isinstance(serializer, serializers.Serializer):
  434. item_schema = self._map_serializer(serializer)
  435. # No write_only fields for response.
  436. for name, schema in item_schema['properties'].copy().items():
  437. if 'writeOnly' in schema:
  438. del item_schema['properties'][name]
  439. if 'required' in item_schema:
  440. item_schema['required'] = [f for f in item_schema['required'] if f != name]
  441. if is_list_view(path, method, self.view):
  442. response_schema = {
  443. 'type': 'array',
  444. 'items': item_schema,
  445. }
  446. paginator = self._get_paginator()
  447. if paginator:
  448. response_schema = paginator.get_paginated_response_schema(response_schema)
  449. else:
  450. response_schema = item_schema
  451. return {
  452. '200': {
  453. 'content': {
  454. ct: {'schema': response_schema}
  455. for ct in self.response_media_types
  456. },
  457. # description is a mandatory property,
  458. # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
  459. # TODO: put something meaningful into it
  460. 'description': ""
  461. }
  462. }