Source code for pe.api.image.improved_diffusion_api

import torch
import numpy as np
import pandas as pd
import tempfile
import os

from pe.api import API
from pe.logging import execution_logger
from pe.data import Data
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
from pe.constant.data import IMAGE_MODEL_LABEL_COLUMN_NAME
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.api.util import ConstantList
from pe.util import download

from improved_diffusion.script_util import NUM_CLASSES
from .improved_diffusion_lib.unet import create_model
from .improved_diffusion_lib.gaussian_diffusion import create_gaussian_diffusion


[docs]class ImprovedDiffusion(API): """The image API that utilizes improved diffusion models from https://arxiv.org/abs/2102.09672."""
[docs] def __init__( self, variation_degrees, model_path, model_image_size=64, num_channels=192, num_res_blocks=3, learn_sigma=True, class_cond=True, use_checkpoint=False, attention_resolutions="16,8", num_heads=4, num_heads_upsample=-1, use_scale_shift_norm=True, dropout=0.0, diffusion_steps=4000, sigma_small=False, noise_schedule="cosine", use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, timestep_respacing="100", batch_size=2000, use_ddim=True, clip_denoised=True, use_data_parallel=True, ): """Constructor. See https://github.com/openai/improved-diffusion for the explanation of the parameters not listed here. :param variation_degrees: The variation degrees utilized at each PE iteration. If a single int is provided, the same variation degree will be used for all iterations. :type variation_degrees: int or list[int] :param model_path: The path of the model checkpoint :type model_path: str :param diffusion_steps: The total number of diffusion steps, defaults to 4000 :type diffusion_steps: int, optional :param timestep_respacing: The step configurations for image generation utilized at each PE iteration. If a single str is provided, the same step configuration will be used for all iterations. Defaults to "100" :type timestep_respacing: str or list[str], optional :param batch_size: The batch size for image generation, defaults to 2000 :type batch_size: int, optional :param use_data_parallel: Whether to use data parallel during image generation, defaults to True :type use_data_parallel: bool, optional """ super().__init__() self._model = create_model( image_size=model_image_size, num_channels=num_channels, num_res_blocks=num_res_blocks, learn_sigma=learn_sigma, class_cond=class_cond, use_checkpoint=use_checkpoint, attention_resolutions=attention_resolutions, num_heads=num_heads, num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, dropout=dropout, ) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._model.load_state_dict(torch.load(model_path, map_location="cpu")) self._model.to(self._device) self._model.eval() all_timestep_respacing = ( set(timestep_respacing) if isinstance(timestep_respacing, list) else {timestep_respacing} ) self._timestep_respacing_to_diffusion = {} self._timestep_respacing_to_sampler = {} for sub_timestep_respacing in all_timestep_respacing: self._timestep_respacing_to_diffusion[sub_timestep_respacing] = create_gaussian_diffusion( steps=diffusion_steps, learn_sigma=learn_sigma, sigma_small=sigma_small, noise_schedule=noise_schedule, use_kl=use_kl, predict_xstart=predict_xstart, rescale_timesteps=rescale_timesteps, rescale_learned_sigmas=rescale_learned_sigmas, timestep_respacing=sub_timestep_respacing, ) self._timestep_respacing_to_sampler[sub_timestep_respacing] = Sampler( model=self._model, diffusion=self._timestep_respacing_to_diffusion[sub_timestep_respacing] ) if use_data_parallel: self._timestep_respacing_to_sampler[sub_timestep_respacing] = torch.nn.DataParallel( self._timestep_respacing_to_sampler[sub_timestep_respacing] ) if isinstance(timestep_respacing, str): self._timestep_respacing = ConstantList(timestep_respacing) else: self._timestep_respacing = timestep_respacing self._batch_size = batch_size self._use_ddim = use_ddim self._image_size = model_image_size self._clip_denoised = clip_denoised self._class_cond = class_cond if isinstance(variation_degrees, int): self._variation_degrees = ConstantList(variation_degrees) else: self._variation_degrees = variation_degrees
[docs] def random_api(self, label_info, num_samples): """Generating random synthetic data. :param label_info: The info of the label, not utilized in this API :type label_info: dict :param num_samples: The number of random samples to generate :type num_samples: int :return: The data object of the generated synthetic data :rtype: :py:class:`pe.data.data.Data` """ label_name = label_info.name execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}") samples, labels = sample( sampler=self._timestep_respacing_to_sampler[self._timestep_respacing[0]], start_t=0, num_samples=num_samples, batch_size=self._batch_size, use_ddim=self._use_ddim, image_size=self._image_size, clip_denoised=self._clip_denoised, class_cond=self._class_cond, device=self._device, ) samples = _round_to_uint8((samples + 1.0) * 127.5) samples = samples.transpose(0, 2, 3, 1) torch.cuda.empty_cache() data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: list(samples), IMAGE_MODEL_LABEL_COLUMN_NAME: list(labels), LABEL_ID_COLUMN_NAME: 0, } ) metadata = {"label_info": [label_info]} execution_logger.info(f"RANDOM API: finished creating {num_samples} samples for label {label_name}") return Data(data_frame=data_frame, metadata=metadata)
[docs] def variation_api(self, syn_data): """Generating variations of the synthetic data. :param syn_data: The data object of the synthetic data :type syn_data: :py:class:`pe.data.data.Data` :return: The data object of the variation of the input synthetic data :rtype: :py:class:`pe.data.data.Data` """ execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples") images = np.stack(syn_data.data_frame[IMAGE_DATA_COLUMN_NAME].values) labels = np.array(syn_data.data_frame[IMAGE_MODEL_LABEL_COLUMN_NAME].values) iteration = getattr(syn_data.metadata, "iteration", -1) variation_degree = self._variation_degrees[iteration + 1] timestep_respacing = self._timestep_respacing[iteration + 1] execution_logger.info( f"VARIATION API parameters: variation_degree={variation_degree}, timestep_respacing={timestep_respacing}, " f"iteration={iteration}" ) images = images.astype(np.float32) / 127.5 - 1.0 images = images.transpose(0, 3, 1, 2) variations, _ = sample( sampler=self._timestep_respacing_to_sampler[timestep_respacing], start_t=variation_degree, start_image=torch.Tensor(images).to(self._device), labels=(None if not self._class_cond else torch.LongTensor(labels).to(self._device)), num_samples=images.shape[0], batch_size=self._batch_size, use_ddim=self._use_ddim, image_size=self._image_size, clip_denoised=self._clip_denoised, class_cond=self._class_cond, device=self._device, ) variations = _round_to_uint8((variations + 1.0) * 127.5) variations = variations.transpose(0, 2, 3, 1) torch.cuda.empty_cache() data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: list(variations), IMAGE_MODEL_LABEL_COLUMN_NAME: list(labels), LABEL_ID_COLUMN_NAME: syn_data.data_frame[LABEL_ID_COLUMN_NAME].values, } ) if LABEL_ID_COLUMN_NAME in syn_data.data_frame.columns: data_frame[LABEL_ID_COLUMN_NAME] = syn_data.data_frame[LABEL_ID_COLUMN_NAME].values execution_logger.info(f"VARIATION API: finished creating variations for {len(syn_data.data_frame)} samples") return Data(data_frame=data_frame, metadata=syn_data.metadata)
[docs]def sample( sampler, num_samples, start_t, batch_size, use_ddim, image_size, clip_denoised, class_cond, device, start_image=None, labels=None, ): all_images = [] all_labels = [] batch_cnt = 0 cnt = 0 while cnt < num_samples: current_batch_size = ( batch_size if start_image is None else min(batch_size, start_image.shape[0] - batch_cnt * batch_size) ) current_batch_size = min(num_samples - cnt, current_batch_size) shape = (current_batch_size, 3, image_size, image_size) model_kwargs = {} if class_cond: if labels is None: classes = torch.randint( low=0, high=NUM_CLASSES, size=(current_batch_size,), device=device, ) else: classes = labels[batch_cnt * batch_size : (batch_cnt + 1) * batch_size] model_kwargs["y"] = classes sample = sampler( clip_denoised=clip_denoised, model_kwargs=model_kwargs, start_t=max(start_t, 0), start_image=( None if start_image is None else start_image[batch_cnt * batch_size : (batch_cnt + 1) * batch_size] ), use_ddim=use_ddim, noise=torch.randn(*shape, device=device), image_size=image_size, ) batch_cnt += 1 all_images.append(sample.detach().cpu().numpy()) if class_cond: all_labels.append(classes.detach().cpu().numpy()) cnt += sample.shape[0] execution_logger.info(f"Created {cnt} samples") all_images = np.concatenate(all_images, axis=0) all_images = all_images[:num_samples] if class_cond: all_labels = np.concatenate(all_labels, axis=0) all_labels = all_labels[:num_samples] else: all_labels = np.zeros(shape=(num_samples,)) return all_images, all_labels
[docs]class Sampler(torch.nn.Module): """A wrapper around the model and diffusion modules that handles the entire sampling process, so as to reduce the communiation rounds between GPUs when using DataParallel. """ def __init__(self, model, diffusion): super().__init__() self._model = model self._diffusion = diffusion
[docs] def forward( self, clip_denoised, model_kwargs, start_t, start_image, use_ddim, noise, image_size, ): sample_fn = self._diffusion.p_sample_loop if not use_ddim else self._diffusion.ddim_sample_loop sample = sample_fn( self._model, (noise.shape[0], 3, image_size, image_size), clip_denoised=clip_denoised, model_kwargs=model_kwargs, start_t=max(start_t, 0), start_image=start_image, noise=noise, device=noise.device, ) return sample
def _round_to_uint8(image): return np.around(np.clip(image, a_min=0, a_max=255)).astype(np.uint8)
[docs]class ImprovedDiffusion270M(ImprovedDiffusion): #: The URL of the checkpoint path CHECKPOINT_URL = "https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_cond_270M_250K.pt"
[docs] def __init__( self, variation_degrees, model_path=None, batch_size=2000, timestep_respacing="100", use_data_parallel=True, ): """The "Class-conditional ImageNet-64 model (270M parameters, trained for 250K iterations)" model from the Improved Diffusion paper. :param variation_degrees: The variation degrees utilized at each PE iteration :type variation_degrees: list[int] :param model_path: The path of the model checkpoint. If not provided, the checkpoint will be downloaded from the `CHECKPOINT_URL` :type model_path: str :param batch_size: The batch size for image generation, defaults to 2000 :type batch_size: int, optional :param timestep_respacing: The step configuration for image generation, defaults to "100" :type timestep_respacing: str, optional :param use_data_parallel: Whether to use data parallel during image generation, defaults to True :type use_data_parallel: bool, optional """ if model_path is None or not os.path.exists(model_path): model_path = self._download_checkpoint(model_path) super().__init__( variation_degrees=variation_degrees, model_path=model_path, model_image_size=64, num_channels=192, num_res_blocks=3, learn_sigma=True, class_cond=True, use_checkpoint=False, attention_resolutions="16,8", num_heads=4, num_heads_upsample=-1, use_scale_shift_norm=True, dropout=0.0, diffusion_steps=4000, sigma_small=False, noise_schedule="cosine", use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, timestep_respacing=timestep_respacing, batch_size=batch_size, use_ddim=True, clip_denoised=True, use_data_parallel=use_data_parallel, )
def _download_checkpoint(self, model_path): execution_logger.info(f"Downloading ImprovedDiffusion checkpoint from {self.CHECKPOINT_URL}") if model_path is None: model_path = tempfile.mktemp(suffix=".pt") download(url=self.CHECKPOINT_URL, fname=model_path) execution_logger.info(f"Finished downloading ImprovedDiffusion checkpoint to {model_path}") return model_path