1034 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1034 lines
		
	
	
		
			37 KiB
		
	
	
	
		
			Python
		
	
	
	
import gc
 | 
						|
import math
 | 
						|
import random
 | 
						|
import traceback
 | 
						|
from typing import Any
 | 
						|
 | 
						|
import torch
 | 
						|
import numpy as np
 | 
						|
import collections
 | 
						|
from itertools import repeat
 | 
						|
 | 
						|
from diffusers import (
 | 
						|
    DDIMScheduler,
 | 
						|
    PNDMScheduler,
 | 
						|
    LMSDiscreteScheduler,
 | 
						|
    EulerDiscreteScheduler,
 | 
						|
    EulerAncestralDiscreteScheduler,
 | 
						|
    DPMSolverMultistepScheduler,
 | 
						|
    UniPCMultistepScheduler,
 | 
						|
    LCMScheduler,
 | 
						|
    DPMSolverSinglestepScheduler,
 | 
						|
    KDPM2DiscreteScheduler,
 | 
						|
    KDPM2AncestralDiscreteScheduler,
 | 
						|
    HeunDiscreteScheduler,
 | 
						|
)
 | 
						|
from loguru import logger
 | 
						|
 | 
						|
from iopaint.schema import SDSampler
 | 
						|
from torch import conv2d, conv_transpose2d
 | 
						|
 | 
						|
 | 
						|
def make_beta_schedule(
 | 
						|
    device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
 | 
						|
):
 | 
						|
    if schedule == "linear":
 | 
						|
        betas = (
 | 
						|
            torch.linspace(
 | 
						|
                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
 | 
						|
            )
 | 
						|
            ** 2
 | 
						|
        )
 | 
						|
 | 
						|
    elif schedule == "cosine":
 | 
						|
        timesteps = (
 | 
						|
            torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
 | 
						|
        ).to(device)
 | 
						|
        alphas = timesteps / (1 + cosine_s) * np.pi / 2
 | 
						|
        alphas = torch.cos(alphas).pow(2).to(device)
 | 
						|
        alphas = alphas / alphas[0]
 | 
						|
        betas = 1 - alphas[1:] / alphas[:-1]
 | 
						|
        betas = np.clip(betas, a_min=0, a_max=0.999)
 | 
						|
 | 
						|
    elif schedule == "sqrt_linear":
 | 
						|
        betas = torch.linspace(
 | 
						|
            linear_start, linear_end, n_timestep, dtype=torch.float64
 | 
						|
        )
 | 
						|
    elif schedule == "sqrt":
 | 
						|
        betas = (
 | 
						|
            torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
 | 
						|
            ** 0.5
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        raise ValueError(f"schedule '{schedule}' unknown.")
 | 
						|
    return betas.numpy()
 | 
						|
 | 
						|
 | 
						|
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
 | 
						|
    # select alphas for computing the variance schedule
 | 
						|
    alphas = alphacums[ddim_timesteps]
 | 
						|
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
 | 
						|
 | 
						|
    # according the the formula provided in https://arxiv.org/abs/2010.02502
 | 
						|
    sigmas = eta * np.sqrt(
 | 
						|
        (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
 | 
						|
    )
 | 
						|
    if verbose:
 | 
						|
        print(
 | 
						|
            f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
 | 
						|
        )
 | 
						|
        print(
 | 
						|
            f"For the chosen value of eta, which is {eta}, "
 | 
						|
            f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
 | 
						|
        )
 | 
						|
    return sigmas, alphas, alphas_prev
 | 
						|
 | 
						|
 | 
						|
def make_ddim_timesteps(
 | 
						|
    ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
 | 
						|
):
 | 
						|
    if ddim_discr_method == "uniform":
 | 
						|
        c = num_ddpm_timesteps // num_ddim_timesteps
 | 
						|
        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
 | 
						|
    elif ddim_discr_method == "quad":
 | 
						|
        ddim_timesteps = (
 | 
						|
            (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
 | 
						|
        ).astype(int)
 | 
						|
    else:
 | 
						|
        raise NotImplementedError(
 | 
						|
            f'There is no ddim discretization method called "{ddim_discr_method}"'
 | 
						|
        )
 | 
						|
 | 
						|
    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
 | 
						|
    # add one to get the final alpha values right (the ones from first scale to data during sampling)
 | 
						|
    steps_out = ddim_timesteps + 1
 | 
						|
    if verbose:
 | 
						|
        print(f"Selected timesteps for ddim sampler: {steps_out}")
 | 
						|
    return steps_out
 | 
						|
 | 
						|
 | 
						|
def noise_like(shape, device, repeat=False):
 | 
						|
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
 | 
						|
        shape[0], *((1,) * (len(shape) - 1))
 | 
						|
    )
 | 
						|
    noise = lambda: torch.randn(shape, device=device)
 | 
						|
    return repeat_noise() if repeat else noise()
 | 
						|
 | 
						|
 | 
						|
def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
 | 
						|
    """
 | 
						|
    Create sinusoidal timestep embeddings.
 | 
						|
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
 | 
						|
                      These may be fractional.
 | 
						|
    :param dim: the dimension of the output.
 | 
						|
    :param max_period: controls the minimum frequency of the embeddings.
 | 
						|
    :return: an [N x dim] Tensor of positional embeddings.
 | 
						|
    """
 | 
						|
    half = dim // 2
 | 
						|
    freqs = torch.exp(
 | 
						|
        -math.log(max_period)
 | 
						|
        * torch.arange(start=0, end=half, dtype=torch.float32)
 | 
						|
        / half
 | 
						|
    ).to(device=device)
 | 
						|
 | 
						|
    args = timesteps[:, None].float() * freqs[None]
 | 
						|
 | 
						|
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
 | 
						|
    if dim % 2:
 | 
						|
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 | 
						|
    return embedding
 | 
						|
 | 
						|
 | 
						|
###### MAT and FcF #######
 | 
						|
 | 
						|
 | 
						|
def normalize_2nd_moment(x, dim=1):
 | 
						|
    return (
 | 
						|
        x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt()
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class EasyDict(dict):
 | 
						|
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""
 | 
						|
 | 
						|
    def __getattr__(self, name: str) -> Any:
 | 
						|
        try:
 | 
						|
            return self[name]
 | 
						|
        except KeyError:
 | 
						|
            raise AttributeError(name)
 | 
						|
 | 
						|
    def __setattr__(self, name: str, value: Any) -> None:
 | 
						|
        self[name] = value
 | 
						|
 | 
						|
    def __delattr__(self, name: str) -> None:
 | 
						|
        del self[name]
 | 
						|
 | 
						|
 | 
						|
def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
 | 
						|
    """Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
 | 
						|
    assert isinstance(x, torch.Tensor)
 | 
						|
    assert clamp is None or clamp >= 0
 | 
						|
    spec = activation_funcs[act]
 | 
						|
    alpha = float(alpha if alpha is not None else spec.def_alpha)
 | 
						|
    gain = float(gain if gain is not None else spec.def_gain)
 | 
						|
    clamp = float(clamp if clamp is not None else -1)
 | 
						|
 | 
						|
    # Add bias.
 | 
						|
    if b is not None:
 | 
						|
        assert isinstance(b, torch.Tensor) and b.ndim == 1
 | 
						|
        assert 0 <= dim < x.ndim
 | 
						|
        assert b.shape[0] == x.shape[dim]
 | 
						|
        x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
 | 
						|
 | 
						|
    # Evaluate activation function.
 | 
						|
    alpha = float(alpha)
 | 
						|
    x = spec.func(x, alpha=alpha)
 | 
						|
 | 
						|
    # Scale by gain.
 | 
						|
    gain = float(gain)
 | 
						|
    if gain != 1:
 | 
						|
        x = x * gain
 | 
						|
 | 
						|
    # Clamp.
 | 
						|
    if clamp >= 0:
 | 
						|
        x = x.clamp(-clamp, clamp)  # pylint: disable=invalid-unary-operand-type
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
def bias_act(
 | 
						|
    x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
 | 
						|
):
 | 
						|
    r"""Fused bias and activation function.
 | 
						|
 | 
						|
    Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
 | 
						|
    and scales the result by `gain`. Each of the steps is optional. In most cases,
 | 
						|
    the fused op is considerably more efficient than performing the same calculation
 | 
						|
    using standard PyTorch ops. It supports first and second order gradients,
 | 
						|
    but not third order gradients.
 | 
						|
 | 
						|
    Args:
 | 
						|
        x:      Input activation tensor. Can be of any shape.
 | 
						|
        b:      Bias vector, or `None` to disable. Must be a 1D tensor of the same type
 | 
						|
                as `x`. The shape must be known, and it must match the dimension of `x`
 | 
						|
                corresponding to `dim`.
 | 
						|
        dim:    The dimension in `x` corresponding to the elements of `b`.
 | 
						|
                The value of `dim` is ignored if `b` is not specified.
 | 
						|
        act:    Name of the activation function to evaluate, or `"linear"` to disable.
 | 
						|
                Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
 | 
						|
                See `activation_funcs` for a full list. `None` is not allowed.
 | 
						|
        alpha:  Shape parameter for the activation function, or `None` to use the default.
 | 
						|
        gain:   Scaling factor for the output tensor, or `None` to use default.
 | 
						|
                See `activation_funcs` for the default scaling of each activation function.
 | 
						|
                If unsure, consider specifying 1.
 | 
						|
        clamp:  Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
 | 
						|
                the clamping (default).
 | 
						|
        impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tensor of the same shape and datatype as `x`.
 | 
						|
    """
 | 
						|
    assert isinstance(x, torch.Tensor)
 | 
						|
    assert impl in ["ref", "cuda"]
 | 
						|
    return _bias_act_ref(
 | 
						|
        x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _get_filter_size(f):
 | 
						|
    if f is None:
 | 
						|
        return 1, 1
 | 
						|
 | 
						|
    assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
 | 
						|
    fw = f.shape[-1]
 | 
						|
    fh = f.shape[0]
 | 
						|
 | 
						|
    fw = int(fw)
 | 
						|
    fh = int(fh)
 | 
						|
    assert fw >= 1 and fh >= 1
 | 
						|
    return fw, fh
 | 
						|
 | 
						|
 | 
						|
def _get_weight_shape(w):
 | 
						|
    shape = [int(sz) for sz in w.shape]
 | 
						|
    return shape
 | 
						|
 | 
						|
 | 
						|
def _parse_scaling(scaling):
 | 
						|
    if isinstance(scaling, int):
 | 
						|
        scaling = [scaling, scaling]
 | 
						|
    assert isinstance(scaling, (list, tuple))
 | 
						|
    assert all(isinstance(x, int) for x in scaling)
 | 
						|
    sx, sy = scaling
 | 
						|
    assert sx >= 1 and sy >= 1
 | 
						|
    return sx, sy
 | 
						|
 | 
						|
 | 
						|
def _parse_padding(padding):
 | 
						|
    if isinstance(padding, int):
 | 
						|
        padding = [padding, padding]
 | 
						|
    assert isinstance(padding, (list, tuple))
 | 
						|
    assert all(isinstance(x, int) for x in padding)
 | 
						|
    if len(padding) == 2:
 | 
						|
        padx, pady = padding
 | 
						|
        padding = [padx, padx, pady, pady]
 | 
						|
    padx0, padx1, pady0, pady1 = padding
 | 
						|
    return padx0, padx1, pady0, pady1
 | 
						|
 | 
						|
 | 
						|
def setup_filter(
 | 
						|
    f,
 | 
						|
    device=torch.device("cpu"),
 | 
						|
    normalize=True,
 | 
						|
    flip_filter=False,
 | 
						|
    gain=1,
 | 
						|
    separable=None,
 | 
						|
):
 | 
						|
    r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
 | 
						|
 | 
						|
    Args:
 | 
						|
        f:           Torch tensor, numpy array, or python list of the shape
 | 
						|
                     `[filter_height, filter_width]` (non-separable),
 | 
						|
                     `[filter_taps]` (separable),
 | 
						|
                     `[]` (impulse), or
 | 
						|
                     `None` (identity).
 | 
						|
        device:      Result device (default: cpu).
 | 
						|
        normalize:   Normalize the filter so that it retains the magnitude
 | 
						|
                     for constant input signal (DC)? (default: True).
 | 
						|
        flip_filter: Flip the filter? (default: False).
 | 
						|
        gain:        Overall scaling factor for signal magnitude (default: 1).
 | 
						|
        separable:   Return a separable filter? (default: select automatically).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Float32 tensor of the shape
 | 
						|
        `[filter_height, filter_width]` (non-separable) or
 | 
						|
        `[filter_taps]` (separable).
 | 
						|
    """
 | 
						|
    # Validate.
 | 
						|
    if f is None:
 | 
						|
        f = 1
 | 
						|
    f = torch.as_tensor(f, dtype=torch.float32)
 | 
						|
    assert f.ndim in [0, 1, 2]
 | 
						|
    assert f.numel() > 0
 | 
						|
    if f.ndim == 0:
 | 
						|
        f = f[np.newaxis]
 | 
						|
 | 
						|
    # Separable?
 | 
						|
    if separable is None:
 | 
						|
        separable = f.ndim == 1 and f.numel() >= 8
 | 
						|
    if f.ndim == 1 and not separable:
 | 
						|
        f = f.ger(f)
 | 
						|
    assert f.ndim == (1 if separable else 2)
 | 
						|
 | 
						|
    # Apply normalize, flip, gain, and device.
 | 
						|
    if normalize:
 | 
						|
        f /= f.sum()
 | 
						|
    if flip_filter:
 | 
						|
        f = f.flip(list(range(f.ndim)))
 | 
						|
    f = f * (gain ** (f.ndim / 2))
 | 
						|
    f = f.to(device=device)
 | 
						|
    return f
 | 
						|
 | 
						|
 | 
						|
def _ntuple(n):
 | 
						|
    def parse(x):
 | 
						|
        if isinstance(x, collections.abc.Iterable):
 | 
						|
            return x
 | 
						|
        return tuple(repeat(x, n))
 | 
						|
 | 
						|
    return parse
 | 
						|
 | 
						|
 | 
						|
to_2tuple = _ntuple(2)
 | 
						|
 | 
						|
activation_funcs = {
 | 
						|
    "linear": EasyDict(
 | 
						|
        func=lambda x, **_: x,
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=1,
 | 
						|
        cuda_idx=1,
 | 
						|
        ref="",
 | 
						|
        has_2nd_grad=False,
 | 
						|
    ),
 | 
						|
    "relu": EasyDict(
 | 
						|
        func=lambda x, **_: torch.nn.functional.relu(x),
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=np.sqrt(2),
 | 
						|
        cuda_idx=2,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=False,
 | 
						|
    ),
 | 
						|
    "lrelu": EasyDict(
 | 
						|
        func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
 | 
						|
        def_alpha=0.2,
 | 
						|
        def_gain=np.sqrt(2),
 | 
						|
        cuda_idx=3,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=False,
 | 
						|
    ),
 | 
						|
    "tanh": EasyDict(
 | 
						|
        func=lambda x, **_: torch.tanh(x),
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=1,
 | 
						|
        cuda_idx=4,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=True,
 | 
						|
    ),
 | 
						|
    "sigmoid": EasyDict(
 | 
						|
        func=lambda x, **_: torch.sigmoid(x),
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=1,
 | 
						|
        cuda_idx=5,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=True,
 | 
						|
    ),
 | 
						|
    "elu": EasyDict(
 | 
						|
        func=lambda x, **_: torch.nn.functional.elu(x),
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=1,
 | 
						|
        cuda_idx=6,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=True,
 | 
						|
    ),
 | 
						|
    "selu": EasyDict(
 | 
						|
        func=lambda x, **_: torch.nn.functional.selu(x),
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=1,
 | 
						|
        cuda_idx=7,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=True,
 | 
						|
    ),
 | 
						|
    "softplus": EasyDict(
 | 
						|
        func=lambda x, **_: torch.nn.functional.softplus(x),
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=1,
 | 
						|
        cuda_idx=8,
 | 
						|
        ref="y",
 | 
						|
        has_2nd_grad=True,
 | 
						|
    ),
 | 
						|
    "swish": EasyDict(
 | 
						|
        func=lambda x, **_: torch.sigmoid(x) * x,
 | 
						|
        def_alpha=0,
 | 
						|
        def_gain=np.sqrt(2),
 | 
						|
        cuda_idx=9,
 | 
						|
        ref="x",
 | 
						|
        has_2nd_grad=True,
 | 
						|
    ),
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
 | 
						|
    r"""Pad, upsample, filter, and downsample a batch of 2D images.
 | 
						|
 | 
						|
    Performs the following sequence of operations for each channel:
 | 
						|
 | 
						|
    1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
 | 
						|
 | 
						|
    2. Pad the image with the specified number of zeros on each side (`padding`).
 | 
						|
       Negative padding corresponds to cropping the image.
 | 
						|
 | 
						|
    3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
 | 
						|
       so that the footprint of all output pixels lies within the input image.
 | 
						|
 | 
						|
    4. Downsample the image by keeping every Nth pixel (`down`).
 | 
						|
 | 
						|
    This sequence of operations bears close resemblance to scipy.signal.upfirdn().
 | 
						|
    The fused op is considerably more efficient than performing the same calculation
 | 
						|
    using standard PyTorch ops. It supports gradients of arbitrary order.
 | 
						|
 | 
						|
    Args:
 | 
						|
        x:           Float32/float64/float16 input tensor of the shape
 | 
						|
                     `[batch_size, num_channels, in_height, in_width]`.
 | 
						|
        f:           Float32 FIR filter of the shape
 | 
						|
                     `[filter_height, filter_width]` (non-separable),
 | 
						|
                     `[filter_taps]` (separable), or
 | 
						|
                     `None` (identity).
 | 
						|
        up:          Integer upsampling factor. Can be a single int or a list/tuple
 | 
						|
                     `[x, y]` (default: 1).
 | 
						|
        down:        Integer downsampling factor. Can be a single int or a list/tuple
 | 
						|
                     `[x, y]` (default: 1).
 | 
						|
        padding:     Padding with respect to the upsampled image. Can be a single number
 | 
						|
                     or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
 | 
						|
                     (default: 0).
 | 
						|
        flip_filter: False = convolution, True = correlation (default: False).
 | 
						|
        gain:        Overall scaling factor for signal magnitude (default: 1).
 | 
						|
        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
 | 
						|
    """
 | 
						|
    # assert isinstance(x, torch.Tensor)
 | 
						|
    # assert impl in ['ref', 'cuda']
 | 
						|
    return _upfirdn2d_ref(
 | 
						|
        x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
 | 
						|
    """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
 | 
						|
    # Validate arguments.
 | 
						|
    assert isinstance(x, torch.Tensor) and x.ndim == 4
 | 
						|
    if f is None:
 | 
						|
        f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
 | 
						|
    assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
 | 
						|
    assert not f.requires_grad
 | 
						|
    batch_size, num_channels, in_height, in_width = x.shape
 | 
						|
    # upx, upy = _parse_scaling(up)
 | 
						|
    # downx, downy = _parse_scaling(down)
 | 
						|
 | 
						|
    upx, upy = up, up
 | 
						|
    downx, downy = down, down
 | 
						|
 | 
						|
    # padx0, padx1, pady0, pady1 = _parse_padding(padding)
 | 
						|
    padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
 | 
						|
 | 
						|
    # Upsample by inserting zeros.
 | 
						|
    x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
 | 
						|
    x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
 | 
						|
    x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
 | 
						|
 | 
						|
    # Pad or crop.
 | 
						|
    x = torch.nn.functional.pad(
 | 
						|
        x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
 | 
						|
    )
 | 
						|
    x = x[
 | 
						|
        :,
 | 
						|
        :,
 | 
						|
        max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
 | 
						|
        max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
 | 
						|
    ]
 | 
						|
 | 
						|
    # Setup filter.
 | 
						|
    f = f * (gain ** (f.ndim / 2))
 | 
						|
    f = f.to(x.dtype)
 | 
						|
    if not flip_filter:
 | 
						|
        f = f.flip(list(range(f.ndim)))
 | 
						|
 | 
						|
    # Convolve with the filter.
 | 
						|
    f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
 | 
						|
    if f.ndim == 4:
 | 
						|
        x = conv2d(input=x, weight=f, groups=num_channels)
 | 
						|
    else:
 | 
						|
        x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
 | 
						|
        x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
 | 
						|
 | 
						|
    # Downsample by throwing away pixels.
 | 
						|
    x = x[:, :, ::downy, ::downx]
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
 | 
						|
    r"""Downsample a batch of 2D images using the given 2D FIR filter.
 | 
						|
 | 
						|
    By default, the result is padded so that its shape is a fraction of the input.
 | 
						|
    User-specified padding is applied on top of that, with negative values
 | 
						|
    indicating cropping. Pixels outside the image are assumed to be zero.
 | 
						|
 | 
						|
    Args:
 | 
						|
        x:           Float32/float64/float16 input tensor of the shape
 | 
						|
                     `[batch_size, num_channels, in_height, in_width]`.
 | 
						|
        f:           Float32 FIR filter of the shape
 | 
						|
                     `[filter_height, filter_width]` (non-separable),
 | 
						|
                     `[filter_taps]` (separable), or
 | 
						|
                     `None` (identity).
 | 
						|
        down:        Integer downsampling factor. Can be a single int or a list/tuple
 | 
						|
                     `[x, y]` (default: 1).
 | 
						|
        padding:     Padding with respect to the input. Can be a single number or a
 | 
						|
                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
 | 
						|
                     (default: 0).
 | 
						|
        flip_filter: False = convolution, True = correlation (default: False).
 | 
						|
        gain:        Overall scaling factor for signal magnitude (default: 1).
 | 
						|
        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
 | 
						|
    """
 | 
						|
    downx, downy = _parse_scaling(down)
 | 
						|
    # padx0, padx1, pady0, pady1 = _parse_padding(padding)
 | 
						|
    padx0, padx1, pady0, pady1 = padding, padding, padding, padding
 | 
						|
 | 
						|
    fw, fh = _get_filter_size(f)
 | 
						|
    p = [
 | 
						|
        padx0 + (fw - downx + 1) // 2,
 | 
						|
        padx1 + (fw - downx) // 2,
 | 
						|
        pady0 + (fh - downy + 1) // 2,
 | 
						|
        pady1 + (fh - downy) // 2,
 | 
						|
    ]
 | 
						|
    return upfirdn2d(
 | 
						|
        x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
 | 
						|
    r"""Upsample a batch of 2D images using the given 2D FIR filter.
 | 
						|
 | 
						|
    By default, the result is padded so that its shape is a multiple of the input.
 | 
						|
    User-specified padding is applied on top of that, with negative values
 | 
						|
    indicating cropping. Pixels outside the image are assumed to be zero.
 | 
						|
 | 
						|
    Args:
 | 
						|
        x:           Float32/float64/float16 input tensor of the shape
 | 
						|
                     `[batch_size, num_channels, in_height, in_width]`.
 | 
						|
        f:           Float32 FIR filter of the shape
 | 
						|
                     `[filter_height, filter_width]` (non-separable),
 | 
						|
                     `[filter_taps]` (separable), or
 | 
						|
                     `None` (identity).
 | 
						|
        up:          Integer upsampling factor. Can be a single int or a list/tuple
 | 
						|
                     `[x, y]` (default: 1).
 | 
						|
        padding:     Padding with respect to the output. Can be a single number or a
 | 
						|
                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
 | 
						|
                     (default: 0).
 | 
						|
        flip_filter: False = convolution, True = correlation (default: False).
 | 
						|
        gain:        Overall scaling factor for signal magnitude (default: 1).
 | 
						|
        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
 | 
						|
    """
 | 
						|
    upx, upy = _parse_scaling(up)
 | 
						|
    # upx, upy = up, up
 | 
						|
    padx0, padx1, pady0, pady1 = _parse_padding(padding)
 | 
						|
    # padx0, padx1, pady0, pady1 = padding, padding, padding, padding
 | 
						|
    fw, fh = _get_filter_size(f)
 | 
						|
    p = [
 | 
						|
        padx0 + (fw + upx - 1) // 2,
 | 
						|
        padx1 + (fw - upx) // 2,
 | 
						|
        pady0 + (fh + upy - 1) // 2,
 | 
						|
        pady1 + (fh - upy) // 2,
 | 
						|
    ]
 | 
						|
    return upfirdn2d(
 | 
						|
        x,
 | 
						|
        f,
 | 
						|
        up=up,
 | 
						|
        padding=p,
 | 
						|
        flip_filter=flip_filter,
 | 
						|
        gain=gain * upx * upy,
 | 
						|
        impl=impl,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class MinibatchStdLayer(torch.nn.Module):
 | 
						|
    def __init__(self, group_size, num_channels=1):
 | 
						|
        super().__init__()
 | 
						|
        self.group_size = group_size
 | 
						|
        self.num_channels = num_channels
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        N, C, H, W = x.shape
 | 
						|
        G = (
 | 
						|
            torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
 | 
						|
            if self.group_size is not None
 | 
						|
            else N
 | 
						|
        )
 | 
						|
        F = self.num_channels
 | 
						|
        c = C // F
 | 
						|
 | 
						|
        y = x.reshape(
 | 
						|
            G, -1, F, c, H, W
 | 
						|
        )  # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
 | 
						|
        y = y - y.mean(dim=0)  # [GnFcHW] Subtract mean over group.
 | 
						|
        y = y.square().mean(dim=0)  # [nFcHW]  Calc variance over group.
 | 
						|
        y = (y + 1e-8).sqrt()  # [nFcHW]  Calc stddev over group.
 | 
						|
        y = y.mean(dim=[2, 3, 4])  # [nF]     Take average over channels and pixels.
 | 
						|
        y = y.reshape(-1, F, 1, 1)  # [nF11]   Add missing dimensions.
 | 
						|
        y = y.repeat(G, 1, H, W)  # [NFHW]   Replicate over group and pixels.
 | 
						|
        x = torch.cat([x, y], dim=1)  # [NCHW]   Append to input as new channels.
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class FullyConnectedLayer(torch.nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_features,  # Number of input features.
 | 
						|
        out_features,  # Number of output features.
 | 
						|
        bias=True,  # Apply additive bias before the activation function?
 | 
						|
        activation="linear",  # Activation function: 'relu', 'lrelu', etc.
 | 
						|
        lr_multiplier=1,  # Learning rate multiplier.
 | 
						|
        bias_init=0,  # Initial value for the additive bias.
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.weight = torch.nn.Parameter(
 | 
						|
            torch.randn([out_features, in_features]) / lr_multiplier
 | 
						|
        )
 | 
						|
        self.bias = (
 | 
						|
            torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
 | 
						|
            if bias
 | 
						|
            else None
 | 
						|
        )
 | 
						|
        self.activation = activation
 | 
						|
 | 
						|
        self.weight_gain = lr_multiplier / np.sqrt(in_features)
 | 
						|
        self.bias_gain = lr_multiplier
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        w = self.weight * self.weight_gain
 | 
						|
        b = self.bias
 | 
						|
        if b is not None and self.bias_gain != 1:
 | 
						|
            b = b * self.bias_gain
 | 
						|
 | 
						|
        if self.activation == "linear" and b is not None:
 | 
						|
            # out = torch.addmm(b.unsqueeze(0), x, w.t())
 | 
						|
            x = x.matmul(w.t())
 | 
						|
            out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
 | 
						|
        else:
 | 
						|
            x = x.matmul(w.t())
 | 
						|
            out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
 | 
						|
        return out
 | 
						|
 | 
						|
 | 
						|
def _conv2d_wrapper(
 | 
						|
    x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
 | 
						|
):
 | 
						|
    """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
 | 
						|
    out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
 | 
						|
 | 
						|
    # Flip weight if requested.
 | 
						|
    if (
 | 
						|
        not flip_weight
 | 
						|
    ):  # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
 | 
						|
        w = w.flip([2, 3])
 | 
						|
 | 
						|
    # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
 | 
						|
    # 1x1 kernel + memory_format=channels_last + less than 64 channels.
 | 
						|
    if (
 | 
						|
        kw == 1
 | 
						|
        and kh == 1
 | 
						|
        and stride == 1
 | 
						|
        and padding in [0, [0, 0], (0, 0)]
 | 
						|
        and not transpose
 | 
						|
    ):
 | 
						|
        if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
 | 
						|
            if out_channels <= 4 and groups == 1:
 | 
						|
                in_shape = x.shape
 | 
						|
                x = w.squeeze(3).squeeze(2) @ x.reshape(
 | 
						|
                    [in_shape[0], in_channels_per_group, -1]
 | 
						|
                )
 | 
						|
                x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
 | 
						|
            else:
 | 
						|
                x = x.to(memory_format=torch.contiguous_format)
 | 
						|
                w = w.to(memory_format=torch.contiguous_format)
 | 
						|
                x = conv2d(x, w, groups=groups)
 | 
						|
            return x.to(memory_format=torch.channels_last)
 | 
						|
 | 
						|
    # Otherwise => execute using conv2d_gradfix.
 | 
						|
    op = conv_transpose2d if transpose else conv2d
 | 
						|
    return op(x, w, stride=stride, padding=padding, groups=groups)
 | 
						|
 | 
						|
 | 
						|
def conv2d_resample(
 | 
						|
    x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
 | 
						|
):
 | 
						|
    r"""2D convolution with optional up/downsampling.
 | 
						|
 | 
						|
    Padding is performed only once at the beginning, not between the operations.
 | 
						|
 | 
						|
    Args:
 | 
						|
        x:              Input tensor of shape
 | 
						|
                        `[batch_size, in_channels, in_height, in_width]`.
 | 
						|
        w:              Weight tensor of shape
 | 
						|
                        `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
 | 
						|
        f:              Low-pass filter for up/downsampling. Must be prepared beforehand by
 | 
						|
                        calling setup_filter(). None = identity (default).
 | 
						|
        up:             Integer upsampling factor (default: 1).
 | 
						|
        down:           Integer downsampling factor (default: 1).
 | 
						|
        padding:        Padding with respect to the upsampled image. Can be a single number
 | 
						|
                        or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
 | 
						|
                        (default: 0).
 | 
						|
        groups:         Split input channels into N groups (default: 1).
 | 
						|
        flip_weight:    False = convolution, True = correlation (default: True).
 | 
						|
        flip_filter:    False = convolution, True = correlation (default: False).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
 | 
						|
    """
 | 
						|
    # Validate arguments.
 | 
						|
    assert isinstance(x, torch.Tensor) and (x.ndim == 4)
 | 
						|
    assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
 | 
						|
    assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2])
 | 
						|
    assert isinstance(up, int) and (up >= 1)
 | 
						|
    assert isinstance(down, int) and (down >= 1)
 | 
						|
    # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int)  {isinstance(groups, int)} {type(groups)}"
 | 
						|
    out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
 | 
						|
    fw, fh = _get_filter_size(f)
 | 
						|
    # px0, px1, py0, py1 = _parse_padding(padding)
 | 
						|
    px0, px1, py0, py1 = padding, padding, padding, padding
 | 
						|
 | 
						|
    # Adjust padding to account for up/downsampling.
 | 
						|
    if up > 1:
 | 
						|
        px0 += (fw + up - 1) // 2
 | 
						|
        px1 += (fw - up) // 2
 | 
						|
        py0 += (fh + up - 1) // 2
 | 
						|
        py1 += (fh - up) // 2
 | 
						|
    if down > 1:
 | 
						|
        px0 += (fw - down + 1) // 2
 | 
						|
        px1 += (fw - down) // 2
 | 
						|
        py0 += (fh - down + 1) // 2
 | 
						|
        py1 += (fh - down) // 2
 | 
						|
 | 
						|
    # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
 | 
						|
    if kw == 1 and kh == 1 and (down > 1 and up == 1):
 | 
						|
        x = upfirdn2d(
 | 
						|
            x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
 | 
						|
        )
 | 
						|
        x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
 | 
						|
        return x
 | 
						|
 | 
						|
    # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
 | 
						|
    if kw == 1 and kh == 1 and (up > 1 and down == 1):
 | 
						|
        x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
 | 
						|
        x = upfirdn2d(
 | 
						|
            x=x,
 | 
						|
            f=f,
 | 
						|
            up=up,
 | 
						|
            padding=[px0, px1, py0, py1],
 | 
						|
            gain=up**2,
 | 
						|
            flip_filter=flip_filter,
 | 
						|
        )
 | 
						|
        return x
 | 
						|
 | 
						|
    # Fast path: downsampling only => use strided convolution.
 | 
						|
    if down > 1 and up == 1:
 | 
						|
        x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
 | 
						|
        x = _conv2d_wrapper(
 | 
						|
            x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
 | 
						|
        )
 | 
						|
        return x
 | 
						|
 | 
						|
    # Fast path: upsampling with optional downsampling => use transpose strided convolution.
 | 
						|
    if up > 1:
 | 
						|
        if groups == 1:
 | 
						|
            w = w.transpose(0, 1)
 | 
						|
        else:
 | 
						|
            w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
 | 
						|
            w = w.transpose(1, 2)
 | 
						|
            w = w.reshape(
 | 
						|
                groups * in_channels_per_group, out_channels // groups, kh, kw
 | 
						|
            )
 | 
						|
        px0 -= kw - 1
 | 
						|
        px1 -= kw - up
 | 
						|
        py0 -= kh - 1
 | 
						|
        py1 -= kh - up
 | 
						|
        pxt = max(min(-px0, -px1), 0)
 | 
						|
        pyt = max(min(-py0, -py1), 0)
 | 
						|
        x = _conv2d_wrapper(
 | 
						|
            x=x,
 | 
						|
            w=w,
 | 
						|
            stride=up,
 | 
						|
            padding=[pyt, pxt],
 | 
						|
            groups=groups,
 | 
						|
            transpose=True,
 | 
						|
            flip_weight=(not flip_weight),
 | 
						|
        )
 | 
						|
        x = upfirdn2d(
 | 
						|
            x=x,
 | 
						|
            f=f,
 | 
						|
            padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
 | 
						|
            gain=up**2,
 | 
						|
            flip_filter=flip_filter,
 | 
						|
        )
 | 
						|
        if down > 1:
 | 
						|
            x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
 | 
						|
        return x
 | 
						|
 | 
						|
    # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
 | 
						|
    if up == 1 and down == 1:
 | 
						|
        if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
 | 
						|
            return _conv2d_wrapper(
 | 
						|
                x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
 | 
						|
            )
 | 
						|
 | 
						|
    # Fallback: Generic reference implementation.
 | 
						|
    x = upfirdn2d(
 | 
						|
        x=x,
 | 
						|
        f=(f if up > 1 else None),
 | 
						|
        up=up,
 | 
						|
        padding=[px0, px1, py0, py1],
 | 
						|
        gain=up**2,
 | 
						|
        flip_filter=flip_filter,
 | 
						|
    )
 | 
						|
    x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
 | 
						|
    if down > 1:
 | 
						|
        x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
class Conv2dLayer(torch.nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels,  # Number of input channels.
 | 
						|
        out_channels,  # Number of output channels.
 | 
						|
        kernel_size,  # Width and height of the convolution kernel.
 | 
						|
        bias=True,  # Apply additive bias before the activation function?
 | 
						|
        activation="linear",  # Activation function: 'relu', 'lrelu', etc.
 | 
						|
        up=1,  # Integer upsampling factor.
 | 
						|
        down=1,  # Integer downsampling factor.
 | 
						|
        resample_filter=[
 | 
						|
            1,
 | 
						|
            3,
 | 
						|
            3,
 | 
						|
            1,
 | 
						|
        ],  # Low-pass filter to apply when resampling activations.
 | 
						|
        conv_clamp=None,  # Clamp the output to +-X, None = disable clamping.
 | 
						|
        channels_last=False,  # Expect the input to have memory_format=channels_last?
 | 
						|
        trainable=True,  # Update the weights of this layer during training?
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.activation = activation
 | 
						|
        self.up = up
 | 
						|
        self.down = down
 | 
						|
        self.register_buffer("resample_filter", setup_filter(resample_filter))
 | 
						|
        self.conv_clamp = conv_clamp
 | 
						|
        self.padding = kernel_size // 2
 | 
						|
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
 | 
						|
        self.act_gain = activation_funcs[activation].def_gain
 | 
						|
 | 
						|
        memory_format = (
 | 
						|
            torch.channels_last if channels_last else torch.contiguous_format
 | 
						|
        )
 | 
						|
        weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
 | 
						|
            memory_format=memory_format
 | 
						|
        )
 | 
						|
        bias = torch.zeros([out_channels]) if bias else None
 | 
						|
        if trainable:
 | 
						|
            self.weight = torch.nn.Parameter(weight)
 | 
						|
            self.bias = torch.nn.Parameter(bias) if bias is not None else None
 | 
						|
        else:
 | 
						|
            self.register_buffer("weight", weight)
 | 
						|
            if bias is not None:
 | 
						|
                self.register_buffer("bias", bias)
 | 
						|
            else:
 | 
						|
                self.bias = None
 | 
						|
 | 
						|
    def forward(self, x, gain=1):
 | 
						|
        w = self.weight * self.weight_gain
 | 
						|
        x = conv2d_resample(
 | 
						|
            x=x,
 | 
						|
            w=w,
 | 
						|
            f=self.resample_filter,
 | 
						|
            up=self.up,
 | 
						|
            down=self.down,
 | 
						|
            padding=self.padding,
 | 
						|
        )
 | 
						|
 | 
						|
        act_gain = self.act_gain * gain
 | 
						|
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
 | 
						|
        out = bias_act(
 | 
						|
            x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
 | 
						|
        )
 | 
						|
        return out
 | 
						|
 | 
						|
 | 
						|
def torch_gc():
 | 
						|
    if torch.cuda.is_available():
 | 
						|
        torch.cuda.empty_cache()
 | 
						|
        torch.cuda.ipc_collect()
 | 
						|
    gc.collect()
 | 
						|
 | 
						|
 | 
						|
def set_seed(seed: int):
 | 
						|
    random.seed(seed)
 | 
						|
    np.random.seed(seed)
 | 
						|
    torch.manual_seed(seed)
 | 
						|
    torch.cuda.manual_seed_all(seed)
 | 
						|
 | 
						|
 | 
						|
def get_scheduler(sd_sampler, scheduler_config):
 | 
						|
    # https://github.com/huggingface/diffusers/issues/4167
 | 
						|
    keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
 | 
						|
    scheduler_config = dict(scheduler_config)
 | 
						|
    for it in keys_to_pop:
 | 
						|
        scheduler_config.pop(it, None)
 | 
						|
 | 
						|
    # fmt: off
 | 
						|
    samplers = {
 | 
						|
        SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
 | 
						|
        SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
 | 
						|
        SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
 | 
						|
        SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
 | 
						|
        SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
 | 
						|
        SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
 | 
						|
        SDSampler.dpm2: [KDPM2DiscreteScheduler],
 | 
						|
        SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
 | 
						|
        SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
 | 
						|
        SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
 | 
						|
        SDSampler.euler: [EulerDiscreteScheduler],
 | 
						|
        SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
 | 
						|
        SDSampler.heun: [HeunDiscreteScheduler],
 | 
						|
        SDSampler.lms: [LMSDiscreteScheduler],
 | 
						|
        SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
 | 
						|
        SDSampler.ddim: [DDIMScheduler],
 | 
						|
        SDSampler.pndm: [PNDMScheduler],
 | 
						|
        SDSampler.uni_pc: [UniPCMultistepScheduler],
 | 
						|
        SDSampler.lcm: [LCMScheduler],
 | 
						|
    }
 | 
						|
    # fmt: on
 | 
						|
    if sd_sampler in samplers:
 | 
						|
        if len(samplers[sd_sampler]) == 2:
 | 
						|
            scheduler_cls, kwargs = samplers[sd_sampler]
 | 
						|
        else:
 | 
						|
            scheduler_cls, kwargs = samplers[sd_sampler][0], {}
 | 
						|
        return scheduler_cls.from_config(scheduler_config, **kwargs)
 | 
						|
    else:
 | 
						|
        raise ValueError(sd_sampler)
 | 
						|
 | 
						|
 | 
						|
def is_local_files_only(**kwargs) -> bool:
 | 
						|
    from huggingface_hub.constants import HF_HUB_OFFLINE
 | 
						|
 | 
						|
    return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
 | 
						|
 | 
						|
 | 
						|
def handle_from_pretrained_exceptions(func, **kwargs):
 | 
						|
    try:
 | 
						|
        return func(**kwargs)
 | 
						|
    except ValueError as e:
 | 
						|
        if "You are trying to load the model files of the `variant=fp16`" in str(e):
 | 
						|
            logger.info("variant=fp16 not found, try revision=fp16")
 | 
						|
            try:
 | 
						|
                return func(**{**kwargs, "variant": None, "revision": "fp16"})
 | 
						|
            except Exception as e:
 | 
						|
                logger.info("revision=fp16 not found, try revision=main")
 | 
						|
                return func(**{**kwargs, "variant": None, "revision": "main"})
 | 
						|
        raise e
 | 
						|
    except OSError as e:
 | 
						|
        previous_traceback = traceback.format_exc()
 | 
						|
        if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
 | 
						|
            logger.info("revision=fp16 not found, try revision=main")
 | 
						|
            return func(**{**kwargs, "variant": None, "revision": "main"})
 | 
						|
        elif "Max retries exceeded" in previous_traceback:
 | 
						|
            logger.exception(
 | 
						|
                "Fetching model from HuggingFace failed. "
 | 
						|
                "If this is your first time downloading the model, you may need to set up proxy in terminal."
 | 
						|
                "If the model has already been downloaded, you can add --local-files-only when starting."
 | 
						|
            )
 | 
						|
            exit(-1)
 | 
						|
        raise e
 | 
						|
    except Exception as e:
 | 
						|
        raise e
 | 
						|
 | 
						|
 | 
						|
def get_torch_dtype(device, no_half: bool):
 | 
						|
    device = str(device)
 | 
						|
    use_fp16 = not no_half
 | 
						|
    use_gpu = device == "cuda"
 | 
						|
    # https://github.com/huggingface/diffusers/issues/4480
 | 
						|
    # pipe.enable_attention_slicing and float16 will cause black output on mps
 | 
						|
    # if device in ["cuda", "mps"] and use_fp16:
 | 
						|
    if device in ["cuda"] and use_fp16:
 | 
						|
        return use_gpu, torch.float16
 | 
						|
    return use_gpu, torch.float32
 | 
						|
 | 
						|
 | 
						|
def enable_low_mem(pipe, enable: bool):
 | 
						|
    if torch.backends.mps.is_available():
 | 
						|
        # https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
 | 
						|
        # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
 | 
						|
        if enable:
 | 
						|
            pipe.enable_attention_slicing("max")
 | 
						|
        else:
 | 
						|
            # https://huggingface.co/docs/diffusers/optimization/mps
 | 
						|
            # Devices with less than 64GB of memory are recommended to use enable_attention_slicing
 | 
						|
            pipe.enable_attention_slicing()
 | 
						|
 | 
						|
    if enable:
 | 
						|
        pipe.vae.enable_tiling()
 |