183 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			183 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Python
		
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
import sys
 | 
						|
from collections.abc import Callable, Iterable, Mapping
 | 
						|
from dataclasses import dataclass, field
 | 
						|
from typing import Any, SupportsIndex
 | 
						|
 | 
						|
from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
 | 
						|
from ..abc import (
 | 
						|
    AnyByteReceiveStream,
 | 
						|
    AnyByteStream,
 | 
						|
    AnyByteStreamConnectable,
 | 
						|
    ByteReceiveStream,
 | 
						|
    ByteStream,
 | 
						|
    ByteStreamConnectable,
 | 
						|
)
 | 
						|
 | 
						|
if sys.version_info >= (3, 12):
 | 
						|
    from typing import override
 | 
						|
else:
 | 
						|
    from typing_extensions import override
 | 
						|
 | 
						|
 | 
						|
@dataclass(eq=False)
 | 
						|
class BufferedByteReceiveStream(ByteReceiveStream):
 | 
						|
    """
 | 
						|
    Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
 | 
						|
    receiving capabilities in the form of a byte stream.
 | 
						|
    """
 | 
						|
 | 
						|
    receive_stream: AnyByteReceiveStream
 | 
						|
    _buffer: bytearray = field(init=False, default_factory=bytearray)
 | 
						|
    _closed: bool = field(init=False, default=False)
 | 
						|
 | 
						|
    async def aclose(self) -> None:
 | 
						|
        await self.receive_stream.aclose()
 | 
						|
        self._closed = True
 | 
						|
 | 
						|
    @property
 | 
						|
    def buffer(self) -> bytes:
 | 
						|
        """The bytes currently in the buffer."""
 | 
						|
        return bytes(self._buffer)
 | 
						|
 | 
						|
    @property
 | 
						|
    def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
 | 
						|
        return self.receive_stream.extra_attributes
 | 
						|
 | 
						|
    def feed_data(self, data: Iterable[SupportsIndex], /) -> None:
 | 
						|
        """
 | 
						|
        Append data directly into the buffer.
 | 
						|
 | 
						|
        Any data in the buffer will be consumed by receive operations before receiving
 | 
						|
        anything from the wrapped stream.
 | 
						|
 | 
						|
        :param data: the data to append to the buffer (can be bytes or anything else
 | 
						|
            that supports ``__index__()``)
 | 
						|
 | 
						|
        """
 | 
						|
        self._buffer.extend(data)
 | 
						|
 | 
						|
    async def receive(self, max_bytes: int = 65536) -> bytes:
 | 
						|
        if self._closed:
 | 
						|
            raise ClosedResourceError
 | 
						|
 | 
						|
        if self._buffer:
 | 
						|
            chunk = bytes(self._buffer[:max_bytes])
 | 
						|
            del self._buffer[:max_bytes]
 | 
						|
            return chunk
 | 
						|
        elif isinstance(self.receive_stream, ByteReceiveStream):
 | 
						|
            return await self.receive_stream.receive(max_bytes)
 | 
						|
        else:
 | 
						|
            # With a bytes-oriented object stream, we need to handle any surplus bytes
 | 
						|
            # we get from the receive() call
 | 
						|
            chunk = await self.receive_stream.receive()
 | 
						|
            if len(chunk) > max_bytes:
 | 
						|
                # Save the surplus bytes in the buffer
 | 
						|
                self._buffer.extend(chunk[max_bytes:])
 | 
						|
                return chunk[:max_bytes]
 | 
						|
            else:
 | 
						|
                return chunk
 | 
						|
 | 
						|
    async def receive_exactly(self, nbytes: int) -> bytes:
 | 
						|
        """
 | 
						|
        Read exactly the given amount of bytes from the stream.
 | 
						|
 | 
						|
        :param nbytes: the number of bytes to read
 | 
						|
        :return: the bytes read
 | 
						|
        :raises ~anyio.IncompleteRead: if the stream was closed before the requested
 | 
						|
            amount of bytes could be read from the stream
 | 
						|
 | 
						|
        """
 | 
						|
        while True:
 | 
						|
            remaining = nbytes - len(self._buffer)
 | 
						|
            if remaining <= 0:
 | 
						|
                retval = self._buffer[:nbytes]
 | 
						|
                del self._buffer[:nbytes]
 | 
						|
                return bytes(retval)
 | 
						|
 | 
						|
            try:
 | 
						|
                if isinstance(self.receive_stream, ByteReceiveStream):
 | 
						|
                    chunk = await self.receive_stream.receive(remaining)
 | 
						|
                else:
 | 
						|
                    chunk = await self.receive_stream.receive()
 | 
						|
            except EndOfStream as exc:
 | 
						|
                raise IncompleteRead from exc
 | 
						|
 | 
						|
            self._buffer.extend(chunk)
 | 
						|
 | 
						|
    async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
 | 
						|
        """
 | 
						|
        Read from the stream until the delimiter is found or max_bytes have been read.
 | 
						|
 | 
						|
        :param delimiter: the marker to look for in the stream
 | 
						|
        :param max_bytes: maximum number of bytes that will be read before raising
 | 
						|
            :exc:`~anyio.DelimiterNotFound`
 | 
						|
        :return: the bytes read (not including the delimiter)
 | 
						|
        :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
 | 
						|
            was found
 | 
						|
        :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
 | 
						|
            bytes read up to the maximum allowed
 | 
						|
 | 
						|
        """
 | 
						|
        delimiter_size = len(delimiter)
 | 
						|
        offset = 0
 | 
						|
        while True:
 | 
						|
            # Check if the delimiter can be found in the current buffer
 | 
						|
            index = self._buffer.find(delimiter, offset)
 | 
						|
            if index >= 0:
 | 
						|
                found = self._buffer[:index]
 | 
						|
                del self._buffer[: index + len(delimiter) :]
 | 
						|
                return bytes(found)
 | 
						|
 | 
						|
            # Check if the buffer is already at or over the limit
 | 
						|
            if len(self._buffer) >= max_bytes:
 | 
						|
                raise DelimiterNotFound(max_bytes)
 | 
						|
 | 
						|
            # Read more data into the buffer from the socket
 | 
						|
            try:
 | 
						|
                data = await self.receive_stream.receive()
 | 
						|
            except EndOfStream as exc:
 | 
						|
                raise IncompleteRead from exc
 | 
						|
 | 
						|
            # Move the offset forward and add the new data to the buffer
 | 
						|
            offset = max(len(self._buffer) - delimiter_size + 1, 0)
 | 
						|
            self._buffer.extend(data)
 | 
						|
 | 
						|
 | 
						|
class BufferedByteStream(BufferedByteReceiveStream, ByteStream):
 | 
						|
    """
 | 
						|
    A full-duplex variant of :class:`BufferedByteReceiveStream`. All writes are passed
 | 
						|
    through to the wrapped stream as-is.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, stream: AnyByteStream):
 | 
						|
        """
 | 
						|
        :param stream: the stream to be wrapped
 | 
						|
 | 
						|
        """
 | 
						|
        super().__init__(stream)
 | 
						|
        self._stream = stream
 | 
						|
 | 
						|
    @override
 | 
						|
    async def send_eof(self) -> None:
 | 
						|
        await self._stream.send_eof()
 | 
						|
 | 
						|
    @override
 | 
						|
    async def send(self, item: bytes) -> None:
 | 
						|
        await self._stream.send(item)
 | 
						|
 | 
						|
 | 
						|
class BufferedConnectable(ByteStreamConnectable):
 | 
						|
    def __init__(self, connectable: AnyByteStreamConnectable):
 | 
						|
        """
 | 
						|
        :param connectable: the connectable to wrap
 | 
						|
 | 
						|
        """
 | 
						|
        self.connectable = connectable
 | 
						|
 | 
						|
    @override
 | 
						|
    async def connect(self) -> BufferedByteStream:
 | 
						|
        stream = await self.connectable.connect()
 | 
						|
        return BufferedByteStream(stream)
 |