modelzoo.vision.pytorch.dit.sample_generator.SampleGenerator#

class modelzoo.vision.pytorch.dit.sample_generator.SampleGenerator[source]#

Bases: abc.ABC

Main BaseClass for model sample generation :param model_ckpt_path: Path to pretrained diffusion model checkpoint :type model_ckpt_path: str :param vae_ckpt_path: Path to pretrained VAE model checkpoint :type vae_ckpt_path: str :param params: Path to yaml containing model params :type params: str :param sample_dir: Path to folder where generated images

and npz file to be stored

Parameters
  • seed (int) – Seed for random generation process

  • num_fid_samples (int) – Number of images to be generated

  • per_gpu_batch_size (int) – Per gpu batch size, this input overrides that in yaml if provided.

Methods

cleanup_dist

create_diffusion_model

Initialize diffusion model, load checkpoint.

create_pipeline

Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler

create_sampler

Create sampler object. sampler_params: contains kwargs that can be passed as input to __init__ of the sampler class.

create_vae_model

run

MAIN function

setup_dist

__init__(model_ckpt_path, vae_ckpt_path, params, sample_dir, seed, num_fid_samples=50000, per_gpu_batch_size=None)[source]#

Main BaseClass for model sample generation :param model_ckpt_path: Path to pretrained diffusion model checkpoint :type model_ckpt_path: str :param vae_ckpt_path: Path to pretrained VAE model checkpoint :type vae_ckpt_path: str :param params: Path to yaml containing model params :type params: str :param sample_dir: Path to folder where generated images

and npz file to be stored

Parameters
  • seed (int) – Seed for random generation process

  • num_fid_samples (int) – Number of images to be generated

  • per_gpu_batch_size (int) – Per gpu batch size, this input overrides that in yaml if provided.

abstract create_diffusion_model(params, model_ckpt_path, use_cfg, device)[source]#

Initialize diffusion model, load checkpoint. Also return forward function to use

create_pipeline(sampler, device)[source]#

Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler

create_sampler(sampler_params)[source]#

Create sampler object. sampler_params: contains kwargs that can be

passed as input to __init__ of the sampler class

run()[source]#

MAIN function