cerebras.modelzoo.common.run_cstorch_flow#

Generic run scripts build using the cstorch API.

Functions

compute_grad_norm

Compute the model wise and per layer norm of the gradients.

compute_params_norm

Compute the model wise norm of the parameters.

get_cluster_config

Sets up CS cluster config for the run.

get_model_checkpoint

Get the path to the model checkpoint, if any.

log_input_summary

Log the input tensors to tensorboard.

optimizer_step_with_summaries

Customized equivalent to cstorch.amp.optimizer_step additionally featuring grad norm summaries.

run_cstorch_eval

Runs the evaluation workflow built using the cstorch API.

run_cstorch_flow

Set up the cstorch run and call the appropriate helper based on the mode.

run_cstorch_train

Runs the training workflow built using the cstorch API.

setup_hf_env_vars

Classes

GradScalerParams

Dataclass for parsing grad scaler params from optimizer params.