224 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			224 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
import re
 | 
						|
import warnings
 | 
						|
from dataclasses import is_dataclass
 | 
						|
from typing import (
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Any,
 | 
						|
    Dict,
 | 
						|
    MutableMapping,
 | 
						|
    Optional,
 | 
						|
    Set,
 | 
						|
    Type,
 | 
						|
    Union,
 | 
						|
    cast,
 | 
						|
)
 | 
						|
from weakref import WeakKeyDictionary
 | 
						|
 | 
						|
import fastapi
 | 
						|
from fastapi._compat import (
 | 
						|
    PYDANTIC_V2,
 | 
						|
    BaseConfig,
 | 
						|
    ModelField,
 | 
						|
    PydanticSchemaGenerationError,
 | 
						|
    Undefined,
 | 
						|
    UndefinedType,
 | 
						|
    Validator,
 | 
						|
    lenient_issubclass,
 | 
						|
)
 | 
						|
from fastapi.datastructures import DefaultPlaceholder, DefaultType
 | 
						|
from pydantic import BaseModel, create_model
 | 
						|
from pydantic.fields import FieldInfo
 | 
						|
from typing_extensions import Literal
 | 
						|
 | 
						|
if TYPE_CHECKING:  # pragma: nocover
 | 
						|
    from .routing import APIRoute
 | 
						|
 | 
						|
# Cache for `create_cloned_field`
 | 
						|
_CLONED_TYPES_CACHE: MutableMapping[
 | 
						|
    Type[BaseModel], Type[BaseModel]
 | 
						|
] = WeakKeyDictionary()
 | 
						|
 | 
						|
 | 
						|
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
 | 
						|
    if status_code is None:
 | 
						|
        return True
 | 
						|
    # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1
 | 
						|
    if status_code in {
 | 
						|
        "default",
 | 
						|
        "1XX",
 | 
						|
        "2XX",
 | 
						|
        "3XX",
 | 
						|
        "4XX",
 | 
						|
        "5XX",
 | 
						|
    }:
 | 
						|
        return True
 | 
						|
    current_status_code = int(status_code)
 | 
						|
    return not (current_status_code < 200 or current_status_code in {204, 205, 304})
 | 
						|
 | 
						|
 | 
						|
def get_path_param_names(path: str) -> Set[str]:
 | 
						|
    return set(re.findall("{(.*?)}", path))
 | 
						|
 | 
						|
 | 
						|
def create_response_field(
 | 
						|
    name: str,
 | 
						|
    type_: Type[Any],
 | 
						|
    class_validators: Optional[Dict[str, Validator]] = None,
 | 
						|
    default: Optional[Any] = Undefined,
 | 
						|
    required: Union[bool, UndefinedType] = Undefined,
 | 
						|
    model_config: Type[BaseConfig] = BaseConfig,
 | 
						|
    field_info: Optional[FieldInfo] = None,
 | 
						|
    alias: Optional[str] = None,
 | 
						|
    mode: Literal["validation", "serialization"] = "validation",
 | 
						|
) -> ModelField:
 | 
						|
    """
 | 
						|
    Create a new response field. Raises if type_ is invalid.
 | 
						|
    """
 | 
						|
    class_validators = class_validators or {}
 | 
						|
    if PYDANTIC_V2:
 | 
						|
        field_info = field_info or FieldInfo(
 | 
						|
            annotation=type_, default=default, alias=alias
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        field_info = field_info or FieldInfo()
 | 
						|
    kwargs = {"name": name, "field_info": field_info}
 | 
						|
    if PYDANTIC_V2:
 | 
						|
        kwargs.update({"mode": mode})
 | 
						|
    else:
 | 
						|
        kwargs.update(
 | 
						|
            {
 | 
						|
                "type_": type_,
 | 
						|
                "class_validators": class_validators,
 | 
						|
                "default": default,
 | 
						|
                "required": required,
 | 
						|
                "model_config": model_config,
 | 
						|
                "alias": alias,
 | 
						|
            }
 | 
						|
        )
 | 
						|
    try:
 | 
						|
        return ModelField(**kwargs)  # type: ignore[arg-type]
 | 
						|
    except (RuntimeError, PydanticSchemaGenerationError):
 | 
						|
        raise fastapi.exceptions.FastAPIError(
 | 
						|
            "Invalid args for response field! Hint: "
 | 
						|
            f"check that {type_} is a valid Pydantic field type. "
 | 
						|
            "If you are using a return type annotation that is not a valid Pydantic "
 | 
						|
            "field (e.g. Union[Response, dict, None]) you can disable generating the "
 | 
						|
            "response model from the type annotation with the path operation decorator "
 | 
						|
            "parameter response_model=None. Read more: "
 | 
						|
            "https://fastapi.tiangolo.com/tutorial/response-model/"
 | 
						|
        ) from None
 | 
						|
 | 
						|
 | 
						|
def create_cloned_field(
 | 
						|
    field: ModelField,
 | 
						|
    *,
 | 
						|
    cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
 | 
						|
) -> ModelField:
 | 
						|
    if PYDANTIC_V2:
 | 
						|
        return field
 | 
						|
    # cloned_types caches already cloned types to support recursive models and improve
 | 
						|
    # performance by avoiding unnecessary cloning
 | 
						|
    if cloned_types is None:
 | 
						|
        cloned_types = _CLONED_TYPES_CACHE
 | 
						|
 | 
						|
    original_type = field.type_
 | 
						|
    if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
 | 
						|
        original_type = original_type.__pydantic_model__
 | 
						|
    use_type = original_type
 | 
						|
    if lenient_issubclass(original_type, BaseModel):
 | 
						|
        original_type = cast(Type[BaseModel], original_type)
 | 
						|
        use_type = cloned_types.get(original_type)
 | 
						|
        if use_type is None:
 | 
						|
            use_type = create_model(original_type.__name__, __base__=original_type)
 | 
						|
            cloned_types[original_type] = use_type
 | 
						|
            for f in original_type.__fields__.values():
 | 
						|
                use_type.__fields__[f.name] = create_cloned_field(
 | 
						|
                    f, cloned_types=cloned_types
 | 
						|
                )
 | 
						|
    new_field = create_response_field(name=field.name, type_=use_type)
 | 
						|
    new_field.has_alias = field.has_alias  # type: ignore[attr-defined]
 | 
						|
    new_field.alias = field.alias  # type: ignore[misc]
 | 
						|
    new_field.class_validators = field.class_validators  # type: ignore[attr-defined]
 | 
						|
    new_field.default = field.default  # type: ignore[misc]
 | 
						|
    new_field.required = field.required  # type: ignore[misc]
 | 
						|
    new_field.model_config = field.model_config  # type: ignore[attr-defined]
 | 
						|
    new_field.field_info = field.field_info
 | 
						|
    new_field.allow_none = field.allow_none  # type: ignore[attr-defined]
 | 
						|
    new_field.validate_always = field.validate_always  # type: ignore[attr-defined]
 | 
						|
    if field.sub_fields:  # type: ignore[attr-defined]
 | 
						|
        new_field.sub_fields = [  # type: ignore[attr-defined]
 | 
						|
            create_cloned_field(sub_field, cloned_types=cloned_types)
 | 
						|
            for sub_field in field.sub_fields  # type: ignore[attr-defined]
 | 
						|
        ]
 | 
						|
    if field.key_field:  # type: ignore[attr-defined]
 | 
						|
        new_field.key_field = create_cloned_field(  # type: ignore[attr-defined]
 | 
						|
            field.key_field,  # type: ignore[attr-defined]
 | 
						|
            cloned_types=cloned_types,
 | 
						|
        )
 | 
						|
    new_field.validators = field.validators  # type: ignore[attr-defined]
 | 
						|
    new_field.pre_validators = field.pre_validators  # type: ignore[attr-defined]
 | 
						|
    new_field.post_validators = field.post_validators  # type: ignore[attr-defined]
 | 
						|
    new_field.parse_json = field.parse_json  # type: ignore[attr-defined]
 | 
						|
    new_field.shape = field.shape  # type: ignore[attr-defined]
 | 
						|
    new_field.populate_validators()  # type: ignore[attr-defined]
 | 
						|
    return new_field
 | 
						|
 | 
						|
 | 
						|
def generate_operation_id_for_path(
 | 
						|
    *, name: str, path: str, method: str
 | 
						|
) -> str:  # pragma: nocover
 | 
						|
    warnings.warn(
 | 
						|
        "fastapi.utils.generate_operation_id_for_path() was deprecated, "
 | 
						|
        "it is not used internally, and will be removed soon",
 | 
						|
        DeprecationWarning,
 | 
						|
        stacklevel=2,
 | 
						|
    )
 | 
						|
    operation_id = f"{name}{path}"
 | 
						|
    operation_id = re.sub(r"\W", "_", operation_id)
 | 
						|
    operation_id = f"{operation_id}_{method.lower()}"
 | 
						|
    return operation_id
 | 
						|
 | 
						|
 | 
						|
def generate_unique_id(route: "APIRoute") -> str:
 | 
						|
    operation_id = f"{route.name}{route.path_format}"
 | 
						|
    operation_id = re.sub(r"\W", "_", operation_id)
 | 
						|
    assert route.methods
 | 
						|
    operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
 | 
						|
    return operation_id
 | 
						|
 | 
						|
 | 
						|
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
 | 
						|
    for key, value in update_dict.items():
 | 
						|
        if (
 | 
						|
            key in main_dict
 | 
						|
            and isinstance(main_dict[key], dict)
 | 
						|
            and isinstance(value, dict)
 | 
						|
        ):
 | 
						|
            deep_dict_update(main_dict[key], value)
 | 
						|
        elif (
 | 
						|
            key in main_dict
 | 
						|
            and isinstance(main_dict[key], list)
 | 
						|
            and isinstance(update_dict[key], list)
 | 
						|
        ):
 | 
						|
            main_dict[key] = main_dict[key] + update_dict[key]
 | 
						|
        else:
 | 
						|
            main_dict[key] = value
 | 
						|
 | 
						|
 | 
						|
def get_value_or_default(
 | 
						|
    first_item: Union[DefaultPlaceholder, DefaultType],
 | 
						|
    *extra_items: Union[DefaultPlaceholder, DefaultType],
 | 
						|
) -> Union[DefaultPlaceholder, DefaultType]:
 | 
						|
    """
 | 
						|
    Pass items or `DefaultPlaceholder`s by descending priority.
 | 
						|
 | 
						|
    The first one to _not_ be a `DefaultPlaceholder` will be returned.
 | 
						|
 | 
						|
    Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
 | 
						|
    """
 | 
						|
    items = (first_item,) + extra_items
 | 
						|
    for item in items:
 | 
						|
        if not isinstance(item, DefaultPlaceholder):
 | 
						|
            return item
 | 
						|
    return first_item
 |