modelzoo.vision.pytorch.dit.sample_generator_dit.DiTSampleGenerator#

class modelzoo.vision.pytorch.dit.sample_generator_dit.DiTSampleGenerator[source]#

Bases: modelzoo.vision.pytorch.dit.sample_generator.SampleGenerator

Class for DiT model sample generation :param model_ckpt_path: Path to pretrained diffusion model checkpoint :type model_ckpt_path: str :param vae_ckpt_path: Path to pretrained VAE model checkpoint :type vae_ckpt_path: str :param params: Path to yaml containing model params :type params: str :param sample_dir: Path to folder where generated images

and npz file to be stored

Parameters
  • seed (int) – Seed for random generation process

  • num_fid_samples (int) – Number of images to be generated

  • per_gpu_batch_size (int) – Per gpu batch size, command line input overrides that in yaml if provided.

Methods

cleanup_dist

create_diffusion_model

Initialize DiT model and load ckpt if provided

create_pipeline

Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler

create_sampler

Create sampler object. sampler_params: contains kwargs that can be passed as input to __init__ of the sampler class.

create_vae_model

Initialize VAE model and load ckpt if provided

run

MAIN function

setup_dist

__init__(model_ckpt_path, vae_ckpt_path, params, sample_dir, seed, num_fid_samples=50000, per_gpu_batch_size=None)[source]#

Class for DiT model sample generation :param model_ckpt_path: Path to pretrained diffusion model checkpoint :type model_ckpt_path: str :param vae_ckpt_path: Path to pretrained VAE model checkpoint :type vae_ckpt_path: str :param params: Path to yaml containing model params :type params: str :param sample_dir: Path to folder where generated images

and npz file to be stored

Parameters
  • seed (int) – Seed for random generation process

  • num_fid_samples (int) – Number of images to be generated

  • per_gpu_batch_size (int) – Per gpu batch size, command line input overrides that in yaml if provided.

create_diffusion_model(params, model_ckpt_path, use_cfg, device)[source]#

Initialize DiT model and load ckpt if provided

Parameters
  • params (dict) – params to be passed to DiT model initilization

  • model_ckpt_path (str) – Path to model checkpoint without VAE

  • use_cfg (bool) – If True, apply classifier free guidance on inputs i.e select the appropriate forward fcn based on use_cfg

  • device (str) – Target device for model

Returns

DiT model fwd_fn (Callable) : Forward fcn to be used by

pipeline object for sampling

Return type

dit_model (nn.Module)

create_pipeline(sampler, device)#

Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler

create_sampler(sampler_params)#

Create sampler object. sampler_params: contains kwargs that can be

passed as input to __init__ of the sampler class

create_vae_model(vae_params, vae_ckpt_path, device)[source]#

Initialize VAE model and load ckpt if provided

Parameters
  • vae_params (dict) – params to initialize VAE model

  • vae_ckpt_path (str) – Path to VAE model checkpoint

  • device (str) – Target device for model

Returns

VAE model for decoding

Return type

vae_model (nn.Module)

run()#

MAIN function