Source code for modelzoo.vision.pytorch.dit.sample_generator_dit

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# isort: off
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
# isort: on
"""
Usage: 

torchrun --nnodes 1 --nproc_per_node 4 \
modelzoo/vision/pytorch/dit/sample_generator_dit.py \
--model_ckpt_path <trained_diffusion_ckpt path> \
--vae_ckpt_path <vae_ckpt_path> \
--params modelzoo/vision/pytorch/dit/configs/params_dit_xlarge_patchsize_2x2.yaml \
--sample_dir=modelzoo/vision/pytorch/dit/sample_dir \
--num_fid_samples=50000

"""

import argparse
import logging

import cerebras_pytorch as cstorch
from modelzoo.vision.pytorch.dit.layers.vae.VAEModel import AutoencoderKL
from modelzoo.vision.pytorch.dit.model import DiTModel
from modelzoo.vision.pytorch.dit.sample_generator import SampleGenerator


[docs]class DiTSampleGenerator(SampleGenerator):
[docs] def __init__( self, model_ckpt_path, vae_ckpt_path, params, sample_dir, seed, num_fid_samples=50000, per_gpu_batch_size=None, ): """ Class for DiT model sample generation Args: model_ckpt_path (str): Path to pretrained diffusion model checkpoint vae_ckpt_path (str): Path to pretrained VAE model checkpoint params (str): Path to yaml containing model params sample_dir (str): Path to folder where generated images and npz file to be stored seed (int): Seed for random generation process num_fid_samples (int): Number of images to be generated per_gpu_batch_size (int): Per gpu batch size, command line input overrides that in yaml if provided. """ super().__init__( model_ckpt_path, vae_ckpt_path, params, sample_dir, seed, num_fid_samples, per_gpu_batch_size, )
[docs] def create_diffusion_model(self, params, model_ckpt_path, use_cfg, device): """ Initialize DiT model and load ckpt if provided Args: params (dict): params to be passed to DiT model initilization model_ckpt_path (str): Path to model checkpoint without VAE use_cfg (bool): If True, apply classifier free guidance on inputs i.e select the appropriate forward fcn based on `use_cfg` device (str): Target device for model Returns: dit_model (nn.Module): DiT model fwd_fn (Callable) : Forward fcn to be used by pipeline object for sampling """ # Initialize DiT Model dit_model = DiTModel(params) # Load checkpoint if model_ckpt_path: _dit_dict = cstorch.load(model_ckpt_path) dit_model.load_state_dict(_dit_dict["model"]) logging.info(f"Initializing DiT model with {model_ckpt_path}") else: logging.info(f"Initializing DiT model with random weights") dit_model = dit_model.model dit_model.to(device) # Select forward fcn if use_cfg: fwd_fn = dit_model.forward_dit_with_cfg else: fwd_fn = dit_model.forward_dit return dit_model, fwd_fn
[docs] def create_vae_model(self, vae_params, vae_ckpt_path, device): """ Initialize VAE model and load ckpt if provided Args: vae_params (dict): params to initialize VAE model vae_ckpt_path (str) : Path to VAE model checkpoint device (str): Target device for model Returns: vae_model (nn.Module): VAE model for decoding """ # Initialize VAE model vae_model = AutoencoderKL vae_model = vae_model(**vae_params) # Load checkpoint if vae_ckpt_path: _vae_dict = cstorch.load(vae_ckpt_path) vae_model.load_state_dict(_vae_dict) logging.info(f"Initializing VAE model with {vae_ckpt_path}") else: logging.info(f"Initializing VAE model with random weights") vae_model.to(device) return vae_model
[docs]def get_parser_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--seed", type=int, default=None) parser.add_argument( "--model_ckpt_path", type=str, default=None, help="Optional path to a diffusion model checkpoint", ) parser.add_argument( "--vae_ckpt_path", type=str, default=None, help="Optional VAE model checkpoint path", ) parser.add_argument( "--params", type=str, required=True, help="Path to params to initialize Diffusion model and VAE models", ) parser.add_argument( "--num_fid_samples", type=int, default=50000, help="number of samples to generate", ) parser.add_argument( "--sample_dir", type=str, required=True, help="Directory to store generated samples", ) parser.add_argument( "--batch_size", type=int, required=False, default=None, help="per-gpu batch size for forward pass", ) parser.add_argument( "--create_grid", action="store_true", required=False, help="If passed, create a grid from images generated", ) args = parser.parse_args() return args
if __name__ == "__main__": args = get_parser_args() sample_gen = DiTSampleGenerator( args.model_ckpt_path, args.vae_ckpt_path, args.params, args.sample_dir, args.seed, args.num_fid_samples, args.batch_size, ) sample_gen.run() if args.create_grid: import math from modelzoo.vision.pytorch.dit.display_images import display_images logging.info(f"Creating grid from samples generated....") nrow = math.ceil(math.sqrt(args.num_fid_samples)) display_images(folder_path=args.sample_dir, nrow=nrow)