modelzoo.vision.pytorch.dit.sample_generator_dit.DiTSampleGenerator#
- class modelzoo.vision.pytorch.dit.sample_generator_dit.DiTSampleGenerator[source]#
Bases:
modelzoo.vision.pytorch.dit.sample_generator.SampleGenerator
Class for DiT model sample generation :param model_ckpt_path: Path to pretrained diffusion model checkpoint :type model_ckpt_path: str :param vae_ckpt_path: Path to pretrained VAE model checkpoint :type vae_ckpt_path: str :param params: Path to yaml containing model params :type params: str :param sample_dir: Path to folder where generated images
and npz file to be stored
- Parameters
seed (int) – Seed for random generation process
num_fid_samples (int) – Number of images to be generated
per_gpu_batch_size (int) – Per gpu batch size, command line input overrides that in yaml if provided.
Methods
cleanup_dist
Initialize DiT model and load ckpt if provided
Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler
Create sampler object. sampler_params: contains kwargs that can be passed as input to __init__ of the sampler class.
Initialize VAE model and load ckpt if provided
MAIN function
setup_dist
- __init__(model_ckpt_path, vae_ckpt_path, params, sample_dir, seed, num_fid_samples=50000, per_gpu_batch_size=None)[source]#
Class for DiT model sample generation :param model_ckpt_path: Path to pretrained diffusion model checkpoint :type model_ckpt_path: str :param vae_ckpt_path: Path to pretrained VAE model checkpoint :type vae_ckpt_path: str :param params: Path to yaml containing model params :type params: str :param sample_dir: Path to folder where generated images
and npz file to be stored
- Parameters
seed (int) – Seed for random generation process
num_fid_samples (int) – Number of images to be generated
per_gpu_batch_size (int) – Per gpu batch size, command line input overrides that in yaml if provided.
- create_diffusion_model(params, model_ckpt_path, use_cfg, device)[source]#
Initialize DiT model and load ckpt if provided
- Parameters
params (dict) – params to be passed to DiT model initilization
model_ckpt_path (str) – Path to model checkpoint without VAE
use_cfg (bool) – If True, apply classifier free guidance on inputs i.e select the appropriate forward fcn based on use_cfg
device (str) – Target device for model
- Returns
DiT model fwd_fn (Callable) : Forward fcn to be used by
pipeline object for sampling
- Return type
dit_model (nn.Module)
- create_pipeline(sampler, device)#
Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler
- create_sampler(sampler_params)#
Create sampler object. sampler_params: contains kwargs that can be
passed as input to __init__ of the sampler class
- create_vae_model(vae_params, vae_ckpt_path, device)[source]#
Initialize VAE model and load ckpt if provided
- Parameters
vae_params (dict) – params to initialize VAE model
vae_ckpt_path (str) – Path to VAE model checkpoint
device (str) – Target device for model
- Returns
VAE model for decoding
- Return type
vae_model (nn.Module)
- run()#
MAIN function