Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/cpython
Path: blob/main/Lib/asyncio/streams.py
12 views
1
__all__ = (
2
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3
'open_connection', 'start_server')
4
5
import collections
6
import socket
7
import sys
8
import weakref
9
10
if hasattr(socket, 'AF_UNIX'):
11
__all__ += ('open_unix_connection', 'start_unix_server')
12
13
from . import coroutines
14
from . import events
15
from . import exceptions
16
from . import format_helpers
17
from . import protocols
18
from .log import logger
19
from .tasks import sleep
20
21
22
_DEFAULT_LIMIT = 2 ** 16 # 64 KiB
23
24
25
async def open_connection(host=None, port=None, *,
26
limit=_DEFAULT_LIMIT, **kwds):
27
"""A wrapper for create_connection() returning a (reader, writer) pair.
28
29
The reader returned is a StreamReader instance; the writer is a
30
StreamWriter instance.
31
32
The arguments are all the usual arguments to create_connection()
33
except protocol_factory; most common are positional host and port,
34
with various optional keyword arguments following.
35
36
Additional optional keyword arguments are loop (to set the event loop
37
instance to use) and limit (to set the buffer limit passed to the
38
StreamReader).
39
40
(If you want to customize the StreamReader and/or
41
StreamReaderProtocol classes, just copy the code -- there's
42
really nothing special here except some convenience.)
43
"""
44
loop = events.get_running_loop()
45
reader = StreamReader(limit=limit, loop=loop)
46
protocol = StreamReaderProtocol(reader, loop=loop)
47
transport, _ = await loop.create_connection(
48
lambda: protocol, host, port, **kwds)
49
writer = StreamWriter(transport, protocol, reader, loop)
50
return reader, writer
51
52
53
async def start_server(client_connected_cb, host=None, port=None, *,
54
limit=_DEFAULT_LIMIT, **kwds):
55
"""Start a socket server, call back for each client connected.
56
57
The first parameter, `client_connected_cb`, takes two parameters:
58
client_reader, client_writer. client_reader is a StreamReader
59
object, while client_writer is a StreamWriter object. This
60
parameter can either be a plain callback function or a coroutine;
61
if it is a coroutine, it will be automatically converted into a
62
Task.
63
64
The rest of the arguments are all the usual arguments to
65
loop.create_server() except protocol_factory; most common are
66
positional host and port, with various optional keyword arguments
67
following. The return value is the same as loop.create_server().
68
69
Additional optional keyword arguments are loop (to set the event loop
70
instance to use) and limit (to set the buffer limit passed to the
71
StreamReader).
72
73
The return value is the same as loop.create_server(), i.e. a
74
Server object which can be used to stop the service.
75
"""
76
loop = events.get_running_loop()
77
78
def factory():
79
reader = StreamReader(limit=limit, loop=loop)
80
protocol = StreamReaderProtocol(reader, client_connected_cb,
81
loop=loop)
82
return protocol
83
84
return await loop.create_server(factory, host, port, **kwds)
85
86
87
if hasattr(socket, 'AF_UNIX'):
88
# UNIX Domain Sockets are supported on this platform
89
90
async def open_unix_connection(path=None, *,
91
limit=_DEFAULT_LIMIT, **kwds):
92
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
93
loop = events.get_running_loop()
94
95
reader = StreamReader(limit=limit, loop=loop)
96
protocol = StreamReaderProtocol(reader, loop=loop)
97
transport, _ = await loop.create_unix_connection(
98
lambda: protocol, path, **kwds)
99
writer = StreamWriter(transport, protocol, reader, loop)
100
return reader, writer
101
102
async def start_unix_server(client_connected_cb, path=None, *,
103
limit=_DEFAULT_LIMIT, **kwds):
104
"""Similar to `start_server` but works with UNIX Domain Sockets."""
105
loop = events.get_running_loop()
106
107
def factory():
108
reader = StreamReader(limit=limit, loop=loop)
109
protocol = StreamReaderProtocol(reader, client_connected_cb,
110
loop=loop)
111
return protocol
112
113
return await loop.create_unix_server(factory, path, **kwds)
114
115
116
class FlowControlMixin(protocols.Protocol):
117
"""Reusable flow control logic for StreamWriter.drain().
118
119
This implements the protocol methods pause_writing(),
120
resume_writing() and connection_lost(). If the subclass overrides
121
these it must call the super methods.
122
123
StreamWriter.drain() must wait for _drain_helper() coroutine.
124
"""
125
126
def __init__(self, loop=None):
127
if loop is None:
128
self._loop = events.get_event_loop()
129
else:
130
self._loop = loop
131
self._paused = False
132
self._drain_waiters = collections.deque()
133
self._connection_lost = False
134
135
def pause_writing(self):
136
assert not self._paused
137
self._paused = True
138
if self._loop.get_debug():
139
logger.debug("%r pauses writing", self)
140
141
def resume_writing(self):
142
assert self._paused
143
self._paused = False
144
if self._loop.get_debug():
145
logger.debug("%r resumes writing", self)
146
147
for waiter in self._drain_waiters:
148
if not waiter.done():
149
waiter.set_result(None)
150
151
def connection_lost(self, exc):
152
self._connection_lost = True
153
# Wake up the writer(s) if currently paused.
154
if not self._paused:
155
return
156
157
for waiter in self._drain_waiters:
158
if not waiter.done():
159
if exc is None:
160
waiter.set_result(None)
161
else:
162
waiter.set_exception(exc)
163
164
async def _drain_helper(self):
165
if self._connection_lost:
166
raise ConnectionResetError('Connection lost')
167
if not self._paused:
168
return
169
waiter = self._loop.create_future()
170
self._drain_waiters.append(waiter)
171
try:
172
await waiter
173
finally:
174
self._drain_waiters.remove(waiter)
175
176
def _get_close_waiter(self, stream):
177
raise NotImplementedError
178
179
180
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
181
"""Helper class to adapt between Protocol and StreamReader.
182
183
(This is a helper class instead of making StreamReader itself a
184
Protocol subclass, because the StreamReader has other potential
185
uses, and to prevent the user of the StreamReader to accidentally
186
call inappropriate methods of the protocol.)
187
"""
188
189
_source_traceback = None
190
191
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
192
super().__init__(loop=loop)
193
if stream_reader is not None:
194
self._stream_reader_wr = weakref.ref(stream_reader)
195
self._source_traceback = stream_reader._source_traceback
196
else:
197
self._stream_reader_wr = None
198
if client_connected_cb is not None:
199
# This is a stream created by the `create_server()` function.
200
# Keep a strong reference to the reader until a connection
201
# is established.
202
self._strong_reader = stream_reader
203
self._reject_connection = False
204
self._stream_writer = None
205
self._task = None
206
self._transport = None
207
self._client_connected_cb = client_connected_cb
208
self._over_ssl = False
209
self._closed = self._loop.create_future()
210
211
@property
212
def _stream_reader(self):
213
if self._stream_reader_wr is None:
214
return None
215
return self._stream_reader_wr()
216
217
def _replace_writer(self, writer):
218
loop = self._loop
219
transport = writer.transport
220
self._stream_writer = writer
221
self._transport = transport
222
self._over_ssl = transport.get_extra_info('sslcontext') is not None
223
224
def connection_made(self, transport):
225
if self._reject_connection:
226
context = {
227
'message': ('An open stream was garbage collected prior to '
228
'establishing network connection; '
229
'call "stream.close()" explicitly.')
230
}
231
if self._source_traceback:
232
context['source_traceback'] = self._source_traceback
233
self._loop.call_exception_handler(context)
234
transport.abort()
235
return
236
self._transport = transport
237
reader = self._stream_reader
238
if reader is not None:
239
reader.set_transport(transport)
240
self._over_ssl = transport.get_extra_info('sslcontext') is not None
241
if self._client_connected_cb is not None:
242
self._stream_writer = StreamWriter(transport, self,
243
reader,
244
self._loop)
245
res = self._client_connected_cb(reader,
246
self._stream_writer)
247
if coroutines.iscoroutine(res):
248
self._task = self._loop.create_task(res)
249
self._strong_reader = None
250
251
def connection_lost(self, exc):
252
reader = self._stream_reader
253
if reader is not None:
254
if exc is None:
255
reader.feed_eof()
256
else:
257
reader.set_exception(exc)
258
if not self._closed.done():
259
if exc is None:
260
self._closed.set_result(None)
261
else:
262
self._closed.set_exception(exc)
263
super().connection_lost(exc)
264
self._stream_reader_wr = None
265
self._stream_writer = None
266
self._task = None
267
self._transport = None
268
269
def data_received(self, data):
270
reader = self._stream_reader
271
if reader is not None:
272
reader.feed_data(data)
273
274
def eof_received(self):
275
reader = self._stream_reader
276
if reader is not None:
277
reader.feed_eof()
278
if self._over_ssl:
279
# Prevent a warning in SSLProtocol.eof_received:
280
# "returning true from eof_received()
281
# has no effect when using ssl"
282
return False
283
return True
284
285
def _get_close_waiter(self, stream):
286
return self._closed
287
288
def __del__(self):
289
# Prevent reports about unhandled exceptions.
290
# Better than self._closed._log_traceback = False hack
291
try:
292
closed = self._closed
293
except AttributeError:
294
pass # failed constructor
295
else:
296
if closed.done() and not closed.cancelled():
297
closed.exception()
298
299
300
class StreamWriter:
301
"""Wraps a Transport.
302
303
This exposes write(), writelines(), [can_]write_eof(),
304
get_extra_info() and close(). It adds drain() which returns an
305
optional Future on which you can wait for flow control. It also
306
adds a transport property which references the Transport
307
directly.
308
"""
309
310
def __init__(self, transport, protocol, reader, loop):
311
self._transport = transport
312
self._protocol = protocol
313
# drain() expects that the reader has an exception() method
314
assert reader is None or isinstance(reader, StreamReader)
315
self._reader = reader
316
self._loop = loop
317
self._complete_fut = self._loop.create_future()
318
self._complete_fut.set_result(None)
319
320
def __repr__(self):
321
info = [self.__class__.__name__, f'transport={self._transport!r}']
322
if self._reader is not None:
323
info.append(f'reader={self._reader!r}')
324
return '<{}>'.format(' '.join(info))
325
326
@property
327
def transport(self):
328
return self._transport
329
330
def write(self, data):
331
self._transport.write(data)
332
333
def writelines(self, data):
334
self._transport.writelines(data)
335
336
def write_eof(self):
337
return self._transport.write_eof()
338
339
def can_write_eof(self):
340
return self._transport.can_write_eof()
341
342
def close(self):
343
return self._transport.close()
344
345
def is_closing(self):
346
return self._transport.is_closing()
347
348
async def wait_closed(self):
349
await self._protocol._get_close_waiter(self)
350
351
def get_extra_info(self, name, default=None):
352
return self._transport.get_extra_info(name, default)
353
354
async def drain(self):
355
"""Flush the write buffer.
356
357
The intended use is to write
358
359
w.write(data)
360
await w.drain()
361
"""
362
if self._reader is not None:
363
exc = self._reader.exception()
364
if exc is not None:
365
raise exc
366
if self._transport.is_closing():
367
# Wait for protocol.connection_lost() call
368
# Raise connection closing error if any,
369
# ConnectionResetError otherwise
370
# Yield to the event loop so connection_lost() may be
371
# called. Without this, _drain_helper() would return
372
# immediately, and code that calls
373
# write(...); await drain()
374
# in a loop would never call connection_lost(), so it
375
# would not see an error when the socket is closed.
376
await sleep(0)
377
await self._protocol._drain_helper()
378
379
async def start_tls(self, sslcontext, *,
380
server_hostname=None,
381
ssl_handshake_timeout=None,
382
ssl_shutdown_timeout=None):
383
"""Upgrade an existing stream-based connection to TLS."""
384
server_side = self._protocol._client_connected_cb is not None
385
protocol = self._protocol
386
await self.drain()
387
new_transport = await self._loop.start_tls( # type: ignore
388
self._transport, protocol, sslcontext,
389
server_side=server_side, server_hostname=server_hostname,
390
ssl_handshake_timeout=ssl_handshake_timeout,
391
ssl_shutdown_timeout=ssl_shutdown_timeout)
392
self._transport = new_transport
393
protocol._replace_writer(self)
394
395
396
class StreamReader:
397
398
_source_traceback = None
399
400
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
401
# The line length limit is a security feature;
402
# it also doubles as half the buffer limit.
403
404
if limit <= 0:
405
raise ValueError('Limit cannot be <= 0')
406
407
self._limit = limit
408
if loop is None:
409
self._loop = events.get_event_loop()
410
else:
411
self._loop = loop
412
self._buffer = bytearray()
413
self._eof = False # Whether we're done.
414
self._waiter = None # A future used by _wait_for_data()
415
self._exception = None
416
self._transport = None
417
self._paused = False
418
if self._loop.get_debug():
419
self._source_traceback = format_helpers.extract_stack(
420
sys._getframe(1))
421
422
def __repr__(self):
423
info = ['StreamReader']
424
if self._buffer:
425
info.append(f'{len(self._buffer)} bytes')
426
if self._eof:
427
info.append('eof')
428
if self._limit != _DEFAULT_LIMIT:
429
info.append(f'limit={self._limit}')
430
if self._waiter:
431
info.append(f'waiter={self._waiter!r}')
432
if self._exception:
433
info.append(f'exception={self._exception!r}')
434
if self._transport:
435
info.append(f'transport={self._transport!r}')
436
if self._paused:
437
info.append('paused')
438
return '<{}>'.format(' '.join(info))
439
440
def exception(self):
441
return self._exception
442
443
def set_exception(self, exc):
444
self._exception = exc
445
446
waiter = self._waiter
447
if waiter is not None:
448
self._waiter = None
449
if not waiter.cancelled():
450
waiter.set_exception(exc)
451
452
def _wakeup_waiter(self):
453
"""Wakeup read*() functions waiting for data or EOF."""
454
waiter = self._waiter
455
if waiter is not None:
456
self._waiter = None
457
if not waiter.cancelled():
458
waiter.set_result(None)
459
460
def set_transport(self, transport):
461
assert self._transport is None, 'Transport already set'
462
self._transport = transport
463
464
def _maybe_resume_transport(self):
465
if self._paused and len(self._buffer) <= self._limit:
466
self._paused = False
467
self._transport.resume_reading()
468
469
def feed_eof(self):
470
self._eof = True
471
self._wakeup_waiter()
472
473
def at_eof(self):
474
"""Return True if the buffer is empty and 'feed_eof' was called."""
475
return self._eof and not self._buffer
476
477
def feed_data(self, data):
478
assert not self._eof, 'feed_data after feed_eof'
479
480
if not data:
481
return
482
483
self._buffer.extend(data)
484
self._wakeup_waiter()
485
486
if (self._transport is not None and
487
not self._paused and
488
len(self._buffer) > 2 * self._limit):
489
try:
490
self._transport.pause_reading()
491
except NotImplementedError:
492
# The transport can't be paused.
493
# We'll just have to buffer all data.
494
# Forget the transport so we don't keep trying.
495
self._transport = None
496
else:
497
self._paused = True
498
499
async def _wait_for_data(self, func_name):
500
"""Wait until feed_data() or feed_eof() is called.
501
502
If stream was paused, automatically resume it.
503
"""
504
# StreamReader uses a future to link the protocol feed_data() method
505
# to a read coroutine. Running two read coroutines at the same time
506
# would have an unexpected behaviour. It would not possible to know
507
# which coroutine would get the next data.
508
if self._waiter is not None:
509
raise RuntimeError(
510
f'{func_name}() called while another coroutine is '
511
f'already waiting for incoming data')
512
513
assert not self._eof, '_wait_for_data after EOF'
514
515
# Waiting for data while paused will make deadlock, so prevent it.
516
# This is essential for readexactly(n) for case when n > self._limit.
517
if self._paused:
518
self._paused = False
519
self._transport.resume_reading()
520
521
self._waiter = self._loop.create_future()
522
try:
523
await self._waiter
524
finally:
525
self._waiter = None
526
527
async def readline(self):
528
"""Read chunk of data from the stream until newline (b'\n') is found.
529
530
On success, return chunk that ends with newline. If only partial
531
line can be read due to EOF, return incomplete line without
532
terminating newline. When EOF was reached while no bytes read, empty
533
bytes object is returned.
534
535
If limit is reached, ValueError will be raised. In that case, if
536
newline was found, complete line including newline will be removed
537
from internal buffer. Else, internal buffer will be cleared. Limit is
538
compared against part of the line without newline.
539
540
If stream was paused, this function will automatically resume it if
541
needed.
542
"""
543
sep = b'\n'
544
seplen = len(sep)
545
try:
546
line = await self.readuntil(sep)
547
except exceptions.IncompleteReadError as e:
548
return e.partial
549
except exceptions.LimitOverrunError as e:
550
if self._buffer.startswith(sep, e.consumed):
551
del self._buffer[:e.consumed + seplen]
552
else:
553
self._buffer.clear()
554
self._maybe_resume_transport()
555
raise ValueError(e.args[0])
556
return line
557
558
async def readuntil(self, separator=b'\n'):
559
"""Read data from the stream until ``separator`` is found.
560
561
On success, the data and separator will be removed from the
562
internal buffer (consumed). Returned data will include the
563
separator at the end.
564
565
Configured stream limit is used to check result. Limit sets the
566
maximal length of data that can be returned, not counting the
567
separator.
568
569
If an EOF occurs and the complete separator is still not found,
570
an IncompleteReadError exception will be raised, and the internal
571
buffer will be reset. The IncompleteReadError.partial attribute
572
may contain the separator partially.
573
574
If the data cannot be read because of over limit, a
575
LimitOverrunError exception will be raised, and the data
576
will be left in the internal buffer, so it can be read again.
577
"""
578
seplen = len(separator)
579
if seplen == 0:
580
raise ValueError('Separator should be at least one-byte string')
581
582
if self._exception is not None:
583
raise self._exception
584
585
# Consume whole buffer except last bytes, which length is
586
# one less than seplen. Let's check corner cases with
587
# separator='SEPARATOR':
588
# * we have received almost complete separator (without last
589
# byte). i.e buffer='some textSEPARATO'. In this case we
590
# can safely consume len(separator) - 1 bytes.
591
# * last byte of buffer is first byte of separator, i.e.
592
# buffer='abcdefghijklmnopqrS'. We may safely consume
593
# everything except that last byte, but this require to
594
# analyze bytes of buffer that match partial separator.
595
# This is slow and/or require FSM. For this case our
596
# implementation is not optimal, since require rescanning
597
# of data that is known to not belong to separator. In
598
# real world, separator will not be so long to notice
599
# performance problems. Even when reading MIME-encoded
600
# messages :)
601
602
# `offset` is the number of bytes from the beginning of the buffer
603
# where there is no occurrence of `separator`.
604
offset = 0
605
606
# Loop until we find `separator` in the buffer, exceed the buffer size,
607
# or an EOF has happened.
608
while True:
609
buflen = len(self._buffer)
610
611
# Check if we now have enough data in the buffer for `separator` to
612
# fit.
613
if buflen - offset >= seplen:
614
isep = self._buffer.find(separator, offset)
615
616
if isep != -1:
617
# `separator` is in the buffer. `isep` will be used later
618
# to retrieve the data.
619
break
620
621
# see upper comment for explanation.
622
offset = buflen + 1 - seplen
623
if offset > self._limit:
624
raise exceptions.LimitOverrunError(
625
'Separator is not found, and chunk exceed the limit',
626
offset)
627
628
# Complete message (with full separator) may be present in buffer
629
# even when EOF flag is set. This may happen when the last chunk
630
# adds data which makes separator be found. That's why we check for
631
# EOF *ater* inspecting the buffer.
632
if self._eof:
633
chunk = bytes(self._buffer)
634
self._buffer.clear()
635
raise exceptions.IncompleteReadError(chunk, None)
636
637
# _wait_for_data() will resume reading if stream was paused.
638
await self._wait_for_data('readuntil')
639
640
if isep > self._limit:
641
raise exceptions.LimitOverrunError(
642
'Separator is found, but chunk is longer than limit', isep)
643
644
chunk = self._buffer[:isep + seplen]
645
del self._buffer[:isep + seplen]
646
self._maybe_resume_transport()
647
return bytes(chunk)
648
649
async def read(self, n=-1):
650
"""Read up to `n` bytes from the stream.
651
652
If `n` is not provided or set to -1,
653
read until EOF, then return all read bytes.
654
If EOF was received and the internal buffer is empty,
655
return an empty bytes object.
656
657
If `n` is 0, return an empty bytes object immediately.
658
659
If `n` is positive, return at most `n` available bytes
660
as soon as at least 1 byte is available in the internal buffer.
661
If EOF is received before any byte is read, return an empty
662
bytes object.
663
664
Returned value is not limited with limit, configured at stream
665
creation.
666
667
If stream was paused, this function will automatically resume it if
668
needed.
669
"""
670
671
if self._exception is not None:
672
raise self._exception
673
674
if n == 0:
675
return b''
676
677
if n < 0:
678
# This used to just loop creating a new waiter hoping to
679
# collect everything in self._buffer, but that would
680
# deadlock if the subprocess sends more than self.limit
681
# bytes. So just call self.read(self._limit) until EOF.
682
blocks = []
683
while True:
684
block = await self.read(self._limit)
685
if not block:
686
break
687
blocks.append(block)
688
return b''.join(blocks)
689
690
if not self._buffer and not self._eof:
691
await self._wait_for_data('read')
692
693
# This will work right even if buffer is less than n bytes
694
data = bytes(memoryview(self._buffer)[:n])
695
del self._buffer[:n]
696
697
self._maybe_resume_transport()
698
return data
699
700
async def readexactly(self, n):
701
"""Read exactly `n` bytes.
702
703
Raise an IncompleteReadError if EOF is reached before `n` bytes can be
704
read. The IncompleteReadError.partial attribute of the exception will
705
contain the partial read bytes.
706
707
if n is zero, return empty bytes object.
708
709
Returned value is not limited with limit, configured at stream
710
creation.
711
712
If stream was paused, this function will automatically resume it if
713
needed.
714
"""
715
if n < 0:
716
raise ValueError('readexactly size can not be less than zero')
717
718
if self._exception is not None:
719
raise self._exception
720
721
if n == 0:
722
return b''
723
724
while len(self._buffer) < n:
725
if self._eof:
726
incomplete = bytes(self._buffer)
727
self._buffer.clear()
728
raise exceptions.IncompleteReadError(incomplete, n)
729
730
await self._wait_for_data('readexactly')
731
732
if len(self._buffer) == n:
733
data = bytes(self._buffer)
734
self._buffer.clear()
735
else:
736
data = bytes(memoryview(self._buffer)[:n])
737
del self._buffer[:n]
738
self._maybe_resume_transport()
739
return data
740
741
def __aiter__(self):
742
return self
743
744
async def __anext__(self):
745
val = await self.readline()
746
if val == b'':
747
raise StopAsyncIteration
748
return val
749
750