# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from enum import Enum
[docs]class BlockType(Enum):
ADALN_ZERO = "adaln_zero"
@classmethod
def values(cls):
return [b.value for b in BlockType]
@classmethod
def get(cls, blk):
if isinstance(blk, str):
return BlockType(blk)
elif isinstance(blk, Enum):
return blk
else:
raise ValueError(
f"Unsupported type {type(blk)}, supported are `str` and `Enum`"
)
[docs]def set_defaults(params):
"""
Update any missing parameters in the params dictionary with default values
Args:
params: The dictionary containing the params
"""
tparams = params["train_input"]
eparams = params["eval_input"]
mparams = params["model"]
runconfig = params["runconfig"]
# train input_required parameters
_set_input_defaults(params)
_set_model_defaults(params)
_copy_params_across(params)
# Runconfig related
if runconfig["checkpoint_steps"] == 0:
logging.warning(
"Setting `runconfig.checkpoint_steps` to `runconfig.max_steps`. Setting to 0 only saves initial checkpoint"
)
runconfig["checkpoint_steps"] = runconfig["max_steps"]
return params
def _set_model_defaults(params):
# model related parameters
mparams = params["model"]
tparams = params["train_input"]
mparams["num_diffusion_steps"] = tparams["num_diffusion_steps"]
mparams["num_classes"] = tparams["num_classes"]
mparams["beta_start"] = mparams.get("beta_start", 0.0001)
mparams["beta_end"] = mparams.get("beta_end", 0.02)
tparams["vae_scaling_factor"] = params["model"]["vae"]["scaling_factor"]
params["eval_input"]["vae_scaling_factor"] = params["model"]["vae"][
"scaling_factor"
]
mparams["vae"]["in_channels"] = tparams["image_channels"]
mparams["vae"]["out_channels"] = tparams["image_channels"]
mparams["vae"]["scaling_factor"] = mparams["vae"].get(
"scaling_factor", 0.18215
)
mparams["latent_channels"] = mparams.get(
"latent_channels", mparams["vae"]["latent_channels"]
)
mparams["latent_size"] = mparams.get(
"latent_size", mparams["vae"]["latent_size"]
)
image_dims = [tparams["image_channels"]] + tparams["image_size"]
latent_dims = [mparams["latent_channels"]] + mparams["latent_size"]
logging.info(
f"Using Image Dimensions (C, H, W): {image_dims} and VAE output Dimensions (C, H, W): {latent_dims}"
)
mparams["block_type"] = mparams.get(
"block_type", BlockType.ADALN_ZERO.value
)
if mparams["block_type"] not in BlockType.values():
raise ValueError(
f"Unsupported DiT block type {mparams['block_type']}. ",
f"Supported values are {BlockType.values()}.",
)
logging.info(f"Using DiT block type : {mparams['block_type']}")
if mparams["fp16_type"] == "bfloat16":
params["optimizer"]["loss_scaling_factor"] = 1.0
# Regression Head
# False -> linear + unpatchify for regression head
mparams["use_conv_transpose_unpatchify"] = mparams.get(
"use_conv_transpose_unpatchify", True
)
if not mparams["use_conv_transpose_unpatchify"]:
raise ValueError(
f"Using linear layer + unpatchify in RegressionHead is unsupported at this time, "
f"please set `model.use_conv_transpose_unpatchify` to True"
)
_set_layer_initializer_defaults(params)
_set_reverse_process_defaults(params)
def _set_reverse_process_defaults(params):
mparams = params["model"]
rparams = mparams.get("reverse_process", {})
if rparams:
rparams["sampler"]["num_diffusion_steps"] = rparams["sampler"].get(
"num_diffusion_steps", mparams["num_diffusion_steps"]
)
rparams["batch_size"] = rparams.get("batch_size", 32)
rparams["pipeline"]["num_classes"] = rparams["pipeline"].get(
"num_classes", mparams["num_classes"]
)
rparams["pipeline"]["custom_labels"] = rparams["pipeline"].get(
"custom_labels", None
)
# For DDPM Sampler only
if rparams["sampler"]["name"] == "ddpm":
rparams["sampler"]["variance_type"] = "fixed_small"
def _set_layer_initializer_defaults(params):
# Modifies in-place
mparams = params["model"]
# Patch Embedding
mparams["projection_initializer"] = {"name": "xavier_uniform", "gain": 1.0}
mparams["init_conv_like_linear"] = mparams.get(
"init_conv_like_linear", mparams["use_conv_patchified_embedding"]
)
# Timestep Embedding MLP
mparams["timestep_embeddding_initializer"] = {
"name": "normal",
"mean": 0.0,
"std": mparams["initializer_range"],
}
# Label Embedding table
mparams["label_embedding_initializer"] = {
"name": "normal",
"mean": 0.0,
"std": mparams["initializer_range"],
}
# Attention
mparams["attention_initializer"] = {"name": "xavier_uniform", "gain": 1.0}
# ffn
mparams["ffn_initializer"] = {"name": "xavier_uniform", "gain": 1.0}
# Regression Head FFN
mparams["head_initializer"] = {"name": "zeros"}
def _set_input_defaults(params):
# Modifies in place
# train input_required parameters
tparams = params["train_input"]
tparams["shuffle"] = tparams.get("shuffle", True)
tparams["shuffle_seed"] = tparams.get("shuffle_seed", 4321)
tparams["num_classes"] = tparams.get("num_classes", 1000)
tparams["noaugment"] = tparams.get("noaugment", False)
tparams["drop_last"] = tparams.get("drop_last", True)
tparams["num_workers"] = tparams.get("num_workers", 0)
tparams["prefetch_factor"] = tparams.get("prefetch_factor", 10)
tparams["persistent_workers"] = tparams.get("persistent_workers", True)
if tparams["noaugment"]:
tparams["transforms"] = None
logging.info(
f"Since `noaugment`={tparams['noaugment']}, the transforms are set to None"
)
tparams["use_worker_cache"] = tparams.get("use_worker_cache", False)
# eval input_required parameters
eparams = params["eval_input"]
eparams["shuffle"] = eparams.get("shuffle", False)
eparams["shuffle_seed"] = eparams.get("shuffle_seed", 4321)
eparams["noaugment"] = eparams.get("noaugment", False)
eparams["drop_last"] = eparams.get("drop_last", True)
eparams["num_workers"] = eparams.get("num_workers", 0)
eparams["prefetch_factor"] = eparams.get("prefetch_factor", 10)
eparams["persistent_workers"] = eparams.get("persistent_workers", True)
if eparams["noaugment"]:
eparams["transforms"] = None
logging.info(
f"Since `noaugment`={eparams['noaugment']}, the transforms are set to None"
)
eparams["use_worker_cache"] = eparams.get("use_worker_cache", False)
def _copy_params_across(params):
# Pass model settings into data loader.
_model_to_input_map = [
# latent shape
"label_dropout_rate",
"latent_size",
"latent_channels",
# Other params
"mixed_precision",
"fp16_type",
# diffusion & related params for performing gd
"schedule_name",
]
for _key_map in _model_to_input_map:
if isinstance(_key_map, tuple):
assert len(_key_map) == 2, f"Tuple {_key_map} does not have len=2"
model_key, input_key = _key_map
else:
model_key = input_key = _key_map
for section in ["train_input", "eval_input"]:
params[section][input_key] = params["model"][model_key]