974 lines
		
	
	
		
			30 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			974 lines
		
	
	
		
			30 KiB
		
	
	
	
		
			Python
		
	
	
	
# pytorch_diffusion + derived encoder decoder
 | 
						|
import math
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
 | 
						|
def get_timestep_embedding(timesteps, embedding_dim):
 | 
						|
    """
 | 
						|
    This matches the implementation in Denoising Diffusion Probabilistic Models:
 | 
						|
    From Fairseq.
 | 
						|
    Build sinusoidal embeddings.
 | 
						|
    This matches the implementation in tensor2tensor, but differs slightly
 | 
						|
    from the description in Section 3.5 of "Attention Is All You Need".
 | 
						|
    """
 | 
						|
    assert len(timesteps.shape) == 1
 | 
						|
 | 
						|
    half_dim = embedding_dim // 2
 | 
						|
    emb = math.log(10000) / (half_dim - 1)
 | 
						|
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
 | 
						|
    emb = emb.to(device=timesteps.device)
 | 
						|
    emb = timesteps.float()[:, None] * emb[None, :]
 | 
						|
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
 | 
						|
    if embedding_dim % 2 == 1:  # zero pad
 | 
						|
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
 | 
						|
    return emb
 | 
						|
 | 
						|
 | 
						|
def nonlinearity(x):
 | 
						|
    # swish
 | 
						|
    return x * torch.sigmoid(x)
 | 
						|
 | 
						|
 | 
						|
def Normalize(in_channels, num_groups=32):
 | 
						|
    return torch.nn.GroupNorm(
 | 
						|
        num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class Upsample(nn.Module):
 | 
						|
    def __init__(self, in_channels, with_conv):
 | 
						|
        super().__init__()
 | 
						|
        self.with_conv = with_conv
 | 
						|
        if self.with_conv:
 | 
						|
            self.conv = torch.nn.Conv2d(
 | 
						|
                in_channels, in_channels, kernel_size=3, stride=1, padding=1
 | 
						|
            )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
 | 
						|
        if self.with_conv:
 | 
						|
            x = self.conv(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class Downsample(nn.Module):
 | 
						|
    def __init__(self, in_channels, with_conv):
 | 
						|
        super().__init__()
 | 
						|
        self.with_conv = with_conv
 | 
						|
        if self.with_conv:
 | 
						|
            # no asymmetric padding in torch conv, must do it ourselves
 | 
						|
            self.conv = torch.nn.Conv2d(
 | 
						|
                in_channels, in_channels, kernel_size=3, stride=2, padding=0
 | 
						|
            )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        if self.with_conv:
 | 
						|
            pad = (0, 1, 0, 1)
 | 
						|
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
 | 
						|
            x = self.conv(x)
 | 
						|
        else:
 | 
						|
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class ResnetBlock(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        in_channels,
 | 
						|
        out_channels=None,
 | 
						|
        conv_shortcut=False,
 | 
						|
        dropout,
 | 
						|
        temb_channels=512,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.in_channels = in_channels
 | 
						|
        out_channels = in_channels if out_channels is None else out_channels
 | 
						|
        self.out_channels = out_channels
 | 
						|
        self.use_conv_shortcut = conv_shortcut
 | 
						|
 | 
						|
        self.norm1 = Normalize(in_channels)
 | 
						|
        self.conv1 = torch.nn.Conv2d(
 | 
						|
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
        if temb_channels > 0:
 | 
						|
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
 | 
						|
        self.norm2 = Normalize(out_channels)
 | 
						|
        self.dropout = torch.nn.Dropout(dropout)
 | 
						|
        self.conv2 = torch.nn.Conv2d(
 | 
						|
            out_channels, out_channels, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
        if self.in_channels != self.out_channels:
 | 
						|
            if self.use_conv_shortcut:
 | 
						|
                self.conv_shortcut = torch.nn.Conv2d(
 | 
						|
                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                self.nin_shortcut = torch.nn.Conv2d(
 | 
						|
                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
 | 
						|
                )
 | 
						|
 | 
						|
    def forward(self, x, temb):
 | 
						|
        h = x
 | 
						|
        h = self.norm1(h)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        h = self.conv1(h)
 | 
						|
 | 
						|
        if temb is not None:
 | 
						|
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
 | 
						|
 | 
						|
        h = self.norm2(h)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        h = self.dropout(h)
 | 
						|
        h = self.conv2(h)
 | 
						|
 | 
						|
        if self.in_channels != self.out_channels:
 | 
						|
            if self.use_conv_shortcut:
 | 
						|
                x = self.conv_shortcut(x)
 | 
						|
            else:
 | 
						|
                x = self.nin_shortcut(x)
 | 
						|
 | 
						|
        return x + h
 | 
						|
 | 
						|
 | 
						|
class AttnBlock(nn.Module):
 | 
						|
    def __init__(self, in_channels):
 | 
						|
        super().__init__()
 | 
						|
        self.in_channels = in_channels
 | 
						|
 | 
						|
        self.norm = Normalize(in_channels)
 | 
						|
        self.q = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
        self.k = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
        self.v = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
        self.proj_out = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        h_ = x
 | 
						|
        h_ = self.norm(h_)
 | 
						|
        q = self.q(h_)
 | 
						|
        k = self.k(h_)
 | 
						|
        v = self.v(h_)
 | 
						|
 | 
						|
        # compute attention
 | 
						|
        b, c, h, w = q.shape
 | 
						|
        q = q.reshape(b, c, h * w)
 | 
						|
        q = q.permute(0, 2, 1)  # b,hw,c
 | 
						|
        k = k.reshape(b, c, h * w)  # b,c,hw
 | 
						|
        w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
 | 
						|
        w_ = w_ * (int(c) ** (-0.5))
 | 
						|
        w_ = torch.nn.functional.softmax(w_, dim=2)
 | 
						|
 | 
						|
        # attend to values
 | 
						|
        v = v.reshape(b, c, h * w)
 | 
						|
        w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
 | 
						|
        h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
 | 
						|
        h_ = h_.reshape(b, c, h, w)
 | 
						|
 | 
						|
        h_ = self.proj_out(h_)
 | 
						|
 | 
						|
        return x + h_
 | 
						|
 | 
						|
 | 
						|
class AttnBlock2_0(nn.Module):
 | 
						|
    def __init__(self, in_channels):
 | 
						|
        super().__init__()
 | 
						|
        self.in_channels = in_channels
 | 
						|
 | 
						|
        self.norm = Normalize(in_channels)
 | 
						|
        self.q = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
        self.k = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
        self.v = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
        self.proj_out = torch.nn.Conv2d(
 | 
						|
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        h_ = x
 | 
						|
        h_ = self.norm(h_)
 | 
						|
        # output: [1, 512, 64, 64]
 | 
						|
        q = self.q(h_)
 | 
						|
        k = self.k(h_)
 | 
						|
        v = self.v(h_)
 | 
						|
 | 
						|
        # compute attention
 | 
						|
        b, c, h, w = q.shape
 | 
						|
 | 
						|
        # q = q.reshape(b, c, h * w).transpose()
 | 
						|
        # q = q.permute(0, 2, 1)  # b,hw,c
 | 
						|
        # k = k.reshape(b, c, h * w)  # b,c,hw
 | 
						|
        q = q.transpose(1, 2)
 | 
						|
        k = k.transpose(1, 2)
 | 
						|
        v = v.transpose(1, 2)
 | 
						|
        # (batch, num_heads, seq_len, head_dim)
 | 
						|
        hidden_states = torch.nn.functional.scaled_dot_product_attention(
 | 
						|
            q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
 | 
						|
        )
 | 
						|
        hidden_states = hidden_states.transpose(1, 2)
 | 
						|
        hidden_states = hidden_states.to(q.dtype)
 | 
						|
 | 
						|
        h_ = self.proj_out(hidden_states)
 | 
						|
 | 
						|
        return x + h_
 | 
						|
 | 
						|
 | 
						|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
 | 
						|
    assert attn_type in [
 | 
						|
        "vanilla",
 | 
						|
        "vanilla-xformers",
 | 
						|
        "memory-efficient-cross-attn",
 | 
						|
        "linear",
 | 
						|
        "none",
 | 
						|
    ], f"attn_type {attn_type} unknown"
 | 
						|
    assert attn_kwargs is None
 | 
						|
    if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
 | 
						|
        # print(f"Using torch.nn.functional.scaled_dot_product_attention")
 | 
						|
        return AttnBlock2_0(in_channels)
 | 
						|
    return AttnBlock(in_channels)
 | 
						|
 | 
						|
 | 
						|
class Model(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        ch,
 | 
						|
        out_ch,
 | 
						|
        ch_mult=(1, 2, 4, 8),
 | 
						|
        num_res_blocks,
 | 
						|
        attn_resolutions,
 | 
						|
        dropout=0.0,
 | 
						|
        resamp_with_conv=True,
 | 
						|
        in_channels,
 | 
						|
        resolution,
 | 
						|
        use_timestep=True,
 | 
						|
        use_linear_attn=False,
 | 
						|
        attn_type="vanilla",
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        if use_linear_attn:
 | 
						|
            attn_type = "linear"
 | 
						|
        self.ch = ch
 | 
						|
        self.temb_ch = self.ch * 4
 | 
						|
        self.num_resolutions = len(ch_mult)
 | 
						|
        self.num_res_blocks = num_res_blocks
 | 
						|
        self.resolution = resolution
 | 
						|
        self.in_channels = in_channels
 | 
						|
 | 
						|
        self.use_timestep = use_timestep
 | 
						|
        if self.use_timestep:
 | 
						|
            # timestep embedding
 | 
						|
            self.temb = nn.Module()
 | 
						|
            self.temb.dense = nn.ModuleList(
 | 
						|
                [
 | 
						|
                    torch.nn.Linear(self.ch, self.temb_ch),
 | 
						|
                    torch.nn.Linear(self.temb_ch, self.temb_ch),
 | 
						|
                ]
 | 
						|
            )
 | 
						|
 | 
						|
        # downsampling
 | 
						|
        self.conv_in = torch.nn.Conv2d(
 | 
						|
            in_channels, self.ch, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
        curr_res = resolution
 | 
						|
        in_ch_mult = (1,) + tuple(ch_mult)
 | 
						|
        self.down = nn.ModuleList()
 | 
						|
        for i_level in range(self.num_resolutions):
 | 
						|
            block = nn.ModuleList()
 | 
						|
            attn = nn.ModuleList()
 | 
						|
            block_in = ch * in_ch_mult[i_level]
 | 
						|
            block_out = ch * ch_mult[i_level]
 | 
						|
            for i_block in range(self.num_res_blocks):
 | 
						|
                block.append(
 | 
						|
                    ResnetBlock(
 | 
						|
                        in_channels=block_in,
 | 
						|
                        out_channels=block_out,
 | 
						|
                        temb_channels=self.temb_ch,
 | 
						|
                        dropout=dropout,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
                block_in = block_out
 | 
						|
                if curr_res in attn_resolutions:
 | 
						|
                    attn.append(make_attn(block_in, attn_type=attn_type))
 | 
						|
            down = nn.Module()
 | 
						|
            down.block = block
 | 
						|
            down.attn = attn
 | 
						|
            if i_level != self.num_resolutions - 1:
 | 
						|
                down.downsample = Downsample(block_in, resamp_with_conv)
 | 
						|
                curr_res = curr_res // 2
 | 
						|
            self.down.append(down)
 | 
						|
 | 
						|
        # middle
 | 
						|
        self.mid = nn.Module()
 | 
						|
        self.mid.block_1 = ResnetBlock(
 | 
						|
            in_channels=block_in,
 | 
						|
            out_channels=block_in,
 | 
						|
            temb_channels=self.temb_ch,
 | 
						|
            dropout=dropout,
 | 
						|
        )
 | 
						|
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
 | 
						|
        self.mid.block_2 = ResnetBlock(
 | 
						|
            in_channels=block_in,
 | 
						|
            out_channels=block_in,
 | 
						|
            temb_channels=self.temb_ch,
 | 
						|
            dropout=dropout,
 | 
						|
        )
 | 
						|
 | 
						|
        # upsampling
 | 
						|
        self.up = nn.ModuleList()
 | 
						|
        for i_level in reversed(range(self.num_resolutions)):
 | 
						|
            block = nn.ModuleList()
 | 
						|
            attn = nn.ModuleList()
 | 
						|
            block_out = ch * ch_mult[i_level]
 | 
						|
            skip_in = ch * ch_mult[i_level]
 | 
						|
            for i_block in range(self.num_res_blocks + 1):
 | 
						|
                if i_block == self.num_res_blocks:
 | 
						|
                    skip_in = ch * in_ch_mult[i_level]
 | 
						|
                block.append(
 | 
						|
                    ResnetBlock(
 | 
						|
                        in_channels=block_in + skip_in,
 | 
						|
                        out_channels=block_out,
 | 
						|
                        temb_channels=self.temb_ch,
 | 
						|
                        dropout=dropout,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
                block_in = block_out
 | 
						|
                if curr_res in attn_resolutions:
 | 
						|
                    attn.append(make_attn(block_in, attn_type=attn_type))
 | 
						|
            up = nn.Module()
 | 
						|
            up.block = block
 | 
						|
            up.attn = attn
 | 
						|
            if i_level != 0:
 | 
						|
                up.upsample = Upsample(block_in, resamp_with_conv)
 | 
						|
                curr_res = curr_res * 2
 | 
						|
            self.up.insert(0, up)  # prepend to get consistent order
 | 
						|
 | 
						|
        # end
 | 
						|
        self.norm_out = Normalize(block_in)
 | 
						|
        self.conv_out = torch.nn.Conv2d(
 | 
						|
            block_in, out_ch, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x, t=None, context=None):
 | 
						|
        # assert x.shape[2] == x.shape[3] == self.resolution
 | 
						|
        if context is not None:
 | 
						|
            # assume aligned context, cat along channel axis
 | 
						|
            x = torch.cat((x, context), dim=1)
 | 
						|
        if self.use_timestep:
 | 
						|
            # timestep embedding
 | 
						|
            assert t is not None
 | 
						|
            temb = get_timestep_embedding(t, self.ch)
 | 
						|
            temb = self.temb.dense[0](temb)
 | 
						|
            temb = nonlinearity(temb)
 | 
						|
            temb = self.temb.dense[1](temb)
 | 
						|
        else:
 | 
						|
            temb = None
 | 
						|
 | 
						|
        # downsampling
 | 
						|
        hs = [self.conv_in(x)]
 | 
						|
        for i_level in range(self.num_resolutions):
 | 
						|
            for i_block in range(self.num_res_blocks):
 | 
						|
                h = self.down[i_level].block[i_block](hs[-1], temb)
 | 
						|
                if len(self.down[i_level].attn) > 0:
 | 
						|
                    h = self.down[i_level].attn[i_block](h)
 | 
						|
                hs.append(h)
 | 
						|
            if i_level != self.num_resolutions - 1:
 | 
						|
                hs.append(self.down[i_level].downsample(hs[-1]))
 | 
						|
 | 
						|
        # middle
 | 
						|
        h = hs[-1]
 | 
						|
        h = self.mid.block_1(h, temb)
 | 
						|
        h = self.mid.attn_1(h)
 | 
						|
        h = self.mid.block_2(h, temb)
 | 
						|
 | 
						|
        # upsampling
 | 
						|
        for i_level in reversed(range(self.num_resolutions)):
 | 
						|
            for i_block in range(self.num_res_blocks + 1):
 | 
						|
                h = self.up[i_level].block[i_block](
 | 
						|
                    torch.cat([h, hs.pop()], dim=1), temb
 | 
						|
                )
 | 
						|
                if len(self.up[i_level].attn) > 0:
 | 
						|
                    h = self.up[i_level].attn[i_block](h)
 | 
						|
            if i_level != 0:
 | 
						|
                h = self.up[i_level].upsample(h)
 | 
						|
 | 
						|
        # end
 | 
						|
        h = self.norm_out(h)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        h = self.conv_out(h)
 | 
						|
        return h
 | 
						|
 | 
						|
    def get_last_layer(self):
 | 
						|
        return self.conv_out.weight
 | 
						|
 | 
						|
 | 
						|
class Encoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        ch,
 | 
						|
        out_ch,
 | 
						|
        ch_mult=(1, 2, 4, 8),
 | 
						|
        num_res_blocks,
 | 
						|
        attn_resolutions,
 | 
						|
        dropout=0.0,
 | 
						|
        resamp_with_conv=True,
 | 
						|
        in_channels,
 | 
						|
        resolution,
 | 
						|
        z_channels,
 | 
						|
        double_z=True,
 | 
						|
        use_linear_attn=False,
 | 
						|
        attn_type="vanilla",
 | 
						|
        **ignore_kwargs,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        if use_linear_attn:
 | 
						|
            attn_type = "linear"
 | 
						|
        self.ch = ch
 | 
						|
        self.temb_ch = 0
 | 
						|
        self.num_resolutions = len(ch_mult)
 | 
						|
        self.num_res_blocks = num_res_blocks
 | 
						|
        self.resolution = resolution
 | 
						|
        self.in_channels = in_channels
 | 
						|
 | 
						|
        # downsampling
 | 
						|
        self.conv_in = torch.nn.Conv2d(
 | 
						|
            in_channels, self.ch, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
        curr_res = resolution
 | 
						|
        in_ch_mult = (1,) + tuple(ch_mult)
 | 
						|
        self.in_ch_mult = in_ch_mult
 | 
						|
        self.down = nn.ModuleList()
 | 
						|
        for i_level in range(self.num_resolutions):
 | 
						|
            block = nn.ModuleList()
 | 
						|
            attn = nn.ModuleList()
 | 
						|
            block_in = ch * in_ch_mult[i_level]
 | 
						|
            block_out = ch * ch_mult[i_level]
 | 
						|
            for i_block in range(self.num_res_blocks):
 | 
						|
                block.append(
 | 
						|
                    ResnetBlock(
 | 
						|
                        in_channels=block_in,
 | 
						|
                        out_channels=block_out,
 | 
						|
                        temb_channels=self.temb_ch,
 | 
						|
                        dropout=dropout,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
                block_in = block_out
 | 
						|
                if curr_res in attn_resolutions:
 | 
						|
                    attn.append(make_attn(block_in, attn_type=attn_type))
 | 
						|
            down = nn.Module()
 | 
						|
            down.block = block
 | 
						|
            down.attn = attn
 | 
						|
            if i_level != self.num_resolutions - 1:
 | 
						|
                down.downsample = Downsample(block_in, resamp_with_conv)
 | 
						|
                curr_res = curr_res // 2
 | 
						|
            self.down.append(down)
 | 
						|
 | 
						|
        # middle
 | 
						|
        self.mid = nn.Module()
 | 
						|
        self.mid.block_1 = ResnetBlock(
 | 
						|
            in_channels=block_in,
 | 
						|
            out_channels=block_in,
 | 
						|
            temb_channels=self.temb_ch,
 | 
						|
            dropout=dropout,
 | 
						|
        )
 | 
						|
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
 | 
						|
        self.mid.block_2 = ResnetBlock(
 | 
						|
            in_channels=block_in,
 | 
						|
            out_channels=block_in,
 | 
						|
            temb_channels=self.temb_ch,
 | 
						|
            dropout=dropout,
 | 
						|
        )
 | 
						|
 | 
						|
        # end
 | 
						|
        self.norm_out = Normalize(block_in)
 | 
						|
        self.conv_out = torch.nn.Conv2d(
 | 
						|
            block_in,
 | 
						|
            2 * z_channels if double_z else z_channels,
 | 
						|
            kernel_size=3,
 | 
						|
            stride=1,
 | 
						|
            padding=1,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        # timestep embedding
 | 
						|
        temb = None
 | 
						|
 | 
						|
        # downsampling
 | 
						|
        hs = [self.conv_in(x)]
 | 
						|
        for i_level in range(self.num_resolutions):
 | 
						|
            for i_block in range(self.num_res_blocks):
 | 
						|
                h = self.down[i_level].block[i_block](hs[-1], temb)
 | 
						|
                if len(self.down[i_level].attn) > 0:
 | 
						|
                    h = self.down[i_level].attn[i_block](h)
 | 
						|
                hs.append(h)
 | 
						|
            if i_level != self.num_resolutions - 1:
 | 
						|
                hs.append(self.down[i_level].downsample(hs[-1]))
 | 
						|
 | 
						|
        # middle
 | 
						|
        h = hs[-1]
 | 
						|
        h = self.mid.block_1(h, temb)
 | 
						|
        h = self.mid.attn_1(h)
 | 
						|
        h = self.mid.block_2(h, temb)
 | 
						|
 | 
						|
        # end
 | 
						|
        h = self.norm_out(h)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        h = self.conv_out(h)
 | 
						|
        return h
 | 
						|
 | 
						|
 | 
						|
class Decoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        ch,
 | 
						|
        out_ch,
 | 
						|
        ch_mult=(1, 2, 4, 8),
 | 
						|
        num_res_blocks,
 | 
						|
        attn_resolutions,
 | 
						|
        dropout=0.0,
 | 
						|
        resamp_with_conv=True,
 | 
						|
        in_channels,
 | 
						|
        resolution,
 | 
						|
        z_channels,
 | 
						|
        give_pre_end=False,
 | 
						|
        tanh_out=False,
 | 
						|
        use_linear_attn=False,
 | 
						|
        attn_type="vanilla",
 | 
						|
        **ignorekwargs,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        if use_linear_attn:
 | 
						|
            attn_type = "linear"
 | 
						|
        self.ch = ch
 | 
						|
        self.temb_ch = 0
 | 
						|
        self.num_resolutions = len(ch_mult)
 | 
						|
        self.num_res_blocks = num_res_blocks
 | 
						|
        self.resolution = resolution
 | 
						|
        self.in_channels = in_channels
 | 
						|
        self.give_pre_end = give_pre_end
 | 
						|
        self.tanh_out = tanh_out
 | 
						|
 | 
						|
        # compute in_ch_mult, block_in and curr_res at lowest res
 | 
						|
        in_ch_mult = (1,) + tuple(ch_mult)
 | 
						|
        block_in = ch * ch_mult[self.num_resolutions - 1]
 | 
						|
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
 | 
						|
        self.z_shape = (1, z_channels, curr_res, curr_res)
 | 
						|
        print(
 | 
						|
            "Working with z of shape {} = {} dimensions.".format(
 | 
						|
                self.z_shape, np.prod(self.z_shape)
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # z to block_in
 | 
						|
        self.conv_in = torch.nn.Conv2d(
 | 
						|
            z_channels, block_in, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
        # middle
 | 
						|
        self.mid = nn.Module()
 | 
						|
        self.mid.block_1 = ResnetBlock(
 | 
						|
            in_channels=block_in,
 | 
						|
            out_channels=block_in,
 | 
						|
            temb_channels=self.temb_ch,
 | 
						|
            dropout=dropout,
 | 
						|
        )
 | 
						|
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
 | 
						|
        self.mid.block_2 = ResnetBlock(
 | 
						|
            in_channels=block_in,
 | 
						|
            out_channels=block_in,
 | 
						|
            temb_channels=self.temb_ch,
 | 
						|
            dropout=dropout,
 | 
						|
        )
 | 
						|
 | 
						|
        # upsampling
 | 
						|
        self.up = nn.ModuleList()
 | 
						|
        for i_level in reversed(range(self.num_resolutions)):
 | 
						|
            block = nn.ModuleList()
 | 
						|
            attn = nn.ModuleList()
 | 
						|
            block_out = ch * ch_mult[i_level]
 | 
						|
            for i_block in range(self.num_res_blocks + 1):
 | 
						|
                block.append(
 | 
						|
                    ResnetBlock(
 | 
						|
                        in_channels=block_in,
 | 
						|
                        out_channels=block_out,
 | 
						|
                        temb_channels=self.temb_ch,
 | 
						|
                        dropout=dropout,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
                block_in = block_out
 | 
						|
                if curr_res in attn_resolutions:
 | 
						|
                    attn.append(make_attn(block_in, attn_type=attn_type))
 | 
						|
            up = nn.Module()
 | 
						|
            up.block = block
 | 
						|
            up.attn = attn
 | 
						|
            if i_level != 0:
 | 
						|
                up.upsample = Upsample(block_in, resamp_with_conv)
 | 
						|
                curr_res = curr_res * 2
 | 
						|
            self.up.insert(0, up)  # prepend to get consistent order
 | 
						|
 | 
						|
        # end
 | 
						|
        self.norm_out = Normalize(block_in)
 | 
						|
        self.conv_out = torch.nn.Conv2d(
 | 
						|
            block_in, out_ch, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, z):
 | 
						|
        # assert z.shape[1:] == self.z_shape[1:]
 | 
						|
        self.last_z_shape = z.shape
 | 
						|
 | 
						|
        # timestep embedding
 | 
						|
        temb = None
 | 
						|
 | 
						|
        # z to block_in
 | 
						|
        h = self.conv_in(z)
 | 
						|
 | 
						|
        # middle
 | 
						|
        h = self.mid.block_1(h, temb)
 | 
						|
        h = self.mid.attn_1(h)
 | 
						|
        h = self.mid.block_2(h, temb)
 | 
						|
 | 
						|
        # upsampling
 | 
						|
        for i_level in reversed(range(self.num_resolutions)):
 | 
						|
            for i_block in range(self.num_res_blocks + 1):
 | 
						|
                h = self.up[i_level].block[i_block](h, temb)
 | 
						|
                if len(self.up[i_level].attn) > 0:
 | 
						|
                    h = self.up[i_level].attn[i_block](h)
 | 
						|
            if i_level != 0:
 | 
						|
                h = self.up[i_level].upsample(h)
 | 
						|
 | 
						|
        # end
 | 
						|
        if self.give_pre_end:
 | 
						|
            return h
 | 
						|
 | 
						|
        h = self.norm_out(h)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        h = self.conv_out(h)
 | 
						|
        if self.tanh_out:
 | 
						|
            h = torch.tanh(h)
 | 
						|
        return h
 | 
						|
 | 
						|
 | 
						|
class SimpleDecoder(nn.Module):
 | 
						|
    def __init__(self, in_channels, out_channels, *args, **kwargs):
 | 
						|
        super().__init__()
 | 
						|
        self.model = nn.ModuleList(
 | 
						|
            [
 | 
						|
                nn.Conv2d(in_channels, in_channels, 1),
 | 
						|
                ResnetBlock(
 | 
						|
                    in_channels=in_channels,
 | 
						|
                    out_channels=2 * in_channels,
 | 
						|
                    temb_channels=0,
 | 
						|
                    dropout=0.0,
 | 
						|
                ),
 | 
						|
                ResnetBlock(
 | 
						|
                    in_channels=2 * in_channels,
 | 
						|
                    out_channels=4 * in_channels,
 | 
						|
                    temb_channels=0,
 | 
						|
                    dropout=0.0,
 | 
						|
                ),
 | 
						|
                ResnetBlock(
 | 
						|
                    in_channels=4 * in_channels,
 | 
						|
                    out_channels=2 * in_channels,
 | 
						|
                    temb_channels=0,
 | 
						|
                    dropout=0.0,
 | 
						|
                ),
 | 
						|
                nn.Conv2d(2 * in_channels, in_channels, 1),
 | 
						|
                Upsample(in_channels, with_conv=True),
 | 
						|
            ]
 | 
						|
        )
 | 
						|
        # end
 | 
						|
        self.norm_out = Normalize(in_channels)
 | 
						|
        self.conv_out = torch.nn.Conv2d(
 | 
						|
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        for i, layer in enumerate(self.model):
 | 
						|
            if i in [1, 2, 3]:
 | 
						|
                x = layer(x, None)
 | 
						|
            else:
 | 
						|
                x = layer(x)
 | 
						|
 | 
						|
        h = self.norm_out(x)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        x = self.conv_out(h)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class UpsampleDecoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels,
 | 
						|
        out_channels,
 | 
						|
        ch,
 | 
						|
        num_res_blocks,
 | 
						|
        resolution,
 | 
						|
        ch_mult=(2, 2),
 | 
						|
        dropout=0.0,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        # upsampling
 | 
						|
        self.temb_ch = 0
 | 
						|
        self.num_resolutions = len(ch_mult)
 | 
						|
        self.num_res_blocks = num_res_blocks
 | 
						|
        block_in = in_channels
 | 
						|
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
 | 
						|
        self.res_blocks = nn.ModuleList()
 | 
						|
        self.upsample_blocks = nn.ModuleList()
 | 
						|
        for i_level in range(self.num_resolutions):
 | 
						|
            res_block = []
 | 
						|
            block_out = ch * ch_mult[i_level]
 | 
						|
            for i_block in range(self.num_res_blocks + 1):
 | 
						|
                res_block.append(
 | 
						|
                    ResnetBlock(
 | 
						|
                        in_channels=block_in,
 | 
						|
                        out_channels=block_out,
 | 
						|
                        temb_channels=self.temb_ch,
 | 
						|
                        dropout=dropout,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
                block_in = block_out
 | 
						|
            self.res_blocks.append(nn.ModuleList(res_block))
 | 
						|
            if i_level != self.num_resolutions - 1:
 | 
						|
                self.upsample_blocks.append(Upsample(block_in, True))
 | 
						|
                curr_res = curr_res * 2
 | 
						|
 | 
						|
        # end
 | 
						|
        self.norm_out = Normalize(block_in)
 | 
						|
        self.conv_out = torch.nn.Conv2d(
 | 
						|
            block_in, out_channels, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        # upsampling
 | 
						|
        h = x
 | 
						|
        for k, i_level in enumerate(range(self.num_resolutions)):
 | 
						|
            for i_block in range(self.num_res_blocks + 1):
 | 
						|
                h = self.res_blocks[i_level][i_block](h, None)
 | 
						|
            if i_level != self.num_resolutions - 1:
 | 
						|
                h = self.upsample_blocks[k](h)
 | 
						|
        h = self.norm_out(h)
 | 
						|
        h = nonlinearity(h)
 | 
						|
        h = self.conv_out(h)
 | 
						|
        return h
 | 
						|
 | 
						|
 | 
						|
class LatentRescaler(nn.Module):
 | 
						|
    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
 | 
						|
        super().__init__()
 | 
						|
        # residual block, interpolate, residual block
 | 
						|
        self.factor = factor
 | 
						|
        self.conv_in = nn.Conv2d(
 | 
						|
            in_channels, mid_channels, kernel_size=3, stride=1, padding=1
 | 
						|
        )
 | 
						|
        self.res_block1 = nn.ModuleList(
 | 
						|
            [
 | 
						|
                ResnetBlock(
 | 
						|
                    in_channels=mid_channels,
 | 
						|
                    out_channels=mid_channels,
 | 
						|
                    temb_channels=0,
 | 
						|
                    dropout=0.0,
 | 
						|
                )
 | 
						|
                for _ in range(depth)
 | 
						|
            ]
 | 
						|
        )
 | 
						|
        self.attn = AttnBlock(mid_channels)
 | 
						|
        self.res_block2 = nn.ModuleList(
 | 
						|
            [
 | 
						|
                ResnetBlock(
 | 
						|
                    in_channels=mid_channels,
 | 
						|
                    out_channels=mid_channels,
 | 
						|
                    temb_channels=0,
 | 
						|
                    dropout=0.0,
 | 
						|
                )
 | 
						|
                for _ in range(depth)
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
        self.conv_out = nn.Conv2d(
 | 
						|
            mid_channels,
 | 
						|
            out_channels,
 | 
						|
            kernel_size=1,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.conv_in(x)
 | 
						|
        for block in self.res_block1:
 | 
						|
            x = block(x, None)
 | 
						|
        x = torch.nn.functional.interpolate(
 | 
						|
            x,
 | 
						|
            size=(
 | 
						|
                int(round(x.shape[2] * self.factor)),
 | 
						|
                int(round(x.shape[3] * self.factor)),
 | 
						|
            ),
 | 
						|
        )
 | 
						|
        x = self.attn(x)
 | 
						|
        for block in self.res_block2:
 | 
						|
            x = block(x, None)
 | 
						|
        x = self.conv_out(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class MergedRescaleEncoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels,
 | 
						|
        ch,
 | 
						|
        resolution,
 | 
						|
        out_ch,
 | 
						|
        num_res_blocks,
 | 
						|
        attn_resolutions,
 | 
						|
        dropout=0.0,
 | 
						|
        resamp_with_conv=True,
 | 
						|
        ch_mult=(1, 2, 4, 8),
 | 
						|
        rescale_factor=1.0,
 | 
						|
        rescale_module_depth=1,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        intermediate_chn = ch * ch_mult[-1]
 | 
						|
        self.encoder = Encoder(
 | 
						|
            in_channels=in_channels,
 | 
						|
            num_res_blocks=num_res_blocks,
 | 
						|
            ch=ch,
 | 
						|
            ch_mult=ch_mult,
 | 
						|
            z_channels=intermediate_chn,
 | 
						|
            double_z=False,
 | 
						|
            resolution=resolution,
 | 
						|
            attn_resolutions=attn_resolutions,
 | 
						|
            dropout=dropout,
 | 
						|
            resamp_with_conv=resamp_with_conv,
 | 
						|
            out_ch=None,
 | 
						|
        )
 | 
						|
        self.rescaler = LatentRescaler(
 | 
						|
            factor=rescale_factor,
 | 
						|
            in_channels=intermediate_chn,
 | 
						|
            mid_channels=intermediate_chn,
 | 
						|
            out_channels=out_ch,
 | 
						|
            depth=rescale_module_depth,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.encoder(x)
 | 
						|
        x = self.rescaler(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class MergedRescaleDecoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        z_channels,
 | 
						|
        out_ch,
 | 
						|
        resolution,
 | 
						|
        num_res_blocks,
 | 
						|
        attn_resolutions,
 | 
						|
        ch,
 | 
						|
        ch_mult=(1, 2, 4, 8),
 | 
						|
        dropout=0.0,
 | 
						|
        resamp_with_conv=True,
 | 
						|
        rescale_factor=1.0,
 | 
						|
        rescale_module_depth=1,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        tmp_chn = z_channels * ch_mult[-1]
 | 
						|
        self.decoder = Decoder(
 | 
						|
            out_ch=out_ch,
 | 
						|
            z_channels=tmp_chn,
 | 
						|
            attn_resolutions=attn_resolutions,
 | 
						|
            dropout=dropout,
 | 
						|
            resamp_with_conv=resamp_with_conv,
 | 
						|
            in_channels=None,
 | 
						|
            num_res_blocks=num_res_blocks,
 | 
						|
            ch_mult=ch_mult,
 | 
						|
            resolution=resolution,
 | 
						|
            ch=ch,
 | 
						|
        )
 | 
						|
        self.rescaler = LatentRescaler(
 | 
						|
            factor=rescale_factor,
 | 
						|
            in_channels=z_channels,
 | 
						|
            mid_channels=tmp_chn,
 | 
						|
            out_channels=tmp_chn,
 | 
						|
            depth=rescale_module_depth,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.rescaler(x)
 | 
						|
        x = self.decoder(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class Upsampler(nn.Module):
 | 
						|
    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
 | 
						|
        super().__init__()
 | 
						|
        assert out_size >= in_size
 | 
						|
        num_blocks = int(np.log2(out_size // in_size)) + 1
 | 
						|
        factor_up = 1.0 + (out_size % in_size)
 | 
						|
        print(
 | 
						|
            f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
 | 
						|
        )
 | 
						|
        self.rescaler = LatentRescaler(
 | 
						|
            factor=factor_up,
 | 
						|
            in_channels=in_channels,
 | 
						|
            mid_channels=2 * in_channels,
 | 
						|
            out_channels=in_channels,
 | 
						|
        )
 | 
						|
        self.decoder = Decoder(
 | 
						|
            out_ch=out_channels,
 | 
						|
            resolution=out_size,
 | 
						|
            z_channels=in_channels,
 | 
						|
            num_res_blocks=2,
 | 
						|
            attn_resolutions=[],
 | 
						|
            in_channels=None,
 | 
						|
            ch=in_channels,
 | 
						|
            ch_mult=[ch_mult for _ in range(num_blocks)],
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.rescaler(x)
 | 
						|
        x = self.decoder(x)
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class Resize(nn.Module):
 | 
						|
    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
 | 
						|
        super().__init__()
 | 
						|
        self.with_conv = learned
 | 
						|
        self.mode = mode
 | 
						|
        if self.with_conv:
 | 
						|
            print(
 | 
						|
                f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
 | 
						|
            )
 | 
						|
            raise NotImplementedError()
 | 
						|
            assert in_channels is not None
 | 
						|
            # no asymmetric padding in torch conv, must do it ourselves
 | 
						|
            self.conv = torch.nn.Conv2d(
 | 
						|
                in_channels, in_channels, kernel_size=4, stride=2, padding=1
 | 
						|
            )
 | 
						|
 | 
						|
    def forward(self, x, scale_factor=1.0):
 | 
						|
        if scale_factor == 1.0:
 | 
						|
            return x
 | 
						|
        else:
 | 
						|
            x = torch.nn.functional.interpolate(
 | 
						|
                x, mode=self.mode, align_corners=False, scale_factor=scale_factor
 | 
						|
            )
 | 
						|
        return x
 |