Source code for pe.api.image.improved_diffusion_lib.gaussian_diffusion

This code contains minor edits from the original code at
to support sampling from the middle of the diffusion process with start_t and
start_image arguments.

import torch as th
from improved_diffusion.respace import SpacedDiffusion
from improved_diffusion.respace import space_timesteps
from improved_diffusion.gaussian_diffusion import _extract_into_tensor
from improved_diffusion import gaussian_diffusion as gd

[docs]class SkippedSpacedDiffusion(SpacedDiffusion):
[docs] def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, model_kwargs=None, device=None, progress=False, start_t=0, start_image=None, ): """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :return: a non-differentiable batch of samples. """ final = None for sample in self.p_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, device=device, progress=progress, start_t=start_t, start_image=start_image, ): final = sample return final["sample"]
[docs] def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, model_kwargs=None, device=None, progress=False, start_t=0, start_image=None, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] indices = indices[start_t:] if start_image is not None: t_batch = th.tensor([indices[0]] * img.shape[0], device=device) img = self.q_sample(start_image, t=t_batch, noise=img) if progress: # Lazy import so that we don't depend on tqdm. from import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) with th.no_grad(): out = self.p_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) yield out img = out["sample"]
[docs] def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) # Equation 12. noise = th.randn_like(x) mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]}
[docs] def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
[docs] def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, start_t=0, start_image=None, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, start_t=start_t, start_image=start_image, ): final = sample return final["sample"]
[docs] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, start_t=0, start_image=None, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] indices = indices[start_t:] if start_image is not None: t_batch = th.tensor([indices[0]] * img.shape[0], device=device) img = self.q_sample(start_image, t=t_batch, noise=img) if progress: # Lazy import so that we don't depend on tqdm. from import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) with th.no_grad(): out = self.ddim_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"]
[docs]def create_gaussian_diffusion( *, steps=1000, learn_sigma=False, sigma_small=False, noise_schedule="linear", use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, timestep_respacing="", ): betas = gd.get_named_beta_schedule(noise_schedule, steps) if use_kl: loss_type = gd.LossType.RESCALED_KL elif rescale_learned_sigmas: loss_type = gd.LossType.RESCALED_MSE else: loss_type = gd.LossType.MSE if not timestep_respacing: timestep_respacing = [steps] return SkippedSpacedDiffusion( use_timesteps=space_timesteps(steps, timestep_respacing), betas=betas, model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), model_var_type=( (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) if not learn_sigma else gd.ModelVarType.LEARNED_RANGE ), loss_type=loss_type, rescale_timesteps=rescale_timesteps, )