# Copyright 2022 Cerebras Systems.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

Samples a large number of images from a pre-trained diffusion model using DDP.
Subsequently saves a .npz file that can be used to compute FID and other
evaluation metrics via the ADM repo:

For a simple single-GPU/CPU sampling script, see
import logging
import math
import os
from abc import ABC, abstractmethod
from datetime import datetime

import numpy as np
import torch
import torch.distributed as dist
import yaml
from PIL import Image
from tqdm import tqdm

from import DiffusionPipeline
from import get_sampler
from import set_defaults

LOGFORMAT = '%(asctime)s %(levelname)-4s[%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(level=logging.INFO, format=LOGFORMAT)

[docs]def create_npz_from_sample_folder(sample_dir, num_samples=50000): """ Builds a single .npz file from a folder of .png samples. """ samples = [] for i in tqdm(range(num_samples), desc="Building .npz file from samples"): sample_pil ="{sample_dir}/{i:06d}.png") sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) assert samples.shape == (num_samples, samples.shape[1], samples.shape[2], 3) npz_path = os.path.join(sample_dir, "sample.npz") np.savez(npz_path, arr_0=samples)"Saved .npz file to {npz_path} [shape={samples.shape}].") return npz_path
[docs]class SampleGenerator(ABC):
[docs] def __init__( self, model_ckpt_path, vae_ckpt_path, params, sample_dir, seed, num_fid_samples=50000, per_gpu_batch_size=None, ): """ Main BaseClass for model sample generation Args: model_ckpt_path (str): Path to pretrained diffusion model checkpoint vae_ckpt_path (str): Path to pretrained VAE model checkpoint params (str): Path to yaml containing model params sample_dir (str): Path to folder where generated images and npz file to be stored seed (int): Seed for random generation process num_fid_samples (int): Number of images to be generated per_gpu_batch_size (int): Per gpu batch size, this input overrides that in yaml if provided. """ self.model_ckpt_path = model_ckpt_path self.vae_ckpt_path = vae_ckpt_path self.params_path = params self.sample_dir = sample_dir self.seed = seed self.num_fid_samples = num_fid_samples # params with open(self.params_path, "r") as fh: params_data = yaml.safe_load(fh) self.params = set_defaults(params_data) self.rparams = self.params["model"]["reverse_process"] if per_gpu_batch_size is not None: self.rparams["batch_size"] = per_gpu_batch_size _msg = f"Using command line batch_size = {per_gpu_batch_size} on each GPU: " else: per_gpu_batch_size = self.rparams["batch_size"] _msg = f"Using batch_size = {per_gpu_batch_size} from reverse_params and set_defaults on each GPU" self.per_gpu_batch_size = per_gpu_batch_size assert ( self.rparams["pipeline"]["guidance_scale"] >= 1.0 ), "In almost all cases, guidance_scale be >= 1.0" self.use_cfg = self.rparams["pipeline"]["guidance_scale"] > 1.0
[docs] @abstractmethod def create_diffusion_model(self, params, model_ckpt_path, use_cfg, device): """ Initialize diffusion model, load checkpoint. Also return forward function to use """ raise NotImplementedError( f"create_diffusion_model should be implemented in child class" )
@abstractmethod def create_vae_model(self, vae_params, vae_ckpt_path, device): raise NotImplementedError( f"create_vae_model should be implemented in child class" )
[docs] def create_sampler(self, sampler_params): """ Create sampler object. sampler_params: contains kwargs that can be passed as input to __init__ of the sampler class """ sname = sampler_params.pop("name") sampler = get_sampler(sname)(**sampler_params) return sampler
[docs] def create_pipeline( self, sampler, device, ): """ Get Pipeline object that creates samples from a batch of random normal noised latent using diffusion model and sampler """ # Create pipeline diff_pipe = DiffusionPipeline(sampler, device=device).to(device) return diff_pipe
def _save_params(self): """ Save params to yaml """ curr_time = datetime.utcnow().strftime('%m%d%Y_%H%M%S') with open( os.path.join(self.sample_dir, f"params_{curr_time}.yaml"), "w", ) as fh: yaml.dump(self.params, fh) def setup_dist(self): # initialize the process group dist.init_process_group("nccl") def cleanup_dist(self): # clean up process group dist.destroy_process_group()
[docs] def run(self): """ MAIN function """ assert torch.cuda.is_available(), "Requires at least one GPU." torch.set_grad_enabled(False) # Set up dist process group self.setup_dist() # Get ranks and device local_rank = int(os.environ["LOCAL_RANK"]) global_rank = dist.get_rank() device = global_rank % torch.cuda.device_count() world_size = dist.get_world_size() generator = torch.Generator(device) if self.seed is None: # large random number chosen as `high` upper bound seed = torch.randint(0, 2147483647, (1,), dtype=torch.int64).item() else: seed = self.seed + global_rank generator.manual_seed(seed) torch.cuda.set_device(device) f"Starting rank={global_rank}, seed={seed}, world_size={world_size}." ) # Create folder to save samples: if global_rank == 0: os.makedirs(self.sample_dir, exist_ok=True)"\nSaving .png samples at {self.sample_dir}\n") self._save_params() dist.barrier() # Initialize Diffusion Model diffusion_model, model_fwd_fn = self.create_diffusion_model( self.params, self.model_ckpt_path, self.use_cfg, device=device ) diffusion_model.eval() # Important # Initialize VAE model for decoding latent vae_model = self.create_vae_model( self.params["model"]["vae"], self.vae_ckpt_path, device ) vae_model.eval() # Important # Create Sampler for Reverse Diffusion process (for loop from T -> 1) sampler_params = self.rparams["sampler"] sampler = self.create_sampler(sampler_params) # Create pipeline input_shape = ( self.rparams["batch_size"], self.params["model"]["latent_channels"], *self.params["model"]["latent_size"], ) diff_pipe = self.create_pipeline(sampler=sampler, device=device) local_bsz = self.per_gpu_batch_size global_batch_size = local_bsz * world_size total_samples = int( math.ceil(self.num_fid_samples / world_size) * world_size ) num_samples_per_gpu = int(total_samples // world_size) if num_samples_per_gpu < local_bsz: raise ValueError( f"`per_gpu_batch_size`(={local_bsz}) > " f"number of samples per gpu(={num_samples_per_gpu}). " f"Lower batch size in `model.reverse_process` in params yaml" ) if num_samples_per_gpu % local_bsz != 0: num_samples_per_gpu = int( math.ceil(num_samples_per_gpu / local_bsz) * local_bsz ) total_samples = num_samples_per_gpu * world_size iterations = int(num_samples_per_gpu // local_bsz) pbar = tqdm(range(iterations), desc="num_batches") if global_rank == 0: f"\nTotal number of images that will be sampled: {total_samples} \n" ) total = 0 diff_inputs = {} if self.use_cfg: diff_inputs["guidance_scale"] = self.rparams["pipeline"][ "guidance_scale" ] diff_inputs["num_cfg_channels"] = self.rparams["pipeline"][ "num_cfg_channels" ] for _ in pbar: # Random normal noised_latent and random integer labels num_classes = self.rparams["pipeline"]["num_classes"] custom_labels = self.rparams["pipeline"]["custom_labels"] _inputs = diff_pipe.build_inputs( input_shape, num_classes, self.use_cfg, generator=generator, custom_labels=custom_labels, ) diff_inputs.update(_inputs) # Denoised sample latent_sample = diff_pipe( model_fwd_fn=model_fwd_fn, generator=generator, progress=False, use_cfg=self.use_cfg, **diff_inputs, ) samples = vae_model.decode( latent_sample / self.params["model"]["vae"]["scaling_factor"] ).sample samples = ( torch.clamp(127.5 * samples + 128.0, 0, 255) .permute(0, 2, 3, 1) .to("cpu", dtype=torch.uint8) .numpy() ) # Save samples to disk as individual .png files for i, sample in enumerate(samples): index = i * dist.get_world_size() + global_rank + total Image.fromarray(sample).save( f"{self.sample_dir}/{index:06d}.png" ) total += global_batch_size # Make sure all processes have finished saving their # samples before attempting to convert to .npz dist.barrier() if global_rank == 0: create_npz_from_sample_folder(self.sample_dir, self.num_fid_samples) print("Done.") dist.barrier() # Clean up dist processes self.cleanup_dist()