Source code for modelzoo.vision.pytorch.dit.pipeline

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable, Optional

import torch
import torch.nn as nn


[docs]class DiffusionPipeline(nn.Module):
[docs] def __init__(self, sampler, device="cpu"): """ Args: sampler: Instance of one of the supported samplers. Refer to ./samplers/get_sampler.py device (int): Device info """ super().__init__() self.device = device self.sampler = sampler logging.warning( f"This pipeline assumes that the `model_fwd_fn` arg passed to `pipeline.step` " f"takes inputs with specific keyword args `noised_latent`, `label`, `timestep` " f"Please make sure your function definition follows the same name convention " f"when using this pipeline" )
[docs] def build_inputs( self, input_shape, num_classes, use_cfg, custom_labels=None, generator=None, ): """ Utility to build random inputs to be passed to the model for the first pass of reverse diffusion process Args: input_shape (Tuple): Tuple indicating shape of noised_latent to be passed to Diffusion model num_classes (int): number of class labels in the dataset that the model was trained on use_cfg (bool): If True, use classifier guidance during sampling custom_labels (List[int]) : Optional list of labels that should be used as conditioning during sampling process. If specified, the model generates images from these classes only generator (torch.Generator): For setting random generator state Returns: dict with keys: `noised_latent` and `label` Note that the keys are chosen to have the same name as used in `forward`/`forward_cfg` method of the model """ # Sample inputs: bsz, C, H, W = input_shape self.batch_size = bsz noised_latent = torch.randn( input_shape, device=self.device, generator=generator ) if custom_labels is None: label = torch.randint( 0, num_classes, (bsz,), device=self.device, generator=generator ) else: if not isinstance(custom_labels, torch.Tensor): custom_labels = torch.tensor(custom_labels, device=self.device) sample_ids = torch.randint( 0, len(custom_labels), size=(bsz,), device=self.device, generator=generator, ) label = custom_labels[sample_ids] # Setup classifier-free guidance: if use_cfg: noised_latent = torch.cat([noised_latent, noised_latent], 0) label_null = torch.tensor( [num_classes] * bsz, device=self.device, # unconditional label id = num_classes ) label = torch.cat([label, label_null], 0) return {"noised_latent": noised_latent, "label": label}
@torch.no_grad() def forward( self, model_fwd_fn: Callable, generator: Optional[torch.Generator] = None, progress: bool = True, use_cfg: bool = True, **inputs_to_model_fwd_fn, ): """ Args: model_fwd_fn: Function handle to the desired forward pass of the diffusion model. generator (torch.Generator): For setting random generator state progress (bool): If true, displays progress bar indicating the timestep loop in sampling process use_cfg (bool): If True, use classifier guidance during sampling inputs_to_model_fwd_fn: kwargs that contain all params to be passed to `model_fwd_fn`. Assumes that the `model_fwd_fn` has inputs by name `noised_latent` indicating gaussian diffused latent and `label` indicating the conditioning labels to be used. Returns: torch.Tensor containing final generated sample at the end of timestep loop T -> 1 """ if not inputs_to_model_fwd_fn: # if dict empty raise ValueError( f"Please pass inputs to `model_fwd_fn` as kwargs `inputs_to_model_fwd_fn` " f"param by calling `self.build_inputs` with appropriate args" ) latent_model_input = inputs_to_model_fwd_fn["noised_latent"] if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm total_timesteps = tqdm(self.sampler.timesteps) else: total_timesteps = self.sampler.timesteps for t in total_timesteps: if progress: total_timesteps.set_postfix({'timestep': t}) timestep = t if not torch.is_tensor(timestep): timestep = torch.tensor( (timestep,), dtype=torch.int32, device=self.device ) timestep = timestep.expand(latent_model_input.shape[0]).to( latent_model_input.device ) # self.inputs[0].shape[0] -> bsz pred_noise, pred_var = model_fwd_fn( timestep=timestep, **inputs_to_model_fwd_fn ) # compute previous image: x_t -> x_t-1 latent_model_input = self.sampler.step( pred_noise, pred_var, t, latent_model_input, generator=generator ).prev_sample inputs_to_model_fwd_fn["noised_latent"] = latent_model_input if use_cfg: latents, _ = latent_model_input.chunk(2, dim=0) else: latents = latent_model_input return latents