modelzoo.vision.pytorch.dit.pipeline.DiffusionPipeline#

class modelzoo.vision.pytorch.dit.pipeline.DiffusionPipeline[source]#

Bases: torch.nn.Module

Parameters
  • sampler – Instance of one of the supported samplers. Refer to ./samplers/get_sampler.py

  • device (int) – Device info

Methods

build_inputs

Utility to build random inputs to be passed to the model

forward

__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

__init__(sampler, device='cpu')[source]#
Parameters
  • sampler – Instance of one of the supported samplers. Refer to ./samplers/get_sampler.py

  • device (int) – Device info

static __new__(cls, *args: Any, **kwargs: Any) Any#
build_inputs(input_shape, num_classes, use_cfg, custom_labels=None, generator=None)[source]#
Utility to build random inputs to be passed to the model

for the first pass of reverse diffusion process

Parameters
  • input_shape (Tuple) – Tuple indicating shape of noised_latent to be passed to Diffusion model

  • num_classes (int) – number of class labels in the dataset that the model was trained on

  • use_cfg (bool) – If True, use classifier guidance during sampling

  • custom_labels (List[int]) – Optional list of labels that should be used as conditioning during sampling process. If specified, the model generates images from these classes only

  • generator (torch.Generator) – For setting random generator state

Returns

noised_latent and label Note that the keys are chosen to have the same name as used in forward/forward_cfg method of the model

Return type

dict with keys

forward(model_fwd_fn: Callable, generator: Optional[torch.Generator] = None, progress: bool = True, use_cfg: bool = True, **inputs_to_model_fwd_fn)#
Parameters
  • model_fwd_fn – Function handle to the desired forward pass of the diffusion model.

  • generator (torch.Generator) – For setting random generator state

  • progress (bool) – If true, displays progress bar indicating the timestep loop in sampling process

  • use_cfg (bool) – If True, use classifier guidance during sampling

  • inputs_to_model_fwd_fn – kwargs that contain all params to be passed to model_fwd_fn. Assumes that the model_fwd_fn has inputs by name noised_latent indicating gaussian diffused latent and label indicating the conditioning labels to be used.

Returns

torch.Tensor containing final generated sample at the end of timestep loop T -> 1