cerebras.modelzoo.common.run_utils.run_with_params#

cerebras.modelzoo.common.run_utils.run_with_params(params: Dict[str, Any], model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, eval_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, extra_args_parser_fn: Optional[Callable[[], List[argparse.ArgumentParser]]] = None)[source]#

Runs a full end-to-end CS/non-CS workflow for a given model

Parameters
  • model_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a torch.nn.Module

  • train_data_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a PyTorch dataloader corresponding to the training dataset

  • eval_data_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a PyTorch dataloader corresponding to the evaluation dataset

  • default_params_fn – An optional callable that takes in the params dictionary and updates any missing params with default values

  • extra_args_parser_fn – An optional callable that adds any extra parser args not covered in get_parser fn.