Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
wiseplat
GitHub Repository: wiseplat/python-code
Path: blob/master/ invest-robot-contest_TinkoffBotTwitch-main/venv/lib/python3.8/site-packages/aiohttp/test_utils.py
7757 views
1
"""Utilities shared by tests."""
2
3
import asyncio
4
import contextlib
5
import gc
6
import inspect
7
import ipaddress
8
import os
9
import socket
10
import sys
11
import warnings
12
from abc import ABC, abstractmethod
13
from types import TracebackType
14
from typing import (
15
TYPE_CHECKING,
16
Any,
17
Callable,
18
Iterator,
19
List,
20
Optional,
21
Type,
22
Union,
23
cast,
24
)
25
from unittest import mock
26
27
from aiosignal import Signal
28
from multidict import CIMultiDict, CIMultiDictProxy
29
from yarl import URL
30
31
import aiohttp
32
from aiohttp.client import _RequestContextManager, _WSRequestContextManager
33
34
from . import ClientSession, hdrs
35
from .abc import AbstractCookieJar
36
from .client_reqrep import ClientResponse
37
from .client_ws import ClientWebSocketResponse
38
from .helpers import PY_38, sentinel
39
from .http import HttpVersion, RawRequestMessage
40
from .web import (
41
Application,
42
AppRunner,
43
BaseRunner,
44
Request,
45
Server,
46
ServerRunner,
47
SockSite,
48
UrlMappingMatchInfo,
49
)
50
from .web_protocol import _RequestHandler
51
52
if TYPE_CHECKING: # pragma: no cover
53
from ssl import SSLContext
54
else:
55
SSLContext = None
56
57
if PY_38:
58
from unittest import IsolatedAsyncioTestCase as TestCase
59
else:
60
from asynctest import TestCase # type: ignore[no-redef]
61
62
REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
63
64
65
def get_unused_port_socket(
66
host: str, family: socket.AddressFamily = socket.AF_INET
67
) -> socket.socket:
68
return get_port_socket(host, 0, family)
69
70
71
def get_port_socket(
72
host: str, port: int, family: socket.AddressFamily
73
) -> socket.socket:
74
s = socket.socket(family, socket.SOCK_STREAM)
75
if REUSE_ADDRESS:
76
# Windows has different semantics for SO_REUSEADDR,
77
# so don't set it. Ref:
78
# https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
79
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
80
s.bind((host, port))
81
return s
82
83
84
def unused_port() -> int:
85
"""Return a port that is unused on the current host."""
86
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
87
s.bind(("127.0.0.1", 0))
88
return cast(int, s.getsockname()[1])
89
90
91
class BaseTestServer(ABC):
92
__test__ = False
93
94
def __init__(
95
self,
96
*,
97
scheme: Union[str, object] = sentinel,
98
loop: Optional[asyncio.AbstractEventLoop] = None,
99
host: str = "127.0.0.1",
100
port: Optional[int] = None,
101
skip_url_asserts: bool = False,
102
socket_factory: Callable[
103
[str, int, socket.AddressFamily], socket.socket
104
] = get_port_socket,
105
**kwargs: Any,
106
) -> None:
107
self._loop = loop
108
self.runner = None # type: Optional[BaseRunner]
109
self._root = None # type: Optional[URL]
110
self.host = host
111
self.port = port
112
self._closed = False
113
self.scheme = scheme
114
self.skip_url_asserts = skip_url_asserts
115
self.socket_factory = socket_factory
116
117
async def start_server(
118
self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any
119
) -> None:
120
if self.runner:
121
return
122
self._loop = loop
123
self._ssl = kwargs.pop("ssl", None)
124
self.runner = await self._make_runner(**kwargs)
125
await self.runner.setup()
126
if not self.port:
127
self.port = 0
128
try:
129
version = ipaddress.ip_address(self.host).version
130
except ValueError:
131
version = 4
132
family = socket.AF_INET6 if version == 6 else socket.AF_INET
133
_sock = self.socket_factory(self.host, self.port, family)
134
self.host, self.port = _sock.getsockname()[:2]
135
site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
136
await site.start()
137
server = site._server
138
assert server is not None
139
sockets = server.sockets
140
assert sockets is not None
141
self.port = sockets[0].getsockname()[1]
142
if self.scheme is sentinel:
143
if self._ssl:
144
scheme = "https"
145
else:
146
scheme = "http"
147
self.scheme = scheme
148
self._root = URL(f"{self.scheme}://{self.host}:{self.port}")
149
150
@abstractmethod # pragma: no cover
151
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
152
pass
153
154
def make_url(self, path: str) -> URL:
155
assert self._root is not None
156
url = URL(path)
157
if not self.skip_url_asserts:
158
assert not url.is_absolute()
159
return self._root.join(url)
160
else:
161
return URL(str(self._root) + path)
162
163
@property
164
def started(self) -> bool:
165
return self.runner is not None
166
167
@property
168
def closed(self) -> bool:
169
return self._closed
170
171
@property
172
def handler(self) -> Server:
173
# for backward compatibility
174
# web.Server instance
175
runner = self.runner
176
assert runner is not None
177
assert runner.server is not None
178
return runner.server
179
180
async def close(self) -> None:
181
"""Close all fixtures created by the test client.
182
183
After that point, the TestClient is no longer usable.
184
185
This is an idempotent function: running close multiple times
186
will not have any additional effects.
187
188
close is also run when the object is garbage collected, and on
189
exit when used as a context manager.
190
191
"""
192
if self.started and not self.closed:
193
assert self.runner is not None
194
await self.runner.cleanup()
195
self._root = None
196
self.port = None
197
self._closed = True
198
199
def __enter__(self) -> None:
200
raise TypeError("Use async with instead")
201
202
def __exit__(
203
self,
204
exc_type: Optional[Type[BaseException]],
205
exc_value: Optional[BaseException],
206
traceback: Optional[TracebackType],
207
) -> None:
208
# __exit__ should exist in pair with __enter__ but never executed
209
pass # pragma: no cover
210
211
async def __aenter__(self) -> "BaseTestServer":
212
await self.start_server(loop=self._loop)
213
return self
214
215
async def __aexit__(
216
self,
217
exc_type: Optional[Type[BaseException]],
218
exc_value: Optional[BaseException],
219
traceback: Optional[TracebackType],
220
) -> None:
221
await self.close()
222
223
224
class TestServer(BaseTestServer):
225
def __init__(
226
self,
227
app: Application,
228
*,
229
scheme: Union[str, object] = sentinel,
230
host: str = "127.0.0.1",
231
port: Optional[int] = None,
232
**kwargs: Any,
233
):
234
self.app = app
235
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
236
237
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
238
return AppRunner(self.app, **kwargs)
239
240
241
class RawTestServer(BaseTestServer):
242
def __init__(
243
self,
244
handler: _RequestHandler,
245
*,
246
scheme: Union[str, object] = sentinel,
247
host: str = "127.0.0.1",
248
port: Optional[int] = None,
249
**kwargs: Any,
250
) -> None:
251
self._handler = handler
252
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
253
254
async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner:
255
srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs)
256
return ServerRunner(srv, debug=debug, **kwargs)
257
258
259
class TestClient:
260
"""
261
A test client implementation.
262
263
To write functional tests for aiohttp based servers.
264
265
"""
266
267
__test__ = False
268
269
def __init__(
270
self,
271
server: BaseTestServer,
272
*,
273
cookie_jar: Optional[AbstractCookieJar] = None,
274
loop: Optional[asyncio.AbstractEventLoop] = None,
275
**kwargs: Any,
276
) -> None:
277
if not isinstance(server, BaseTestServer):
278
raise TypeError(
279
"server must be TestServer " "instance, found type: %r" % type(server)
280
)
281
self._server = server
282
self._loop = loop
283
if cookie_jar is None:
284
cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
285
self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs)
286
self._closed = False
287
self._responses = [] # type: List[ClientResponse]
288
self._websockets = [] # type: List[ClientWebSocketResponse]
289
290
async def start_server(self) -> None:
291
await self._server.start_server(loop=self._loop)
292
293
@property
294
def host(self) -> str:
295
return self._server.host
296
297
@property
298
def port(self) -> Optional[int]:
299
return self._server.port
300
301
@property
302
def server(self) -> BaseTestServer:
303
return self._server
304
305
@property
306
def app(self) -> Optional[Application]:
307
return cast(Optional[Application], getattr(self._server, "app", None))
308
309
@property
310
def session(self) -> ClientSession:
311
"""An internal aiohttp.ClientSession.
312
313
Unlike the methods on the TestClient, client session requests
314
do not automatically include the host in the url queried, and
315
will require an absolute path to the resource.
316
317
"""
318
return self._session
319
320
def make_url(self, path: str) -> URL:
321
return self._server.make_url(path)
322
323
async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse:
324
resp = await self._session.request(method, self.make_url(path), **kwargs)
325
# save it to close later
326
self._responses.append(resp)
327
return resp
328
329
def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager:
330
"""Routes a request to tested http server.
331
332
The interface is identical to aiohttp.ClientSession.request,
333
except the loop kwarg is overridden by the instance used by the
334
test server.
335
336
"""
337
return _RequestContextManager(self._request(method, path, **kwargs))
338
339
def get(self, path: str, **kwargs: Any) -> _RequestContextManager:
340
"""Perform an HTTP GET request."""
341
return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
342
343
def post(self, path: str, **kwargs: Any) -> _RequestContextManager:
344
"""Perform an HTTP POST request."""
345
return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
346
347
def options(self, path: str, **kwargs: Any) -> _RequestContextManager:
348
"""Perform an HTTP OPTIONS request."""
349
return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))
350
351
def head(self, path: str, **kwargs: Any) -> _RequestContextManager:
352
"""Perform an HTTP HEAD request."""
353
return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
354
355
def put(self, path: str, **kwargs: Any) -> _RequestContextManager:
356
"""Perform an HTTP PUT request."""
357
return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
358
359
def patch(self, path: str, **kwargs: Any) -> _RequestContextManager:
360
"""Perform an HTTP PATCH request."""
361
return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))
362
363
def delete(self, path: str, **kwargs: Any) -> _RequestContextManager:
364
"""Perform an HTTP PATCH request."""
365
return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))
366
367
def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager:
368
"""Initiate websocket connection.
369
370
The api corresponds to aiohttp.ClientSession.ws_connect.
371
372
"""
373
return _WSRequestContextManager(self._ws_connect(path, **kwargs))
374
375
async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse:
376
ws = await self._session.ws_connect(self.make_url(path), **kwargs)
377
self._websockets.append(ws)
378
return ws
379
380
async def close(self) -> None:
381
"""Close all fixtures created by the test client.
382
383
After that point, the TestClient is no longer usable.
384
385
This is an idempotent function: running close multiple times
386
will not have any additional effects.
387
388
close is also run on exit when used as a(n) (asynchronous)
389
context manager.
390
391
"""
392
if not self._closed:
393
for resp in self._responses:
394
resp.close()
395
for ws in self._websockets:
396
await ws.close()
397
await self._session.close()
398
await self._server.close()
399
self._closed = True
400
401
def __enter__(self) -> None:
402
raise TypeError("Use async with instead")
403
404
def __exit__(
405
self,
406
exc_type: Optional[Type[BaseException]],
407
exc: Optional[BaseException],
408
tb: Optional[TracebackType],
409
) -> None:
410
# __exit__ should exist in pair with __enter__ but never executed
411
pass # pragma: no cover
412
413
async def __aenter__(self) -> "TestClient":
414
await self.start_server()
415
return self
416
417
async def __aexit__(
418
self,
419
exc_type: Optional[Type[BaseException]],
420
exc: Optional[BaseException],
421
tb: Optional[TracebackType],
422
) -> None:
423
await self.close()
424
425
426
class AioHTTPTestCase(TestCase):
427
"""A base class to allow for unittest web applications using aiohttp.
428
429
Provides the following:
430
431
* self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
432
* self.loop (asyncio.BaseEventLoop): the event loop in which the
433
application and server are running.
434
* self.app (aiohttp.web.Application): the application returned by
435
self.get_application()
436
437
Note that the TestClient's methods are asynchronous: you have to
438
execute function on the test client using asynchronous methods.
439
"""
440
441
async def get_application(self) -> Application:
442
"""Get application.
443
444
This method should be overridden
445
to return the aiohttp.web.Application
446
object to test.
447
"""
448
return self.get_app()
449
450
def get_app(self) -> Application:
451
"""Obsolete method used to constructing web application.
452
453
Use .get_application() coroutine instead.
454
"""
455
raise RuntimeError("Did you forget to define get_application()?")
456
457
def setUp(self) -> None:
458
try:
459
self.loop = asyncio.get_running_loop()
460
except (AttributeError, RuntimeError): # AttributeError->py36
461
self.loop = asyncio.get_event_loop_policy().get_event_loop()
462
463
self.loop.run_until_complete(self.setUpAsync())
464
465
async def setUpAsync(self) -> None:
466
self.app = await self.get_application()
467
self.server = await self.get_server(self.app)
468
self.client = await self.get_client(self.server)
469
470
await self.client.start_server()
471
472
def tearDown(self) -> None:
473
self.loop.run_until_complete(self.tearDownAsync())
474
475
async def tearDownAsync(self) -> None:
476
await self.client.close()
477
478
async def get_server(self, app: Application) -> TestServer:
479
"""Return a TestServer instance."""
480
return TestServer(app, loop=self.loop)
481
482
async def get_client(self, server: TestServer) -> TestClient:
483
"""Return a TestClient instance."""
484
return TestClient(server, loop=self.loop)
485
486
487
def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
488
"""
489
A decorator dedicated to use with asynchronous AioHTTPTestCase test methods.
490
491
In 3.8+, this does nothing.
492
"""
493
warnings.warn(
494
"Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+",
495
DeprecationWarning,
496
stacklevel=2,
497
)
498
return func
499
500
501
_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
502
503
504
@contextlib.contextmanager
505
def loop_context(
506
loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
507
) -> Iterator[asyncio.AbstractEventLoop]:
508
"""A contextmanager that creates an event_loop, for test purposes.
509
510
Handles the creation and cleanup of a test loop.
511
"""
512
loop = setup_test_loop(loop_factory)
513
yield loop
514
teardown_test_loop(loop, fast=fast)
515
516
517
def setup_test_loop(
518
loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
519
) -> asyncio.AbstractEventLoop:
520
"""Create and return an asyncio.BaseEventLoop instance.
521
522
The caller should also call teardown_test_loop,
523
once they are done with the loop.
524
"""
525
loop = loop_factory()
526
try:
527
module = loop.__class__.__module__
528
skip_watcher = "uvloop" in module
529
except AttributeError: # pragma: no cover
530
# Just in case
531
skip_watcher = True
532
asyncio.set_event_loop(loop)
533
if sys.platform != "win32" and not skip_watcher:
534
policy = asyncio.get_event_loop_policy()
535
watcher: asyncio.AbstractChildWatcher
536
try: # Python >= 3.8
537
# Refs:
538
# * https://github.com/pytest-dev/pytest-xdist/issues/620
539
# * https://stackoverflow.com/a/58614689/595220
540
# * https://bugs.python.org/issue35621
541
# * https://github.com/python/cpython/pull/14344
542
watcher = asyncio.ThreadedChildWatcher()
543
except AttributeError: # Python < 3.8
544
watcher = asyncio.SafeChildWatcher()
545
watcher.attach_loop(loop)
546
with contextlib.suppress(NotImplementedError):
547
policy.set_child_watcher(watcher)
548
return loop
549
550
551
def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
552
"""Teardown and cleanup an event_loop created by setup_test_loop."""
553
closed = loop.is_closed()
554
if not closed:
555
loop.call_soon(loop.stop)
556
loop.run_forever()
557
loop.close()
558
559
if not fast:
560
gc.collect()
561
562
asyncio.set_event_loop(None)
563
564
565
def _create_app_mock() -> mock.MagicMock:
566
def get_dict(app: Any, key: str) -> Any:
567
return app.__app_dict[key]
568
569
def set_dict(app: Any, key: str, value: Any) -> None:
570
app.__app_dict[key] = value
571
572
app = mock.MagicMock()
573
app.__app_dict = {}
574
app.__getitem__ = get_dict
575
app.__setitem__ = set_dict
576
577
app._debug = False
578
app.on_response_prepare = Signal(app)
579
app.on_response_prepare.freeze()
580
return app
581
582
583
def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
584
transport = mock.Mock()
585
586
def get_extra_info(key: str) -> Optional[SSLContext]:
587
if key == "sslcontext":
588
return sslcontext
589
else:
590
return None
591
592
transport.get_extra_info.side_effect = get_extra_info
593
return transport
594
595
596
def make_mocked_request(
597
method: str,
598
path: str,
599
headers: Any = None,
600
*,
601
match_info: Any = sentinel,
602
version: HttpVersion = HttpVersion(1, 1),
603
closing: bool = False,
604
app: Any = None,
605
writer: Any = sentinel,
606
protocol: Any = sentinel,
607
transport: Any = sentinel,
608
payload: Any = sentinel,
609
sslcontext: Optional[SSLContext] = None,
610
client_max_size: int = 1024 ** 2,
611
loop: Any = ...,
612
) -> Request:
613
"""Creates mocked web.Request testing purposes.
614
615
Useful in unit tests, when spinning full web server is overkill or
616
specific conditions and errors are hard to trigger.
617
"""
618
task = mock.Mock()
619
if loop is ...:
620
loop = mock.Mock()
621
loop.create_future.return_value = ()
622
623
if version < HttpVersion(1, 1):
624
closing = True
625
626
if headers:
627
headers = CIMultiDictProxy(CIMultiDict(headers))
628
raw_hdrs = tuple(
629
(k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
630
)
631
else:
632
headers = CIMultiDictProxy(CIMultiDict())
633
raw_hdrs = ()
634
635
chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
636
637
message = RawRequestMessage(
638
method,
639
path,
640
version,
641
headers,
642
raw_hdrs,
643
closing,
644
None,
645
False,
646
chunked,
647
URL(path),
648
)
649
if app is None:
650
app = _create_app_mock()
651
652
if transport is sentinel:
653
transport = _create_transport(sslcontext)
654
655
if protocol is sentinel:
656
protocol = mock.Mock()
657
protocol.transport = transport
658
659
if writer is sentinel:
660
writer = mock.Mock()
661
writer.write_headers = make_mocked_coro(None)
662
writer.write = make_mocked_coro(None)
663
writer.write_eof = make_mocked_coro(None)
664
writer.drain = make_mocked_coro(None)
665
writer.transport = transport
666
667
protocol.transport = transport
668
protocol.writer = writer
669
670
if payload is sentinel:
671
payload = mock.Mock()
672
673
req = Request(
674
message, payload, protocol, writer, task, loop, client_max_size=client_max_size
675
)
676
677
match_info = UrlMappingMatchInfo(
678
{} if match_info is sentinel else match_info, mock.Mock()
679
)
680
match_info.add_app(app)
681
req._match_info = match_info
682
683
return req
684
685
686
def make_mocked_coro(
687
return_value: Any = sentinel, raise_exception: Any = sentinel
688
) -> Any:
689
"""Creates a coroutine mock."""
690
691
async def mock_coro(*args: Any, **kwargs: Any) -> Any:
692
if raise_exception is not sentinel:
693
raise raise_exception
694
if not inspect.isawaitable(return_value):
695
return return_value
696
await return_value
697
698
return mock.Mock(wraps=mock_coro)
699
700