321 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			321 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
import json
 | 
						|
import typing
 | 
						|
from http import cookies as http_cookies
 | 
						|
 | 
						|
import anyio
 | 
						|
 | 
						|
from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
 | 
						|
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
 | 
						|
from starlette.exceptions import HTTPException
 | 
						|
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
 | 
						|
from starlette.types import Message, Receive, Scope, Send
 | 
						|
 | 
						|
try:
 | 
						|
    from multipart.multipart import parse_options_header
 | 
						|
except ModuleNotFoundError:  # pragma: nocover
 | 
						|
    parse_options_header = None
 | 
						|
 | 
						|
 | 
						|
if typing.TYPE_CHECKING:
 | 
						|
    from starlette.routing import Router
 | 
						|
 | 
						|
 | 
						|
SERVER_PUSH_HEADERS_TO_COPY = {
 | 
						|
    "accept",
 | 
						|
    "accept-encoding",
 | 
						|
    "accept-language",
 | 
						|
    "cache-control",
 | 
						|
    "user-agent",
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def cookie_parser(cookie_string: str) -> dict[str, str]:
 | 
						|
    """
 | 
						|
    This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
 | 
						|
 | 
						|
    It attempts to mimic browser cookie parsing behavior: browsers and web servers
 | 
						|
    frequently disregard the spec (RFC 6265) when setting and reading cookies,
 | 
						|
    so we attempt to suit the common scenarios here.
 | 
						|
 | 
						|
    This function has been adapted from Django 3.1.0.
 | 
						|
    Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
 | 
						|
    on an outdated spec and will fail on lots of input we want to support
 | 
						|
    """
 | 
						|
    cookie_dict: dict[str, str] = {}
 | 
						|
    for chunk in cookie_string.split(";"):
 | 
						|
        if "=" in chunk:
 | 
						|
            key, val = chunk.split("=", 1)
 | 
						|
        else:
 | 
						|
            # Assume an empty name per
 | 
						|
            # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
 | 
						|
            key, val = "", chunk
 | 
						|
        key, val = key.strip(), val.strip()
 | 
						|
        if key or val:
 | 
						|
            # unquote using Python's algorithm.
 | 
						|
            cookie_dict[key] = http_cookies._unquote(val)
 | 
						|
    return cookie_dict
 | 
						|
 | 
						|
 | 
						|
class ClientDisconnect(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class HTTPConnection(typing.Mapping[str, typing.Any]):
 | 
						|
    """
 | 
						|
    A base class for incoming HTTP connections, that is used to provide
 | 
						|
    any functionality that is common to both `Request` and `WebSocket`.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
 | 
						|
        assert scope["type"] in ("http", "websocket")
 | 
						|
        self.scope = scope
 | 
						|
 | 
						|
    def __getitem__(self, key: str) -> typing.Any:
 | 
						|
        return self.scope[key]
 | 
						|
 | 
						|
    def __iter__(self) -> typing.Iterator[str]:
 | 
						|
        return iter(self.scope)
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        return len(self.scope)
 | 
						|
 | 
						|
    # Don't use the `abc.Mapping.__eq__` implementation.
 | 
						|
    # Connection instances should never be considered equal
 | 
						|
    # unless `self is other`.
 | 
						|
    __eq__ = object.__eq__
 | 
						|
    __hash__ = object.__hash__
 | 
						|
 | 
						|
    @property
 | 
						|
    def app(self) -> typing.Any:
 | 
						|
        return self.scope["app"]
 | 
						|
 | 
						|
    @property
 | 
						|
    def url(self) -> URL:
 | 
						|
        if not hasattr(self, "_url"):
 | 
						|
            self._url = URL(scope=self.scope)
 | 
						|
        return self._url
 | 
						|
 | 
						|
    @property
 | 
						|
    def base_url(self) -> URL:
 | 
						|
        if not hasattr(self, "_base_url"):
 | 
						|
            base_url_scope = dict(self.scope)
 | 
						|
            # This is used by request.url_for, it might be used inside a Mount which
 | 
						|
            # would have its own child scope with its own root_path, but the base URL
 | 
						|
            # for url_for should still be the top level app root path.
 | 
						|
            app_root_path = base_url_scope.get(
 | 
						|
                "app_root_path", base_url_scope.get("root_path", "")
 | 
						|
            )
 | 
						|
            path = app_root_path
 | 
						|
            if not path.endswith("/"):
 | 
						|
                path += "/"
 | 
						|
            base_url_scope["path"] = path
 | 
						|
            base_url_scope["query_string"] = b""
 | 
						|
            base_url_scope["root_path"] = app_root_path
 | 
						|
            self._base_url = URL(scope=base_url_scope)
 | 
						|
        return self._base_url
 | 
						|
 | 
						|
    @property
 | 
						|
    def headers(self) -> Headers:
 | 
						|
        if not hasattr(self, "_headers"):
 | 
						|
            self._headers = Headers(scope=self.scope)
 | 
						|
        return self._headers
 | 
						|
 | 
						|
    @property
 | 
						|
    def query_params(self) -> QueryParams:
 | 
						|
        if not hasattr(self, "_query_params"):
 | 
						|
            self._query_params = QueryParams(self.scope["query_string"])
 | 
						|
        return self._query_params
 | 
						|
 | 
						|
    @property
 | 
						|
    def path_params(self) -> dict[str, typing.Any]:
 | 
						|
        return self.scope.get("path_params", {})
 | 
						|
 | 
						|
    @property
 | 
						|
    def cookies(self) -> dict[str, str]:
 | 
						|
        if not hasattr(self, "_cookies"):
 | 
						|
            cookies: dict[str, str] = {}
 | 
						|
            cookie_header = self.headers.get("cookie")
 | 
						|
 | 
						|
            if cookie_header:
 | 
						|
                cookies = cookie_parser(cookie_header)
 | 
						|
            self._cookies = cookies
 | 
						|
        return self._cookies
 | 
						|
 | 
						|
    @property
 | 
						|
    def client(self) -> Address | None:
 | 
						|
        # client is a 2 item tuple of (host, port), None or missing
 | 
						|
        host_port = self.scope.get("client")
 | 
						|
        if host_port is not None:
 | 
						|
            return Address(*host_port)
 | 
						|
        return None
 | 
						|
 | 
						|
    @property
 | 
						|
    def session(self) -> dict[str, typing.Any]:
 | 
						|
        assert (
 | 
						|
            "session" in self.scope
 | 
						|
        ), "SessionMiddleware must be installed to access request.session"
 | 
						|
        return self.scope["session"]  # type: ignore[no-any-return]
 | 
						|
 | 
						|
    @property
 | 
						|
    def auth(self) -> typing.Any:
 | 
						|
        assert (
 | 
						|
            "auth" in self.scope
 | 
						|
        ), "AuthenticationMiddleware must be installed to access request.auth"
 | 
						|
        return self.scope["auth"]
 | 
						|
 | 
						|
    @property
 | 
						|
    def user(self) -> typing.Any:
 | 
						|
        assert (
 | 
						|
            "user" in self.scope
 | 
						|
        ), "AuthenticationMiddleware must be installed to access request.user"
 | 
						|
        return self.scope["user"]
 | 
						|
 | 
						|
    @property
 | 
						|
    def state(self) -> State:
 | 
						|
        if not hasattr(self, "_state"):
 | 
						|
            # Ensure 'state' has an empty dict if it's not already populated.
 | 
						|
            self.scope.setdefault("state", {})
 | 
						|
            # Create a state instance with a reference to the dict in which it should
 | 
						|
            # store info
 | 
						|
            self._state = State(self.scope["state"])
 | 
						|
        return self._state
 | 
						|
 | 
						|
    def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
 | 
						|
        router: Router = self.scope["router"]
 | 
						|
        url_path = router.url_path_for(name, **path_params)
 | 
						|
        return url_path.make_absolute_url(base_url=self.base_url)
 | 
						|
 | 
						|
 | 
						|
async def empty_receive() -> typing.NoReturn:
 | 
						|
    raise RuntimeError("Receive channel has not been made available")
 | 
						|
 | 
						|
 | 
						|
async def empty_send(message: Message) -> typing.NoReturn:
 | 
						|
    raise RuntimeError("Send channel has not been made available")
 | 
						|
 | 
						|
 | 
						|
class Request(HTTPConnection):
 | 
						|
    _form: FormData | None
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
 | 
						|
    ):
 | 
						|
        super().__init__(scope)
 | 
						|
        assert scope["type"] == "http"
 | 
						|
        self._receive = receive
 | 
						|
        self._send = send
 | 
						|
        self._stream_consumed = False
 | 
						|
        self._is_disconnected = False
 | 
						|
        self._form = None
 | 
						|
 | 
						|
    @property
 | 
						|
    def method(self) -> str:
 | 
						|
        return typing.cast(str, self.scope["method"])
 | 
						|
 | 
						|
    @property
 | 
						|
    def receive(self) -> Receive:
 | 
						|
        return self._receive
 | 
						|
 | 
						|
    async def stream(self) -> typing.AsyncGenerator[bytes, None]:
 | 
						|
        if hasattr(self, "_body"):
 | 
						|
            yield self._body
 | 
						|
            yield b""
 | 
						|
            return
 | 
						|
        if self._stream_consumed:
 | 
						|
            raise RuntimeError("Stream consumed")
 | 
						|
        while not self._stream_consumed:
 | 
						|
            message = await self._receive()
 | 
						|
            if message["type"] == "http.request":
 | 
						|
                body = message.get("body", b"")
 | 
						|
                if not message.get("more_body", False):
 | 
						|
                    self._stream_consumed = True
 | 
						|
                if body:
 | 
						|
                    yield body
 | 
						|
            elif message["type"] == "http.disconnect":
 | 
						|
                self._is_disconnected = True
 | 
						|
                raise ClientDisconnect()
 | 
						|
        yield b""
 | 
						|
 | 
						|
    async def body(self) -> bytes:
 | 
						|
        if not hasattr(self, "_body"):
 | 
						|
            chunks: list[bytes] = []
 | 
						|
            async for chunk in self.stream():
 | 
						|
                chunks.append(chunk)
 | 
						|
            self._body = b"".join(chunks)
 | 
						|
        return self._body
 | 
						|
 | 
						|
    async def json(self) -> typing.Any:
 | 
						|
        if not hasattr(self, "_json"):
 | 
						|
            body = await self.body()
 | 
						|
            self._json = json.loads(body)
 | 
						|
        return self._json
 | 
						|
 | 
						|
    async def _get_form(
 | 
						|
        self, *, max_files: int | float = 1000, max_fields: int | float = 1000
 | 
						|
    ) -> FormData:
 | 
						|
        if self._form is None:
 | 
						|
            assert (
 | 
						|
                parse_options_header is not None
 | 
						|
            ), "The `python-multipart` library must be installed to use form parsing."
 | 
						|
            content_type_header = self.headers.get("Content-Type")
 | 
						|
            content_type: bytes
 | 
						|
            content_type, _ = parse_options_header(content_type_header)
 | 
						|
            if content_type == b"multipart/form-data":
 | 
						|
                try:
 | 
						|
                    multipart_parser = MultiPartParser(
 | 
						|
                        self.headers,
 | 
						|
                        self.stream(),
 | 
						|
                        max_files=max_files,
 | 
						|
                        max_fields=max_fields,
 | 
						|
                    )
 | 
						|
                    self._form = await multipart_parser.parse()
 | 
						|
                except MultiPartException as exc:
 | 
						|
                    if "app" in self.scope:
 | 
						|
                        raise HTTPException(status_code=400, detail=exc.message)
 | 
						|
                    raise exc
 | 
						|
            elif content_type == b"application/x-www-form-urlencoded":
 | 
						|
                form_parser = FormParser(self.headers, self.stream())
 | 
						|
                self._form = await form_parser.parse()
 | 
						|
            else:
 | 
						|
                self._form = FormData()
 | 
						|
        return self._form
 | 
						|
 | 
						|
    def form(
 | 
						|
        self, *, max_files: int | float = 1000, max_fields: int | float = 1000
 | 
						|
    ) -> AwaitableOrContextManager[FormData]:
 | 
						|
        return AwaitableOrContextManagerWrapper(
 | 
						|
            self._get_form(max_files=max_files, max_fields=max_fields)
 | 
						|
        )
 | 
						|
 | 
						|
    async def close(self) -> None:
 | 
						|
        if self._form is not None:
 | 
						|
            await self._form.close()
 | 
						|
 | 
						|
    async def is_disconnected(self) -> bool:
 | 
						|
        if not self._is_disconnected:
 | 
						|
            message: Message = {}
 | 
						|
 | 
						|
            # If message isn't immediately available, move on
 | 
						|
            with anyio.CancelScope() as cs:
 | 
						|
                cs.cancel()
 | 
						|
                message = await self._receive()
 | 
						|
 | 
						|
            if message.get("type") == "http.disconnect":
 | 
						|
                self._is_disconnected = True
 | 
						|
 | 
						|
        return self._is_disconnected
 | 
						|
 | 
						|
    async def send_push_promise(self, path: str) -> None:
 | 
						|
        if "http.response.push" in self.scope.get("extensions", {}):
 | 
						|
            raw_headers: list[tuple[bytes, bytes]] = []
 | 
						|
            for name in SERVER_PUSH_HEADERS_TO_COPY:
 | 
						|
                for value in self.headers.getlist(name):
 | 
						|
                    raw_headers.append(
 | 
						|
                        (name.encode("latin-1"), value.encode("latin-1"))
 | 
						|
                    )
 | 
						|
            await self._send(
 | 
						|
                {"type": "http.response.push", "path": path, "headers": raw_headers}
 | 
						|
            )
 |