modelzoo.vision.pytorch.dit.pipeline.DiffusionPipeline#
- class modelzoo.vision.pytorch.dit.pipeline.DiffusionPipeline[source]#
Bases:
torch.nn.Module
- Parameters
sampler – Instance of one of the supported samplers. Refer to ./samplers/get_sampler.py
device (int) – Device info
Methods
Utility to build random inputs to be passed to the model
- __call__(*args: Any, **kwargs: Any) Any #
Call self as a function.
- __init__(sampler, device='cpu')[source]#
- Parameters
sampler – Instance of one of the supported samplers. Refer to ./samplers/get_sampler.py
device (int) – Device info
- static __new__(cls, *args: Any, **kwargs: Any) Any #
- build_inputs(input_shape, num_classes, use_cfg, custom_labels=None, generator=None)[source]#
- Utility to build random inputs to be passed to the model
for the first pass of reverse diffusion process
- Parameters
input_shape (Tuple) – Tuple indicating shape of noised_latent to be passed to Diffusion model
num_classes (int) – number of class labels in the dataset that the model was trained on
use_cfg (bool) – If True, use classifier guidance during sampling
custom_labels (List[int]) – Optional list of labels that should be used as conditioning during sampling process. If specified, the model generates images from these classes only
generator (torch.Generator) – For setting random generator state
- Returns
noised_latent and label Note that the keys are chosen to have the same name as used in forward/forward_cfg method of the model
- Return type
dict with keys
- forward(model_fwd_fn: Callable, generator: Optional[torch.Generator] = None, progress: bool = True, use_cfg: bool = True, **inputs_to_model_fwd_fn)#
- Parameters
model_fwd_fn – Function handle to the desired forward pass of the diffusion model.
generator (torch.Generator) – For setting random generator state
progress (bool) – If true, displays progress bar indicating the timestep loop in sampling process
use_cfg (bool) – If True, use classifier guidance during sampling
inputs_to_model_fwd_fn – kwargs that contain all params to be passed to model_fwd_fn. Assumes that the model_fwd_fn has inputs by name noised_latent indicating gaussian diffused latent and label indicating the conditioning labels to be used.
- Returns
torch.Tensor containing final generated sample at the end of timestep loop T -> 1