197 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			197 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
import importlib
 | 
						|
 | 
						|
import torch
 | 
						|
from torch import optim
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from inspect import isfunction
 | 
						|
from PIL import Image, ImageDraw, ImageFont
 | 
						|
 | 
						|
 | 
						|
def log_txt_as_img(wh, xc, size=10):
 | 
						|
    # wh a tuple of (width, height)
 | 
						|
    # xc a list of captions to plot
 | 
						|
    b = len(xc)
 | 
						|
    txts = list()
 | 
						|
    for bi in range(b):
 | 
						|
        txt = Image.new("RGB", wh, color="white")
 | 
						|
        draw = ImageDraw.Draw(txt)
 | 
						|
        font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
 | 
						|
        nc = int(32 * (wh[0] / 256))
 | 
						|
        lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
 | 
						|
 | 
						|
        try:
 | 
						|
            draw.text((0, 0), lines, fill="black", font=font)
 | 
						|
        except UnicodeEncodeError:
 | 
						|
            print("Cant encode string for logging. Skipping.")
 | 
						|
 | 
						|
        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
 | 
						|
        txts.append(txt)
 | 
						|
    txts = np.stack(txts)
 | 
						|
    txts = torch.tensor(txts)
 | 
						|
    return txts
 | 
						|
 | 
						|
 | 
						|
def ismap(x):
 | 
						|
    if not isinstance(x, torch.Tensor):
 | 
						|
        return False
 | 
						|
    return (len(x.shape) == 4) and (x.shape[1] > 3)
 | 
						|
 | 
						|
 | 
						|
def isimage(x):
 | 
						|
    if not isinstance(x,torch.Tensor):
 | 
						|
        return False
 | 
						|
    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
 | 
						|
 | 
						|
 | 
						|
def exists(x):
 | 
						|
    return x is not None
 | 
						|
 | 
						|
 | 
						|
def default(val, d):
 | 
						|
    if exists(val):
 | 
						|
        return val
 | 
						|
    return d() if isfunction(d) else d
 | 
						|
 | 
						|
 | 
						|
def mean_flat(tensor):
 | 
						|
    """
 | 
						|
    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
 | 
						|
    Take the mean over all non-batch dimensions.
 | 
						|
    """
 | 
						|
    return tensor.mean(dim=list(range(1, len(tensor.shape))))
 | 
						|
 | 
						|
 | 
						|
def count_params(model, verbose=False):
 | 
						|
    total_params = sum(p.numel() for p in model.parameters())
 | 
						|
    if verbose:
 | 
						|
        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
 | 
						|
    return total_params
 | 
						|
 | 
						|
 | 
						|
def instantiate_from_config(config, **kwargs):
 | 
						|
    if "target" not in config:
 | 
						|
        if config == '__is_first_stage__':
 | 
						|
            return None
 | 
						|
        elif config == "__is_unconditional__":
 | 
						|
            return None
 | 
						|
        raise KeyError("Expected key `target` to instantiate.")
 | 
						|
    return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
 | 
						|
 | 
						|
 | 
						|
def get_obj_from_str(string, reload=False):
 | 
						|
    module, cls = string.rsplit(".", 1)
 | 
						|
    if reload:
 | 
						|
        module_imp = importlib.import_module(module)
 | 
						|
        importlib.reload(module_imp)
 | 
						|
    return getattr(importlib.import_module(module, package=None), cls)
 | 
						|
 | 
						|
 | 
						|
class AdamWwithEMAandWings(optim.Optimizer):
 | 
						|
    # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
 | 
						|
    def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  # TODO: check hyperparameters before using
 | 
						|
                 weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999,   # ema decay to match previous code
 | 
						|
                 ema_power=1., param_names=()):
 | 
						|
        """AdamW that saves EMA versions of the parameters."""
 | 
						|
        if not 0.0 <= lr:
 | 
						|
            raise ValueError("Invalid learning rate: {}".format(lr))
 | 
						|
        if not 0.0 <= eps:
 | 
						|
            raise ValueError("Invalid epsilon value: {}".format(eps))
 | 
						|
        if not 0.0 <= betas[0] < 1.0:
 | 
						|
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
 | 
						|
        if not 0.0 <= betas[1] < 1.0:
 | 
						|
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
 | 
						|
        if not 0.0 <= weight_decay:
 | 
						|
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
 | 
						|
        if not 0.0 <= ema_decay <= 1.0:
 | 
						|
            raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
 | 
						|
        defaults = dict(lr=lr, betas=betas, eps=eps,
 | 
						|
                        weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
 | 
						|
                        ema_power=ema_power, param_names=param_names)
 | 
						|
        super().__init__(params, defaults)
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        super().__setstate__(state)
 | 
						|
        for group in self.param_groups:
 | 
						|
            group.setdefault('amsgrad', False)
 | 
						|
 | 
						|
    @torch.no_grad()
 | 
						|
    def step(self, closure=None):
 | 
						|
        """Performs a single optimization step.
 | 
						|
        Args:
 | 
						|
            closure (callable, optional): A closure that reevaluates the model
 | 
						|
                and returns the loss.
 | 
						|
        """
 | 
						|
        loss = None
 | 
						|
        if closure is not None:
 | 
						|
            with torch.enable_grad():
 | 
						|
                loss = closure()
 | 
						|
 | 
						|
        for group in self.param_groups:
 | 
						|
            params_with_grad = []
 | 
						|
            grads = []
 | 
						|
            exp_avgs = []
 | 
						|
            exp_avg_sqs = []
 | 
						|
            ema_params_with_grad = []
 | 
						|
            state_sums = []
 | 
						|
            max_exp_avg_sqs = []
 | 
						|
            state_steps = []
 | 
						|
            amsgrad = group['amsgrad']
 | 
						|
            beta1, beta2 = group['betas']
 | 
						|
            ema_decay = group['ema_decay']
 | 
						|
            ema_power = group['ema_power']
 | 
						|
 | 
						|
            for p in group['params']:
 | 
						|
                if p.grad is None:
 | 
						|
                    continue
 | 
						|
                params_with_grad.append(p)
 | 
						|
                if p.grad.is_sparse:
 | 
						|
                    raise RuntimeError('AdamW does not support sparse gradients')
 | 
						|
                grads.append(p.grad)
 | 
						|
 | 
						|
                state = self.state[p]
 | 
						|
 | 
						|
                # State initialization
 | 
						|
                if len(state) == 0:
 | 
						|
                    state['step'] = 0
 | 
						|
                    # Exponential moving average of gradient values
 | 
						|
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
 | 
						|
                    # Exponential moving average of squared gradient values
 | 
						|
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
 | 
						|
                    if amsgrad:
 | 
						|
                        # Maintains max of all exp. moving avg. of sq. grad. values
 | 
						|
                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
 | 
						|
                    # Exponential moving average of parameter values
 | 
						|
                    state['param_exp_avg'] = p.detach().float().clone()
 | 
						|
 | 
						|
                exp_avgs.append(state['exp_avg'])
 | 
						|
                exp_avg_sqs.append(state['exp_avg_sq'])
 | 
						|
                ema_params_with_grad.append(state['param_exp_avg'])
 | 
						|
 | 
						|
                if amsgrad:
 | 
						|
                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])
 | 
						|
 | 
						|
                # update the steps for each param group update
 | 
						|
                state['step'] += 1
 | 
						|
                # record the step after step update
 | 
						|
                state_steps.append(state['step'])
 | 
						|
 | 
						|
            optim._functional.adamw(params_with_grad,
 | 
						|
                    grads,
 | 
						|
                    exp_avgs,
 | 
						|
                    exp_avg_sqs,
 | 
						|
                    max_exp_avg_sqs,
 | 
						|
                    state_steps,
 | 
						|
                    amsgrad=amsgrad,
 | 
						|
                    beta1=beta1,
 | 
						|
                    beta2=beta2,
 | 
						|
                    lr=group['lr'],
 | 
						|
                    weight_decay=group['weight_decay'],
 | 
						|
                    eps=group['eps'],
 | 
						|
                    maximize=False)
 | 
						|
 | 
						|
            cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
 | 
						|
            for param, ema_param in zip(params_with_grad, ema_params_with_grad):
 | 
						|
                ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
 | 
						|
 | 
						|
        return loss |