routers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. """
  2. Routers provide a convenient and consistent way of automatically
  3. determining the URL conf for your API.
  4. They are used by simply instantiating a Router class, and then registering
  5. all the required ViewSets with that router.
  6. For example, you might have a `urls.py` that looks something like this:
  7. router = routers.DefaultRouter()
  8. router.register('users', UserViewSet, 'user')
  9. router.register('accounts', AccountViewSet, 'account')
  10. urlpatterns = router.urls
  11. """
  12. import itertools
  13. from collections import OrderedDict, namedtuple
  14. from django.conf.urls import url
  15. from django.core.exceptions import ImproperlyConfigured
  16. from django.urls import NoReverseMatch
  17. from rest_framework import views
  18. from rest_framework.response import Response
  19. from rest_framework.reverse import reverse
  20. from rest_framework.schemas import SchemaGenerator
  21. from rest_framework.schemas.views import SchemaView
  22. from rest_framework.settings import api_settings
  23. from rest_framework.urlpatterns import format_suffix_patterns
  24. Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
  25. DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
  26. def escape_curly_brackets(url_path):
  27. """
  28. Double brackets in regex of url_path for escape string formatting
  29. """
  30. return url_path.replace('{', '{{').replace('}', '}}')
  31. def flatten(list_of_lists):
  32. """
  33. Takes an iterable of iterables, returns a single iterable containing all items
  34. """
  35. return itertools.chain(*list_of_lists)
  36. class BaseRouter:
  37. def __init__(self):
  38. self.registry = []
  39. def register(self, prefix, viewset, basename=None):
  40. if basename is None:
  41. basename = self.get_default_basename(viewset)
  42. self.registry.append((prefix, viewset, basename))
  43. # invalidate the urls cache
  44. if hasattr(self, '_urls'):
  45. del self._urls
  46. def get_default_basename(self, viewset):
  47. """
  48. If `basename` is not specified, attempt to automatically determine
  49. it from the viewset.
  50. """
  51. raise NotImplementedError('get_default_basename must be overridden')
  52. def get_urls(self):
  53. """
  54. Return a list of URL patterns, given the registered viewsets.
  55. """
  56. raise NotImplementedError('get_urls must be overridden')
  57. @property
  58. def urls(self):
  59. if not hasattr(self, '_urls'):
  60. self._urls = self.get_urls()
  61. return self._urls
  62. class SimpleRouter(BaseRouter):
  63. routes = [
  64. # List route.
  65. Route(
  66. url=r'^{prefix}{trailing_slash}$',
  67. mapping={
  68. 'get': 'list',
  69. 'post': 'create'
  70. },
  71. name='{basename}-list',
  72. detail=False,
  73. initkwargs={'suffix': 'List'}
  74. ),
  75. # Dynamically generated list routes. Generated using
  76. # @action(detail=False) decorator on methods of the viewset.
  77. DynamicRoute(
  78. url=r'^{prefix}/{url_path}{trailing_slash}$',
  79. name='{basename}-{url_name}',
  80. detail=False,
  81. initkwargs={}
  82. ),
  83. # Detail route.
  84. Route(
  85. url=r'^{prefix}/{lookup}{trailing_slash}$',
  86. mapping={
  87. 'get': 'retrieve',
  88. 'put': 'update',
  89. 'patch': 'partial_update',
  90. 'delete': 'destroy'
  91. },
  92. name='{basename}-detail',
  93. detail=True,
  94. initkwargs={'suffix': 'Instance'}
  95. ),
  96. # Dynamically generated detail routes. Generated using
  97. # @action(detail=True) decorator on methods of the viewset.
  98. DynamicRoute(
  99. url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
  100. name='{basename}-{url_name}',
  101. detail=True,
  102. initkwargs={}
  103. ),
  104. ]
  105. def __init__(self, trailing_slash=True):
  106. self.trailing_slash = '/' if trailing_slash else ''
  107. super().__init__()
  108. def get_default_basename(self, viewset):
  109. """
  110. If `basename` is not specified, attempt to automatically determine
  111. it from the viewset.
  112. """
  113. queryset = getattr(viewset, 'queryset', None)
  114. assert queryset is not None, '`basename` argument not specified, and could ' \
  115. 'not automatically determine the name from the viewset, as ' \
  116. 'it does not have a `.queryset` attribute.'
  117. return queryset.model._meta.object_name.lower()
  118. def get_routes(self, viewset):
  119. """
  120. Augment `self.routes` with any dynamically generated routes.
  121. Returns a list of the Route namedtuple.
  122. """
  123. # converting to list as iterables are good for one pass, known host needs to be checked again and again for
  124. # different functions.
  125. known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
  126. extra_actions = viewset.get_extra_actions()
  127. # checking action names against the known actions list
  128. not_allowed = [
  129. action.__name__ for action in extra_actions
  130. if action.__name__ in known_actions
  131. ]
  132. if not_allowed:
  133. msg = ('Cannot use the @action decorator on the following '
  134. 'methods, as they are existing routes: %s')
  135. raise ImproperlyConfigured(msg % ', '.join(not_allowed))
  136. # partition detail and list actions
  137. detail_actions = [action for action in extra_actions if action.detail]
  138. list_actions = [action for action in extra_actions if not action.detail]
  139. routes = []
  140. for route in self.routes:
  141. if isinstance(route, DynamicRoute) and route.detail:
  142. routes += [self._get_dynamic_route(route, action) for action in detail_actions]
  143. elif isinstance(route, DynamicRoute) and not route.detail:
  144. routes += [self._get_dynamic_route(route, action) for action in list_actions]
  145. else:
  146. routes.append(route)
  147. return routes
  148. def _get_dynamic_route(self, route, action):
  149. initkwargs = route.initkwargs.copy()
  150. initkwargs.update(action.kwargs)
  151. url_path = escape_curly_brackets(action.url_path)
  152. return Route(
  153. url=route.url.replace('{url_path}', url_path),
  154. mapping=action.mapping,
  155. name=route.name.replace('{url_name}', action.url_name),
  156. detail=route.detail,
  157. initkwargs=initkwargs,
  158. )
  159. def get_method_map(self, viewset, method_map):
  160. """
  161. Given a viewset, and a mapping of http methods to actions,
  162. return a new mapping which only includes any mappings that
  163. are actually implemented by the viewset.
  164. """
  165. bound_methods = {}
  166. for method, action in method_map.items():
  167. if hasattr(viewset, action):
  168. bound_methods[method] = action
  169. return bound_methods
  170. def get_lookup_regex(self, viewset, lookup_prefix=''):
  171. """
  172. Given a viewset, return the portion of URL regex that is used
  173. to match against a single instance.
  174. Note that lookup_prefix is not used directly inside REST rest_framework
  175. itself, but is required in order to nicely support nested router
  176. implementations, such as drf-nested-routers.
  177. https://github.com/alanjds/drf-nested-routers
  178. """
  179. base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
  180. # Use `pk` as default field, unset set. Default regex should not
  181. # consume `.json` style suffixes and should break at '/' boundaries.
  182. lookup_field = getattr(viewset, 'lookup_field', 'pk')
  183. lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
  184. lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
  185. return base_regex.format(
  186. lookup_prefix=lookup_prefix,
  187. lookup_url_kwarg=lookup_url_kwarg,
  188. lookup_value=lookup_value
  189. )
  190. def get_urls(self):
  191. """
  192. Use the registered viewsets to generate a list of URL patterns.
  193. """
  194. ret = []
  195. for prefix, viewset, basename in self.registry:
  196. lookup = self.get_lookup_regex(viewset)
  197. routes = self.get_routes(viewset)
  198. for route in routes:
  199. # Only actions which actually exist on the viewset will be bound
  200. mapping = self.get_method_map(viewset, route.mapping)
  201. if not mapping:
  202. continue
  203. # Build the url pattern
  204. regex = route.url.format(
  205. prefix=prefix,
  206. lookup=lookup,
  207. trailing_slash=self.trailing_slash
  208. )
  209. # If there is no prefix, the first part of the url is probably
  210. # controlled by project's urls.py and the router is in an app,
  211. # so a slash in the beginning will (A) cause Django to give
  212. # warnings and (B) generate URLS that will require using '//'.
  213. if not prefix and regex[:2] == '^/':
  214. regex = '^' + regex[2:]
  215. initkwargs = route.initkwargs.copy()
  216. initkwargs.update({
  217. 'basename': basename,
  218. 'detail': route.detail,
  219. })
  220. view = viewset.as_view(mapping, **initkwargs)
  221. name = route.name.format(basename=basename)
  222. ret.append(url(regex, view, name=name))
  223. return ret
  224. class APIRootView(views.APIView):
  225. """
  226. The default basic root view for DefaultRouter
  227. """
  228. _ignore_model_permissions = True
  229. schema = None # exclude from schema
  230. api_root_dict = None
  231. def get(self, request, *args, **kwargs):
  232. # Return a plain {"name": "hyperlink"} response.
  233. ret = OrderedDict()
  234. namespace = request.resolver_match.namespace
  235. for key, url_name in self.api_root_dict.items():
  236. if namespace:
  237. url_name = namespace + ':' + url_name
  238. try:
  239. ret[key] = reverse(
  240. url_name,
  241. args=args,
  242. kwargs=kwargs,
  243. request=request,
  244. format=kwargs.get('format', None)
  245. )
  246. except NoReverseMatch:
  247. # Don't bail out if eg. no list routes exist, only detail routes.
  248. continue
  249. return Response(ret)
  250. class DefaultRouter(SimpleRouter):
  251. """
  252. The default router extends the SimpleRouter, but also adds in a default
  253. API root view, and adds format suffix patterns to the URLs.
  254. """
  255. include_root_view = True
  256. include_format_suffixes = True
  257. root_view_name = 'api-root'
  258. default_schema_renderers = None
  259. APIRootView = APIRootView
  260. APISchemaView = SchemaView
  261. SchemaGenerator = SchemaGenerator
  262. def __init__(self, *args, **kwargs):
  263. if 'root_renderers' in kwargs:
  264. self.root_renderers = kwargs.pop('root_renderers')
  265. else:
  266. self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
  267. super().__init__(*args, **kwargs)
  268. def get_api_root_view(self, api_urls=None):
  269. """
  270. Return a basic root view.
  271. """
  272. api_root_dict = OrderedDict()
  273. list_name = self.routes[0].name
  274. for prefix, viewset, basename in self.registry:
  275. api_root_dict[prefix] = list_name.format(basename=basename)
  276. return self.APIRootView.as_view(api_root_dict=api_root_dict)
  277. def get_urls(self):
  278. """
  279. Generate the list of URL patterns, including a default root view
  280. for the API, and appending `.json` style format suffixes.
  281. """
  282. urls = super().get_urls()
  283. if self.include_root_view:
  284. view = self.get_api_root_view(api_urls=urls)
  285. root_url = url(r'^$', view, name=self.root_view_name)
  286. urls.append(root_url)
  287. if self.include_format_suffixes:
  288. urls = format_suffix_patterns(urls)
  289. return urls