cerebras.modelzoo.common.train_and_eval.train_and_eval#

cerebras.modelzoo.common.train_and_eval.train_and_eval(params, params_obj, model_fn, train_data_fn, eval_data_fn, cluster_config, artifact_dir)[source]#

Runs the train and eval workflow built using the cstorch API.

Parameters
  • params – the params dictionary extracted from the params.yaml used

  • model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module

  • train_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader meant for training

  • eval_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader meant for eval