test.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
  2. # to make it harder for the user to import the wrong thing without realizing.
  3. import io
  4. from importlib import import_module
  5. from django.conf import settings
  6. from django.core.exceptions import ImproperlyConfigured
  7. from django.core.handlers.wsgi import WSGIHandler
  8. from django.test import override_settings, testcases
  9. from django.test.client import Client as DjangoClient
  10. from django.test.client import ClientHandler
  11. from django.test.client import RequestFactory as DjangoRequestFactory
  12. from django.utils.encoding import force_bytes
  13. from django.utils.http import urlencode
  14. from rest_framework.compat import coreapi, requests
  15. from rest_framework.settings import api_settings
  16. def force_authenticate(request, user=None, token=None):
  17. request._force_auth_user = user
  18. request._force_auth_token = token
  19. if requests is not None:
  20. class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
  21. def get_all(self, key, default):
  22. return self.getheaders(key)
  23. class MockOriginalResponse:
  24. def __init__(self, headers):
  25. self.msg = HeaderDict(headers)
  26. self.closed = False
  27. def isclosed(self):
  28. return self.closed
  29. def close(self):
  30. self.closed = True
  31. class DjangoTestAdapter(requests.adapters.HTTPAdapter):
  32. """
  33. A transport adapter for `requests`, that makes requests via the
  34. Django WSGI app, rather than making actual HTTP requests over the network.
  35. """
  36. def __init__(self):
  37. self.app = WSGIHandler()
  38. self.factory = DjangoRequestFactory()
  39. def get_environ(self, request):
  40. """
  41. Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
  42. """
  43. method = request.method
  44. url = request.url
  45. kwargs = {}
  46. # Set request content, if any exists.
  47. if request.body is not None:
  48. if hasattr(request.body, 'read'):
  49. kwargs['data'] = request.body.read()
  50. else:
  51. kwargs['data'] = request.body
  52. if 'content-type' in request.headers:
  53. kwargs['content_type'] = request.headers['content-type']
  54. # Set request headers.
  55. for key, value in request.headers.items():
  56. key = key.upper()
  57. if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
  58. continue
  59. kwargs['HTTP_%s' % key.replace('-', '_')] = value
  60. return self.factory.generic(method, url, **kwargs).environ
  61. def send(self, request, *args, **kwargs):
  62. """
  63. Make an outgoing request to the Django WSGI application.
  64. """
  65. raw_kwargs = {}
  66. def start_response(wsgi_status, wsgi_headers):
  67. status, _, reason = wsgi_status.partition(' ')
  68. raw_kwargs['status'] = int(status)
  69. raw_kwargs['reason'] = reason
  70. raw_kwargs['headers'] = wsgi_headers
  71. raw_kwargs['version'] = 11
  72. raw_kwargs['preload_content'] = False
  73. raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
  74. # Make the outgoing request via WSGI.
  75. environ = self.get_environ(request)
  76. wsgi_response = self.app(environ, start_response)
  77. # Build the underlying urllib3.HTTPResponse
  78. raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
  79. raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
  80. # Build the requests.Response
  81. return self.build_response(request, raw)
  82. def close(self):
  83. pass
  84. class RequestsClient(requests.Session):
  85. def __init__(self, *args, **kwargs):
  86. super().__init__(*args, **kwargs)
  87. adapter = DjangoTestAdapter()
  88. self.mount('http://', adapter)
  89. self.mount('https://', adapter)
  90. def request(self, method, url, *args, **kwargs):
  91. if not url.startswith('http'):
  92. raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
  93. return super().request(method, url, *args, **kwargs)
  94. else:
  95. def RequestsClient(*args, **kwargs):
  96. raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
  97. if coreapi is not None:
  98. class CoreAPIClient(coreapi.Client):
  99. def __init__(self, *args, **kwargs):
  100. self._session = RequestsClient()
  101. kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
  102. return super().__init__(*args, **kwargs)
  103. @property
  104. def session(self):
  105. return self._session
  106. else:
  107. def CoreAPIClient(*args, **kwargs):
  108. raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
  109. class APIRequestFactory(DjangoRequestFactory):
  110. renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
  111. default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
  112. def __init__(self, enforce_csrf_checks=False, **defaults):
  113. self.enforce_csrf_checks = enforce_csrf_checks
  114. self.renderer_classes = {}
  115. for cls in self.renderer_classes_list:
  116. self.renderer_classes[cls.format] = cls
  117. super().__init__(**defaults)
  118. def _encode_data(self, data, format=None, content_type=None):
  119. """
  120. Encode the data returning a two tuple of (bytes, content_type)
  121. """
  122. if data is None:
  123. return ('', content_type)
  124. assert format is None or content_type is None, (
  125. 'You may not set both `format` and `content_type`.'
  126. )
  127. if content_type:
  128. # Content type specified explicitly, treat data as a raw bytestring
  129. ret = force_bytes(data, settings.DEFAULT_CHARSET)
  130. else:
  131. format = format or self.default_format
  132. assert format in self.renderer_classes, (
  133. "Invalid format '{}'. Available formats are {}. "
  134. "Set TEST_REQUEST_RENDERER_CLASSES to enable "
  135. "extra request formats.".format(
  136. format,
  137. ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
  138. )
  139. )
  140. # Use format and render the data into a bytestring
  141. renderer = self.renderer_classes[format]()
  142. ret = renderer.render(data)
  143. # Determine the content-type header from the renderer
  144. content_type = "{}; charset={}".format(
  145. renderer.media_type, renderer.charset
  146. )
  147. # Coerce text to bytes if required.
  148. if isinstance(ret, str):
  149. ret = ret.encode(renderer.charset)
  150. return ret, content_type
  151. def get(self, path, data=None, **extra):
  152. r = {
  153. 'QUERY_STRING': urlencode(data or {}, doseq=True),
  154. }
  155. if not data and '?' in path:
  156. # Fix to support old behavior where you have the arguments in the
  157. # url. See #1461.
  158. query_string = force_bytes(path.split('?')[1])
  159. query_string = query_string.decode('iso-8859-1')
  160. r['QUERY_STRING'] = query_string
  161. r.update(extra)
  162. return self.generic('GET', path, **r)
  163. def post(self, path, data=None, format=None, content_type=None, **extra):
  164. data, content_type = self._encode_data(data, format, content_type)
  165. return self.generic('POST', path, data, content_type, **extra)
  166. def put(self, path, data=None, format=None, content_type=None, **extra):
  167. data, content_type = self._encode_data(data, format, content_type)
  168. return self.generic('PUT', path, data, content_type, **extra)
  169. def patch(self, path, data=None, format=None, content_type=None, **extra):
  170. data, content_type = self._encode_data(data, format, content_type)
  171. return self.generic('PATCH', path, data, content_type, **extra)
  172. def delete(self, path, data=None, format=None, content_type=None, **extra):
  173. data, content_type = self._encode_data(data, format, content_type)
  174. return self.generic('DELETE', path, data, content_type, **extra)
  175. def options(self, path, data=None, format=None, content_type=None, **extra):
  176. data, content_type = self._encode_data(data, format, content_type)
  177. return self.generic('OPTIONS', path, data, content_type, **extra)
  178. def generic(self, method, path, data='',
  179. content_type='application/octet-stream', secure=False, **extra):
  180. # Include the CONTENT_TYPE, regardless of whether or not data is empty.
  181. if content_type is not None:
  182. extra['CONTENT_TYPE'] = str(content_type)
  183. return super().generic(
  184. method, path, data, content_type, secure, **extra)
  185. def request(self, **kwargs):
  186. request = super().request(**kwargs)
  187. request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
  188. return request
  189. class ForceAuthClientHandler(ClientHandler):
  190. """
  191. A patched version of ClientHandler that can enforce authentication
  192. on the outgoing requests.
  193. """
  194. def __init__(self, *args, **kwargs):
  195. self._force_user = None
  196. self._force_token = None
  197. super().__init__(*args, **kwargs)
  198. def get_response(self, request):
  199. # This is the simplest place we can hook into to patch the
  200. # request object.
  201. force_authenticate(request, self._force_user, self._force_token)
  202. return super().get_response(request)
  203. class APIClient(APIRequestFactory, DjangoClient):
  204. def __init__(self, enforce_csrf_checks=False, **defaults):
  205. super().__init__(**defaults)
  206. self.handler = ForceAuthClientHandler(enforce_csrf_checks)
  207. self._credentials = {}
  208. def credentials(self, **kwargs):
  209. """
  210. Sets headers that will be used on every outgoing request.
  211. """
  212. self._credentials = kwargs
  213. def force_authenticate(self, user=None, token=None):
  214. """
  215. Forcibly authenticates outgoing requests with the given
  216. user and/or token.
  217. """
  218. self.handler._force_user = user
  219. self.handler._force_token = token
  220. if user is None:
  221. self.logout() # Also clear any possible session info if required
  222. def request(self, **kwargs):
  223. # Ensure that any credentials set get added to every request.
  224. kwargs.update(self._credentials)
  225. return super().request(**kwargs)
  226. def get(self, path, data=None, follow=False, **extra):
  227. response = super().get(path, data=data, **extra)
  228. if follow:
  229. response = self._handle_redirects(response, **extra)
  230. return response
  231. def post(self, path, data=None, format=None, content_type=None,
  232. follow=False, **extra):
  233. response = super().post(
  234. path, data=data, format=format, content_type=content_type, **extra)
  235. if follow:
  236. response = self._handle_redirects(response, **extra)
  237. return response
  238. def put(self, path, data=None, format=None, content_type=None,
  239. follow=False, **extra):
  240. response = super().put(
  241. path, data=data, format=format, content_type=content_type, **extra)
  242. if follow:
  243. response = self._handle_redirects(response, **extra)
  244. return response
  245. def patch(self, path, data=None, format=None, content_type=None,
  246. follow=False, **extra):
  247. response = super().patch(
  248. path, data=data, format=format, content_type=content_type, **extra)
  249. if follow:
  250. response = self._handle_redirects(response, **extra)
  251. return response
  252. def delete(self, path, data=None, format=None, content_type=None,
  253. follow=False, **extra):
  254. response = super().delete(
  255. path, data=data, format=format, content_type=content_type, **extra)
  256. if follow:
  257. response = self._handle_redirects(response, **extra)
  258. return response
  259. def options(self, path, data=None, format=None, content_type=None,
  260. follow=False, **extra):
  261. response = super().options(
  262. path, data=data, format=format, content_type=content_type, **extra)
  263. if follow:
  264. response = self._handle_redirects(response, **extra)
  265. return response
  266. def logout(self):
  267. self._credentials = {}
  268. # Also clear any `force_authenticate`
  269. self.handler._force_user = None
  270. self.handler._force_token = None
  271. if self.session:
  272. super().logout()
  273. class APITransactionTestCase(testcases.TransactionTestCase):
  274. client_class = APIClient
  275. class APITestCase(testcases.TestCase):
  276. client_class = APIClient
  277. class APISimpleTestCase(testcases.SimpleTestCase):
  278. client_class = APIClient
  279. class APILiveServerTestCase(testcases.LiveServerTestCase):
  280. client_class = APIClient
  281. class URLPatternsTestCase(testcases.SimpleTestCase):
  282. """
  283. Isolate URL patterns on a per-TestCase basis. For example,
  284. class ATestCase(URLPatternsTestCase):
  285. urlpatterns = [...]
  286. def test_something(self):
  287. ...
  288. class AnotherTestCase(URLPatternsTestCase):
  289. urlpatterns = [...]
  290. def test_something_else(self):
  291. ...
  292. """
  293. @classmethod
  294. def setUpClass(cls):
  295. # Get the module of the TestCase subclass
  296. cls._module = import_module(cls.__module__)
  297. cls._override = override_settings(ROOT_URLCONF=cls.__module__)
  298. if hasattr(cls._module, 'urlpatterns'):
  299. cls._module_urlpatterns = cls._module.urlpatterns
  300. cls._module.urlpatterns = cls.urlpatterns
  301. cls._override.enable()
  302. super().setUpClass()
  303. @classmethod
  304. def tearDownClass(cls):
  305. super().tearDownClass()
  306. cls._override.disable()
  307. if hasattr(cls, '_module_urlpatterns'):
  308. cls._module.urlpatterns = cls._module_urlpatterns
  309. else:
  310. del cls._module.urlpatterns