139 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			139 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
import inspect
 | 
						|
import sys
 | 
						|
from collections.abc import Callable, Iterable, Mapping
 | 
						|
from contextlib import AbstractContextManager
 | 
						|
from types import TracebackType
 | 
						|
from typing import TYPE_CHECKING, Any
 | 
						|
 | 
						|
if sys.version_info < (3, 11):
 | 
						|
    from ._exceptions import BaseExceptionGroup
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    _Handler = Callable[[BaseExceptionGroup[Any]], Any]
 | 
						|
 | 
						|
 | 
						|
class _Catcher:
 | 
						|
    def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]):
 | 
						|
        self._handler_map = handler_map
 | 
						|
 | 
						|
    def __enter__(self) -> None:
 | 
						|
        pass
 | 
						|
 | 
						|
    def __exit__(
 | 
						|
        self,
 | 
						|
        etype: type[BaseException] | None,
 | 
						|
        exc: BaseException | None,
 | 
						|
        tb: TracebackType | None,
 | 
						|
    ) -> bool:
 | 
						|
        if exc is not None:
 | 
						|
            unhandled = self.handle_exception(exc)
 | 
						|
            if unhandled is exc:
 | 
						|
                return False
 | 
						|
            elif unhandled is None:
 | 
						|
                return True
 | 
						|
            else:
 | 
						|
                if isinstance(exc, BaseExceptionGroup):
 | 
						|
                    try:
 | 
						|
                        raise unhandled from exc.__cause__
 | 
						|
                    except BaseExceptionGroup:
 | 
						|
                        # Change __context__ to __cause__ because Python 3.11 does this
 | 
						|
                        # too
 | 
						|
                        unhandled.__context__ = exc.__cause__
 | 
						|
                        raise
 | 
						|
 | 
						|
                raise unhandled from exc
 | 
						|
 | 
						|
        return False
 | 
						|
 | 
						|
    def handle_exception(self, exc: BaseException) -> BaseException | None:
 | 
						|
        excgroup: BaseExceptionGroup | None
 | 
						|
        if isinstance(exc, BaseExceptionGroup):
 | 
						|
            excgroup = exc
 | 
						|
        else:
 | 
						|
            excgroup = BaseExceptionGroup("", [exc])
 | 
						|
 | 
						|
        new_exceptions: list[BaseException] = []
 | 
						|
        for exc_types, handler in self._handler_map.items():
 | 
						|
            matched, excgroup = excgroup.split(exc_types)
 | 
						|
            if matched:
 | 
						|
                try:
 | 
						|
                    try:
 | 
						|
                        raise matched
 | 
						|
                    except BaseExceptionGroup:
 | 
						|
                        result = handler(matched)
 | 
						|
                except BaseExceptionGroup as new_exc:
 | 
						|
                    if new_exc is matched:
 | 
						|
                        new_exceptions.append(new_exc)
 | 
						|
                    else:
 | 
						|
                        new_exceptions.extend(new_exc.exceptions)
 | 
						|
                except BaseException as new_exc:
 | 
						|
                    new_exceptions.append(new_exc)
 | 
						|
                else:
 | 
						|
                    if inspect.iscoroutine(result):
 | 
						|
                        raise TypeError(
 | 
						|
                            f"Error trying to handle {matched!r} with {handler!r}. "
 | 
						|
                            "Exception handler must be a sync function."
 | 
						|
                        ) from exc
 | 
						|
 | 
						|
            if not excgroup:
 | 
						|
                break
 | 
						|
 | 
						|
        if new_exceptions:
 | 
						|
            if len(new_exceptions) == 1:
 | 
						|
                return new_exceptions[0]
 | 
						|
 | 
						|
            return BaseExceptionGroup("", new_exceptions)
 | 
						|
        elif (
 | 
						|
            excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc
 | 
						|
        ):
 | 
						|
            return exc
 | 
						|
        else:
 | 
						|
            return excgroup
 | 
						|
 | 
						|
 | 
						|
def catch(
 | 
						|
    __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler],
 | 
						|
) -> AbstractContextManager[None]:
 | 
						|
    if not isinstance(__handlers, Mapping):
 | 
						|
        raise TypeError("the argument must be a mapping")
 | 
						|
 | 
						|
    handler_map: dict[
 | 
						|
        tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]]
 | 
						|
    ] = {}
 | 
						|
    for type_or_iterable, handler in __handlers.items():
 | 
						|
        iterable: tuple[type[BaseException]]
 | 
						|
        if isinstance(type_or_iterable, type) and issubclass(
 | 
						|
            type_or_iterable, BaseException
 | 
						|
        ):
 | 
						|
            iterable = (type_or_iterable,)
 | 
						|
        elif isinstance(type_or_iterable, Iterable):
 | 
						|
            iterable = tuple(type_or_iterable)
 | 
						|
        else:
 | 
						|
            raise TypeError(
 | 
						|
                "each key must be either an exception classes or an iterable thereof"
 | 
						|
            )
 | 
						|
 | 
						|
        if not callable(handler):
 | 
						|
            raise TypeError("handlers must be callable")
 | 
						|
 | 
						|
        for exc_type in iterable:
 | 
						|
            if not isinstance(exc_type, type) or not issubclass(
 | 
						|
                exc_type, BaseException
 | 
						|
            ):
 | 
						|
                raise TypeError(
 | 
						|
                    "each key must be either an exception classes or an iterable "
 | 
						|
                    "thereof"
 | 
						|
                )
 | 
						|
 | 
						|
            if issubclass(exc_type, BaseExceptionGroup):
 | 
						|
                raise TypeError(
 | 
						|
                    "catching ExceptionGroup with catch() is not allowed. "
 | 
						|
                    "Use except instead."
 | 
						|
                )
 | 
						|
 | 
						|
        handler_map[iterable] = handler
 | 
						|
 | 
						|
    return _Catcher(handler_map)
 |