modelzoo.common.pytorch.run_cstorch_flow.run_cstorch_train#

modelzoo.common.pytorch.run_cstorch_flow.run_cstorch_train(params, model_fn, input_fn, cs_config)[source]#

Runs the training workflow built using the cstorch API

Parameters
  • params – the params dictionary extracted from the params.yaml used

  • model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module

  • input_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader