# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Samples a large number of images from a pre-trained diffusion model using DDP.
Subsequently saves a .npz file that can be used to compute FID and other
evaluation metrics via the ADM repo:
https://github.com/openai/guided-diffusion/tree/main/evaluations
For a simple single-GPU/CPU sampling script, see sample_generator_dit_simple.py.
"""
import logging
import math
import os
from abc import ABC, abstractmethod
from datetime import datetime
import numpy as np
import torch
import torch.distributed as dist
import yaml
from PIL import Image
from tqdm import tqdm
from modelzoo.vision.pytorch.dit.pipeline import DiffusionPipeline
from modelzoo.vision.pytorch.dit.samplers.get_sampler import get_sampler
from modelzoo.vision.pytorch.dit.utils import set_defaults
LOGFORMAT = '%(asctime)s %(levelname)-4s[%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
[docs]def create_npz_from_sample_folder(sample_dir, num_samples=50000):
"""
Builds a single .npz file from a folder of .png samples.
"""
samples = []
for i in tqdm(range(num_samples), desc="Building .npz file from samples"):
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (num_samples, samples.shape[1], samples.shape[2], 3)
npz_path = os.path.join(sample_dir, "sample.npz")
np.savez(npz_path, arr_0=samples)
logging.info(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
[docs]class SampleGenerator(ABC):
[docs] def __init__(
self,
model_ckpt_path,
vae_ckpt_path,
params,
sample_dir,
seed,
num_fid_samples=50000,
per_gpu_batch_size=None,
):
"""
Main BaseClass for model sample generation
Args:
model_ckpt_path (str): Path to pretrained diffusion model checkpoint
vae_ckpt_path (str): Path to pretrained VAE model checkpoint
params (str): Path to yaml containing model params
sample_dir (str): Path to folder where generated images
and npz file to be stored
seed (int): Seed for random generation process
num_fid_samples (int): Number of images to be generated
per_gpu_batch_size (int): Per gpu batch size,
this input overrides that in yaml if provided.
"""
self.model_ckpt_path = model_ckpt_path
self.vae_ckpt_path = vae_ckpt_path
self.params_path = params
self.sample_dir = sample_dir
self.seed = seed
self.num_fid_samples = num_fid_samples
# params
with open(self.params_path, "r") as fh:
params_data = yaml.safe_load(fh)
self.params = set_defaults(params_data)
self.rparams = self.params["model"]["reverse_process"]
if per_gpu_batch_size is not None:
self.rparams["batch_size"] = per_gpu_batch_size
_msg = f"Using command line batch_size = {per_gpu_batch_size} on each GPU: "
else:
per_gpu_batch_size = self.rparams["batch_size"]
_msg = f"Using batch_size = {per_gpu_batch_size} from reverse_params and set_defaults on each GPU"
self.per_gpu_batch_size = per_gpu_batch_size
logging.info(_msg)
assert (
self.rparams["pipeline"]["guidance_scale"] >= 1.0
), "In almost all cases, guidance_scale be >= 1.0"
self.use_cfg = self.rparams["pipeline"]["guidance_scale"] > 1.0
[docs] @abstractmethod
def create_diffusion_model(self, params, model_ckpt_path, use_cfg, device):
"""
Initialize diffusion model, load checkpoint. Also return forward function to use
"""
raise NotImplementedError(
f"create_diffusion_model should be implemented in child class"
)
@abstractmethod
def create_vae_model(self, vae_params, vae_ckpt_path, device):
raise NotImplementedError(
f"create_vae_model should be implemented in child class"
)
[docs] def create_sampler(self, sampler_params):
"""
Create sampler object.
sampler_params: contains kwargs that can be
passed as input to __init__ of the sampler class
"""
sname = sampler_params.pop("name")
sampler = get_sampler(sname)(**sampler_params)
return sampler
[docs] def create_pipeline(
self, sampler, device,
):
"""
Get Pipeline object that
creates samples from a batch of random normal noised latent
using diffusion model and sampler
"""
# Create pipeline
diff_pipe = DiffusionPipeline(sampler, device=device).to(device)
return diff_pipe
def _save_params(self):
"""
Save params to yaml
"""
curr_time = datetime.utcnow().strftime('%m%d%Y_%H%M%S')
with open(
os.path.join(self.sample_dir, f"params_{curr_time}.yaml"), "w",
) as fh:
yaml.dump(self.params, fh)
def setup_dist(self):
# initialize the process group
dist.init_process_group("nccl")
def cleanup_dist(self):
# clean up process group
dist.destroy_process_group()
[docs] def run(self):
"""
MAIN function
"""
assert torch.cuda.is_available(), "Requires at least one GPU."
torch.set_grad_enabled(False)
# Set up dist process group
self.setup_dist()
# Get ranks and device
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = dist.get_rank()
device = global_rank % torch.cuda.device_count()
world_size = dist.get_world_size()
generator = torch.Generator(device)
if self.seed is None:
# large random number chosen as `high` upper bound
seed = torch.randint(0, 2147483647, (1,), dtype=torch.int64).item()
else:
seed = self.seed + global_rank
generator.manual_seed(seed)
torch.cuda.set_device(device)
logging.info(
f"Starting rank={global_rank}, seed={seed}, world_size={world_size}."
)
# Create folder to save samples:
if global_rank == 0:
os.makedirs(self.sample_dir, exist_ok=True)
logging.info(f"\nSaving .png samples at {self.sample_dir}\n")
self._save_params()
dist.barrier()
# Initialize Diffusion Model
diffusion_model, model_fwd_fn = self.create_diffusion_model(
self.params, self.model_ckpt_path, self.use_cfg, device=device
)
diffusion_model.eval() # Important
# Initialize VAE model for decoding latent
vae_model = self.create_vae_model(
self.params["model"]["vae"], self.vae_ckpt_path, device
)
vae_model.eval() # Important
# Create Sampler for Reverse Diffusion process (for loop from T -> 1)
sampler_params = self.rparams["sampler"]
sampler = self.create_sampler(sampler_params)
# Create pipeline
input_shape = (
self.rparams["batch_size"],
self.params["model"]["latent_channels"],
*self.params["model"]["latent_size"],
)
diff_pipe = self.create_pipeline(sampler=sampler, device=device)
local_bsz = self.per_gpu_batch_size
global_batch_size = local_bsz * world_size
total_samples = int(
math.ceil(self.num_fid_samples / world_size) * world_size
)
num_samples_per_gpu = int(total_samples // world_size)
if num_samples_per_gpu < local_bsz:
raise ValueError(
f"`per_gpu_batch_size`(={local_bsz}) > "
f"number of samples per gpu(={num_samples_per_gpu}). "
f"Lower batch size in `model.reverse_process` in params yaml"
)
if num_samples_per_gpu % local_bsz != 0:
num_samples_per_gpu = int(
math.ceil(num_samples_per_gpu / local_bsz) * local_bsz
)
total_samples = num_samples_per_gpu * world_size
iterations = int(num_samples_per_gpu // local_bsz)
pbar = tqdm(range(iterations), desc="num_batches")
if global_rank == 0:
logging.info(
f"\nTotal number of images that will be sampled: {total_samples} \n"
)
total = 0
diff_inputs = {}
if self.use_cfg:
diff_inputs["guidance_scale"] = self.rparams["pipeline"][
"guidance_scale"
]
diff_inputs["num_cfg_channels"] = self.rparams["pipeline"][
"num_cfg_channels"
]
for _ in pbar:
# Random normal noised_latent and random integer labels
num_classes = self.rparams["pipeline"]["num_classes"]
custom_labels = self.rparams["pipeline"]["custom_labels"]
_inputs = diff_pipe.build_inputs(
input_shape,
num_classes,
self.use_cfg,
generator=generator,
custom_labels=custom_labels,
)
diff_inputs.update(_inputs)
# Denoised sample
latent_sample = diff_pipe(
model_fwd_fn=model_fwd_fn,
generator=generator,
progress=False,
use_cfg=self.use_cfg,
**diff_inputs,
)
samples = vae_model.decode(
latent_sample / self.params["model"]["vae"]["scaling_factor"]
).sample
samples = (
torch.clamp(127.5 * samples + 128.0, 0, 255)
.permute(0, 2, 3, 1)
.to("cpu", dtype=torch.uint8)
.numpy()
)
# Save samples to disk as individual .png files
for i, sample in enumerate(samples):
index = i * dist.get_world_size() + global_rank + total
Image.fromarray(sample).save(
f"{self.sample_dir}/{index:06d}.png"
)
total += global_batch_size
# Make sure all processes have finished saving their
# samples before attempting to convert to .npz
dist.barrier()
if global_rank == 0:
create_npz_from_sample_folder(self.sample_dir, self.num_fid_samples)
print("Done.")
dist.barrier()
# Clean up dist processes
self.cleanup_dist()