Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
wiseplat
GitHub Repository: wiseplat/python-code
Path: blob/master/ invest-robot-contest_TinkoffBotTwitch-main/venv/lib/python3.8/site-packages/aiohttp/http_websocket.py
7767 views
1
"""WebSocket protocol versions 13 and 8."""
2
3
import asyncio
4
import collections
5
import json
6
import random
7
import re
8
import sys
9
import zlib
10
from enum import IntEnum
11
from struct import Struct
12
from typing import Any, Callable, List, Optional, Pattern, Set, Tuple, Union, cast
13
14
from .base_protocol import BaseProtocol
15
from .helpers import NO_EXTENSIONS
16
from .streams import DataQueue
17
from .typedefs import Final
18
19
__all__ = (
20
"WS_CLOSED_MESSAGE",
21
"WS_CLOSING_MESSAGE",
22
"WS_KEY",
23
"WebSocketReader",
24
"WebSocketWriter",
25
"WSMessage",
26
"WebSocketError",
27
"WSMsgType",
28
"WSCloseCode",
29
)
30
31
32
class WSCloseCode(IntEnum):
33
OK = 1000
34
GOING_AWAY = 1001
35
PROTOCOL_ERROR = 1002
36
UNSUPPORTED_DATA = 1003
37
ABNORMAL_CLOSURE = 1006
38
INVALID_TEXT = 1007
39
POLICY_VIOLATION = 1008
40
MESSAGE_TOO_BIG = 1009
41
MANDATORY_EXTENSION = 1010
42
INTERNAL_ERROR = 1011
43
SERVICE_RESTART = 1012
44
TRY_AGAIN_LATER = 1013
45
BAD_GATEWAY = 1014
46
47
48
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
49
50
51
class WSMsgType(IntEnum):
52
# websocket spec types
53
CONTINUATION = 0x0
54
TEXT = 0x1
55
BINARY = 0x2
56
PING = 0x9
57
PONG = 0xA
58
CLOSE = 0x8
59
60
# aiohttp specific types
61
CLOSING = 0x100
62
CLOSED = 0x101
63
ERROR = 0x102
64
65
text = TEXT
66
binary = BINARY
67
ping = PING
68
pong = PONG
69
close = CLOSE
70
closing = CLOSING
71
closed = CLOSED
72
error = ERROR
73
74
75
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
76
77
78
UNPACK_LEN2 = Struct("!H").unpack_from
79
UNPACK_LEN3 = Struct("!Q").unpack_from
80
UNPACK_CLOSE_CODE = Struct("!H").unpack
81
PACK_LEN1 = Struct("!BB").pack
82
PACK_LEN2 = Struct("!BBH").pack
83
PACK_LEN3 = Struct("!BBQ").pack
84
PACK_CLOSE_CODE = Struct("!H").pack
85
MSG_SIZE: Final[int] = 2 ** 14
86
DEFAULT_LIMIT: Final[int] = 2 ** 16
87
88
89
_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"])
90
91
92
class WSMessage(_WSMessageBase):
93
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
94
"""Return parsed JSON data.
95
96
.. versionadded:: 0.22
97
"""
98
return loads(self.data)
99
100
101
WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
102
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
103
104
105
class WebSocketError(Exception):
106
"""WebSocket protocol parser error."""
107
108
def __init__(self, code: int, message: str) -> None:
109
self.code = code
110
super().__init__(code, message)
111
112
def __str__(self) -> str:
113
return cast(str, self.args[1])
114
115
116
class WSHandshakeError(Exception):
117
"""WebSocket protocol handshake error."""
118
119
120
native_byteorder: Final[str] = sys.byteorder
121
122
123
# Used by _websocket_mask_python
124
_XOR_TABLE: Final[List[bytes]] = [bytes(a ^ b for a in range(256)) for b in range(256)]
125
126
127
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
128
"""Websocket masking function.
129
130
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
131
object of any length. The contents of `data` are masked with `mask`,
132
as specified in section 5.3 of RFC 6455.
133
134
Note that this function mutates the `data` argument.
135
136
This pure-python implementation may be replaced by an optimized
137
version when available.
138
139
"""
140
assert isinstance(data, bytearray), data
141
assert len(mask) == 4, mask
142
143
if data:
144
a, b, c, d = (_XOR_TABLE[n] for n in mask)
145
data[::4] = data[::4].translate(a)
146
data[1::4] = data[1::4].translate(b)
147
data[2::4] = data[2::4].translate(c)
148
data[3::4] = data[3::4].translate(d)
149
150
151
if NO_EXTENSIONS: # pragma: no cover
152
_websocket_mask = _websocket_mask_python
153
else:
154
try:
155
from ._websocket import _websocket_mask_cython # type: ignore[import]
156
157
_websocket_mask = _websocket_mask_cython
158
except ImportError: # pragma: no cover
159
_websocket_mask = _websocket_mask_python
160
161
_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
162
163
164
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
165
r"^(?:;\s*(?:"
166
r"(server_no_context_takeover)|"
167
r"(client_no_context_takeover)|"
168
r"(server_max_window_bits(?:=(\d+))?)|"
169
r"(client_max_window_bits(?:=(\d+))?)))*$"
170
)
171
172
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
173
174
175
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
176
if not extstr:
177
return 0, False
178
179
compress = 0
180
notakeover = False
181
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
182
defext = ext.group(1)
183
# Return compress = 15 when get `permessage-deflate`
184
if not defext:
185
compress = 15
186
break
187
match = _WS_EXT_RE.match(defext)
188
if match:
189
compress = 15
190
if isserver:
191
# Server never fail to detect compress handshake.
192
# Server does not need to send max wbit to client
193
if match.group(4):
194
compress = int(match.group(4))
195
# Group3 must match if group4 matches
196
# Compress wbit 8 does not support in zlib
197
# If compress level not support,
198
# CONTINUE to next extension
199
if compress > 15 or compress < 9:
200
compress = 0
201
continue
202
if match.group(1):
203
notakeover = True
204
# Ignore regex group 5 & 6 for client_max_window_bits
205
break
206
else:
207
if match.group(6):
208
compress = int(match.group(6))
209
# Group5 must match if group6 matches
210
# Compress wbit 8 does not support in zlib
211
# If compress level not support,
212
# FAIL the parse progress
213
if compress > 15 or compress < 9:
214
raise WSHandshakeError("Invalid window size")
215
if match.group(2):
216
notakeover = True
217
# Ignore regex group 5 & 6 for client_max_window_bits
218
break
219
# Return Fail if client side and not match
220
elif not isserver:
221
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
222
223
return compress, notakeover
224
225
226
def ws_ext_gen(
227
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
228
) -> str:
229
# client_notakeover=False not used for server
230
# compress wbit 8 does not support in zlib
231
if compress < 9 or compress > 15:
232
raise ValueError(
233
"Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
234
)
235
enabledext = ["permessage-deflate"]
236
if not isserver:
237
enabledext.append("client_max_window_bits")
238
239
if compress < 15:
240
enabledext.append("server_max_window_bits=" + str(compress))
241
if server_notakeover:
242
enabledext.append("server_no_context_takeover")
243
# if client_notakeover:
244
# enabledext.append('client_no_context_takeover')
245
return "; ".join(enabledext)
246
247
248
class WSParserState(IntEnum):
249
READ_HEADER = 1
250
READ_PAYLOAD_LENGTH = 2
251
READ_PAYLOAD_MASK = 3
252
READ_PAYLOAD = 4
253
254
255
class WebSocketReader:
256
def __init__(
257
self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
258
) -> None:
259
self.queue = queue
260
self._max_msg_size = max_msg_size
261
262
self._exc = None # type: Optional[BaseException]
263
self._partial = bytearray()
264
self._state = WSParserState.READ_HEADER
265
266
self._opcode = None # type: Optional[int]
267
self._frame_fin = False
268
self._frame_opcode = None # type: Optional[int]
269
self._frame_payload = bytearray()
270
271
self._tail = b""
272
self._has_mask = False
273
self._frame_mask = None # type: Optional[bytes]
274
self._payload_length = 0
275
self._payload_length_flag = 0
276
self._compressed = None # type: Optional[bool]
277
self._decompressobj = None # type: Any # zlib.decompressobj actually
278
self._compress = compress
279
280
def feed_eof(self) -> None:
281
self.queue.feed_eof()
282
283
def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
284
if self._exc:
285
return True, data
286
287
try:
288
return self._feed_data(data)
289
except Exception as exc:
290
self._exc = exc
291
self.queue.set_exception(exc)
292
return True, b""
293
294
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
295
for fin, opcode, payload, compressed in self.parse_frame(data):
296
if compressed and not self._decompressobj:
297
self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
298
if opcode == WSMsgType.CLOSE:
299
if len(payload) >= 2:
300
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
301
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
302
raise WebSocketError(
303
WSCloseCode.PROTOCOL_ERROR,
304
f"Invalid close code: {close_code}",
305
)
306
try:
307
close_message = payload[2:].decode("utf-8")
308
except UnicodeDecodeError as exc:
309
raise WebSocketError(
310
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
311
) from exc
312
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
313
elif payload:
314
raise WebSocketError(
315
WSCloseCode.PROTOCOL_ERROR,
316
f"Invalid close frame: {fin} {opcode} {payload!r}",
317
)
318
else:
319
msg = WSMessage(WSMsgType.CLOSE, 0, "")
320
321
self.queue.feed_data(msg, 0)
322
323
elif opcode == WSMsgType.PING:
324
self.queue.feed_data(
325
WSMessage(WSMsgType.PING, payload, ""), len(payload)
326
)
327
328
elif opcode == WSMsgType.PONG:
329
self.queue.feed_data(
330
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
331
)
332
333
elif (
334
opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
335
and self._opcode is None
336
):
337
raise WebSocketError(
338
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
339
)
340
else:
341
# load text/binary
342
if not fin:
343
# got partial frame payload
344
if opcode != WSMsgType.CONTINUATION:
345
self._opcode = opcode
346
self._partial.extend(payload)
347
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
348
raise WebSocketError(
349
WSCloseCode.MESSAGE_TOO_BIG,
350
"Message size {} exceeds limit {}".format(
351
len(self._partial), self._max_msg_size
352
),
353
)
354
else:
355
# previous frame was non finished
356
# we should get continuation opcode
357
if self._partial:
358
if opcode != WSMsgType.CONTINUATION:
359
raise WebSocketError(
360
WSCloseCode.PROTOCOL_ERROR,
361
"The opcode in non-fin frame is expected "
362
"to be zero, got {!r}".format(opcode),
363
)
364
365
if opcode == WSMsgType.CONTINUATION:
366
assert self._opcode is not None
367
opcode = self._opcode
368
self._opcode = None
369
370
self._partial.extend(payload)
371
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
372
raise WebSocketError(
373
WSCloseCode.MESSAGE_TOO_BIG,
374
"Message size {} exceeds limit {}".format(
375
len(self._partial), self._max_msg_size
376
),
377
)
378
379
# Decompress process must to be done after all packets
380
# received.
381
if compressed:
382
self._partial.extend(_WS_DEFLATE_TRAILING)
383
payload_merged = self._decompressobj.decompress(
384
self._partial, self._max_msg_size
385
)
386
if self._decompressobj.unconsumed_tail:
387
left = len(self._decompressobj.unconsumed_tail)
388
raise WebSocketError(
389
WSCloseCode.MESSAGE_TOO_BIG,
390
"Decompressed message size {} exceeds limit {}".format(
391
self._max_msg_size + left, self._max_msg_size
392
),
393
)
394
else:
395
payload_merged = bytes(self._partial)
396
397
self._partial.clear()
398
399
if opcode == WSMsgType.TEXT:
400
try:
401
text = payload_merged.decode("utf-8")
402
self.queue.feed_data(
403
WSMessage(WSMsgType.TEXT, text, ""), len(text)
404
)
405
except UnicodeDecodeError as exc:
406
raise WebSocketError(
407
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
408
) from exc
409
else:
410
self.queue.feed_data(
411
WSMessage(WSMsgType.BINARY, payload_merged, ""),
412
len(payload_merged),
413
)
414
415
return False, b""
416
417
def parse_frame(
418
self, buf: bytes
419
) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
420
"""Return the next frame from the socket."""
421
frames = []
422
if self._tail:
423
buf, self._tail = self._tail + buf, b""
424
425
start_pos = 0
426
buf_length = len(buf)
427
428
while True:
429
# read header
430
if self._state == WSParserState.READ_HEADER:
431
if buf_length - start_pos >= 2:
432
data = buf[start_pos : start_pos + 2]
433
start_pos += 2
434
first_byte, second_byte = data
435
436
fin = (first_byte >> 7) & 1
437
rsv1 = (first_byte >> 6) & 1
438
rsv2 = (first_byte >> 5) & 1
439
rsv3 = (first_byte >> 4) & 1
440
opcode = first_byte & 0xF
441
442
# frame-fin = %x0 ; more frames of this message follow
443
# / %x1 ; final frame of this message
444
# frame-rsv1 = %x0 ;
445
# 1 bit, MUST be 0 unless negotiated otherwise
446
# frame-rsv2 = %x0 ;
447
# 1 bit, MUST be 0 unless negotiated otherwise
448
# frame-rsv3 = %x0 ;
449
# 1 bit, MUST be 0 unless negotiated otherwise
450
#
451
# Remove rsv1 from this test for deflate development
452
if rsv2 or rsv3 or (rsv1 and not self._compress):
453
raise WebSocketError(
454
WSCloseCode.PROTOCOL_ERROR,
455
"Received frame with non-zero reserved bits",
456
)
457
458
if opcode > 0x7 and fin == 0:
459
raise WebSocketError(
460
WSCloseCode.PROTOCOL_ERROR,
461
"Received fragmented control frame",
462
)
463
464
has_mask = (second_byte >> 7) & 1
465
length = second_byte & 0x7F
466
467
# Control frames MUST have a payload
468
# length of 125 bytes or less
469
if opcode > 0x7 and length > 125:
470
raise WebSocketError(
471
WSCloseCode.PROTOCOL_ERROR,
472
"Control frame payload cannot be " "larger than 125 bytes",
473
)
474
475
# Set compress status if last package is FIN
476
# OR set compress status if this is first fragment
477
# Raise error if not first fragment with rsv1 = 0x1
478
if self._frame_fin or self._compressed is None:
479
self._compressed = True if rsv1 else False
480
elif rsv1:
481
raise WebSocketError(
482
WSCloseCode.PROTOCOL_ERROR,
483
"Received frame with non-zero reserved bits",
484
)
485
486
self._frame_fin = bool(fin)
487
self._frame_opcode = opcode
488
self._has_mask = bool(has_mask)
489
self._payload_length_flag = length
490
self._state = WSParserState.READ_PAYLOAD_LENGTH
491
else:
492
break
493
494
# read payload length
495
if self._state == WSParserState.READ_PAYLOAD_LENGTH:
496
length = self._payload_length_flag
497
if length == 126:
498
if buf_length - start_pos >= 2:
499
data = buf[start_pos : start_pos + 2]
500
start_pos += 2
501
length = UNPACK_LEN2(data)[0]
502
self._payload_length = length
503
self._state = (
504
WSParserState.READ_PAYLOAD_MASK
505
if self._has_mask
506
else WSParserState.READ_PAYLOAD
507
)
508
else:
509
break
510
elif length > 126:
511
if buf_length - start_pos >= 8:
512
data = buf[start_pos : start_pos + 8]
513
start_pos += 8
514
length = UNPACK_LEN3(data)[0]
515
self._payload_length = length
516
self._state = (
517
WSParserState.READ_PAYLOAD_MASK
518
if self._has_mask
519
else WSParserState.READ_PAYLOAD
520
)
521
else:
522
break
523
else:
524
self._payload_length = length
525
self._state = (
526
WSParserState.READ_PAYLOAD_MASK
527
if self._has_mask
528
else WSParserState.READ_PAYLOAD
529
)
530
531
# read payload mask
532
if self._state == WSParserState.READ_PAYLOAD_MASK:
533
if buf_length - start_pos >= 4:
534
self._frame_mask = buf[start_pos : start_pos + 4]
535
start_pos += 4
536
self._state = WSParserState.READ_PAYLOAD
537
else:
538
break
539
540
if self._state == WSParserState.READ_PAYLOAD:
541
length = self._payload_length
542
payload = self._frame_payload
543
544
chunk_len = buf_length - start_pos
545
if length >= chunk_len:
546
self._payload_length = length - chunk_len
547
payload.extend(buf[start_pos:])
548
start_pos = buf_length
549
else:
550
self._payload_length = 0
551
payload.extend(buf[start_pos : start_pos + length])
552
start_pos = start_pos + length
553
554
if self._payload_length == 0:
555
if self._has_mask:
556
assert self._frame_mask is not None
557
_websocket_mask(self._frame_mask, payload)
558
559
frames.append(
560
(self._frame_fin, self._frame_opcode, payload, self._compressed)
561
)
562
563
self._frame_payload = bytearray()
564
self._state = WSParserState.READ_HEADER
565
else:
566
break
567
568
self._tail = buf[start_pos:]
569
570
return frames
571
572
573
class WebSocketWriter:
574
def __init__(
575
self,
576
protocol: BaseProtocol,
577
transport: asyncio.Transport,
578
*,
579
use_mask: bool = False,
580
limit: int = DEFAULT_LIMIT,
581
random: Any = random.Random(),
582
compress: int = 0,
583
notakeover: bool = False,
584
) -> None:
585
self.protocol = protocol
586
self.transport = transport
587
self.use_mask = use_mask
588
self.randrange = random.randrange
589
self.compress = compress
590
self.notakeover = notakeover
591
self._closing = False
592
self._limit = limit
593
self._output_size = 0
594
self._compressobj = None # type: Any # actually compressobj
595
596
async def _send_frame(
597
self, message: bytes, opcode: int, compress: Optional[int] = None
598
) -> None:
599
"""Send a frame over the websocket with message as its payload."""
600
if self._closing and not (opcode & WSMsgType.CLOSE):
601
raise ConnectionResetError("Cannot write to closing transport")
602
603
rsv = 0
604
605
# Only compress larger packets (disabled)
606
# Does small packet needs to be compressed?
607
# if self.compress and opcode < 8 and len(message) > 124:
608
if (compress or self.compress) and opcode < 8:
609
if compress:
610
# Do not set self._compress if compressing is for this frame
611
compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress)
612
else: # self.compress
613
if not self._compressobj:
614
self._compressobj = zlib.compressobj(
615
level=zlib.Z_BEST_SPEED, wbits=-self.compress
616
)
617
compressobj = self._compressobj
618
619
message = compressobj.compress(message)
620
message = message + compressobj.flush(
621
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
622
)
623
if message.endswith(_WS_DEFLATE_TRAILING):
624
message = message[:-4]
625
rsv = rsv | 0x40
626
627
msg_length = len(message)
628
629
use_mask = self.use_mask
630
if use_mask:
631
mask_bit = 0x80
632
else:
633
mask_bit = 0
634
635
if msg_length < 126:
636
header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
637
elif msg_length < (1 << 16):
638
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
639
else:
640
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
641
if use_mask:
642
mask = self.randrange(0, 0xFFFFFFFF)
643
mask = mask.to_bytes(4, "big")
644
message = bytearray(message)
645
_websocket_mask(mask, message)
646
self._write(header + mask + message)
647
self._output_size += len(header) + len(mask) + len(message)
648
else:
649
if len(message) > MSG_SIZE:
650
self._write(header)
651
self._write(message)
652
else:
653
self._write(header + message)
654
655
self._output_size += len(header) + len(message)
656
657
if self._output_size > self._limit:
658
self._output_size = 0
659
await self.protocol._drain_helper()
660
661
def _write(self, data: bytes) -> None:
662
if self.transport is None or self.transport.is_closing():
663
raise ConnectionResetError("Cannot write to closing transport")
664
self.transport.write(data)
665
666
async def pong(self, message: bytes = b"") -> None:
667
"""Send pong message."""
668
if isinstance(message, str):
669
message = message.encode("utf-8")
670
await self._send_frame(message, WSMsgType.PONG)
671
672
async def ping(self, message: bytes = b"") -> None:
673
"""Send ping message."""
674
if isinstance(message, str):
675
message = message.encode("utf-8")
676
await self._send_frame(message, WSMsgType.PING)
677
678
async def send(
679
self,
680
message: Union[str, bytes],
681
binary: bool = False,
682
compress: Optional[int] = None,
683
) -> None:
684
"""Send a frame over the websocket with message as its payload."""
685
if isinstance(message, str):
686
message = message.encode("utf-8")
687
if binary:
688
await self._send_frame(message, WSMsgType.BINARY, compress)
689
else:
690
await self._send_frame(message, WSMsgType.TEXT, compress)
691
692
async def close(self, code: int = 1000, message: bytes = b"") -> None:
693
"""Close the websocket, sending the specified code and message."""
694
if isinstance(message, str):
695
message = message.encode("utf-8")
696
try:
697
await self._send_frame(
698
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
699
)
700
finally:
701
self._closing = True
702
703