Source code for cerebras.modelzoo.common.run_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for running Cerebras Pytorch Models"""
import argparse
import inspect
import math
import os
import subprocess
import sys
from shutil import which
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import yaml

from cerebras.appliance.log import (
    collect_wsc_log_settings,
    get_level_name,
    wsc_logger,
)
from cerebras.modelzoo.common.pytorch_utils import (
    RunConfigParamsValidator,
    get_checkpoints,
)
from cerebras.modelzoo.common.run_cstorch_flow import run_cstorch_flow
from cerebras.modelzoo.common.utils.run.cli_parser import get_params_from_args
from cerebras.modelzoo.common.utils.run.utils import DeviceType
from cerebras.pytorch.core import modes

DATA_FN_TYPE = Callable[[dict], torch.utils.data.DataLoader]


[docs]def arg_filter(arg: str, keyword: str) -> bool: """Checks if a given arg matches the given keyword""" arg = arg.strip() return ( arg.startswith(f"--{keyword}=") or arg.startswith(f"-{keyword}=") or arg == f"--{keyword}" or arg == f"-{keyword}" )
[docs]def update_sideband_mode_arg( arguments: List[str], new_mode_arg: str, old_mode: str ) -> List[str]: """Updates sideband arguments to a different mode""" # filter out args with the name of the old mode provided they # have "mode" or "m" preceding offset_arguments = [None] + arguments updated_args = [ an_arg for an_arg, prev_arg in zip(arguments, offset_arguments) if an_arg != old_mode or not (arg_filter(prev_arg, "mode") or arg_filter(prev_arg, "m")) ] # filter and add the new mode updated_args = [ new_mode_arg if arg_filter(an_arg, "mode") or arg_filter(an_arg, "m") else an_arg for an_arg in updated_args ] return updated_args
[docs]def sideband_eval_all( filename: str, arguments: List[str], params: Dict[Any, Any] ): """Temporary support for running eval multiple times via subprocess""" eval_mode = "--mode=eval" if any(arg_filter(an_arg, "checkpoint_path") for an_arg in arguments): raise ValueError( "Checkpoint path cannot be provided with eval_all. Checkpoints inferred from model_dir" ) updated_args = update_sideband_mode_arg( arguments, eval_mode, f"sideband_{modes.EVAL_ALL}" ) # Gather all checkpoints checkpoint_path = None updated_args.append(checkpoint_path) checkpoints = get_checkpoints( params['runconfig']['model_dir'], ) if len(checkpoints) == 0: raise ValueError( f"No checkpoints found at {params['runconfig']['model_dir']}" ) for a_chkpt in checkpoints: checkpoint_path = f"--checkpoint_path={a_chkpt}" updated_args[-1] = checkpoint_path # By just calling this from the top each run will be a separate logdir single_run = [sys.executable, filename] single_run.extend(updated_args) subprocess.run(single_run, check=True)
[docs]def sideband_train_eval_all( filename: str, arguments: List[str], params: Dict[Any, Any] ): """Temporary support for running train and eval multiple times via subprocess""" train_mode = "--mode=train" eval_mode = "--mode=eval" train_args = update_sideband_mode_arg( arguments, train_mode, f"sideband_{modes.TRAIN_AND_EVAL}" ) eval_args = update_sideband_mode_arg( arguments, eval_mode, f"sideband_{modes.TRAIN_AND_EVAL}" ) runconfig = params['runconfig'] if runconfig.get('num_steps', None) is not None: if runconfig.get('num_epochs', None) is not None: raise ValueError( "num_steps and num_epochs cannot both be specified " "in the runconfig section of params" ) if runconfig.get('steps_per_epoch', None) is not None: raise ValueError( "num_steps and steps_per_epoch cannot both be specified " "in the runconfig section of params" ) if runconfig.get('eval_frequency', None) is None: raise ValueError( "if num_steps is specified, eval_frequency is needed " "to dictate how many train steps before each eval" ) total_steps = int(runconfig['num_steps']) train_steps = int(runconfig['eval_frequency']) num_iters = math.ceil(total_steps / train_steps) last_steps = total_steps % train_steps # add num_steps overwrite last_train_args = train_args + [f"--num_steps={last_steps}"] train_args.append(f"--num_steps={train_steps}") elif runconfig.get('num_epochs', None) is not None: num_iters = int(runconfig['num_epochs']) # add num_epochs overwrite train_args.append("--num_epochs=1") last_steps = 0 else: raise ValueError( "For train_and_eval mode, one of `num_steps` or `num_epochs` " " must be specified and not be None." ) single_run = [sys.executable, filename] train_cmd = single_run + train_args eval_cmd = single_run + eval_args for i in range(num_iters): # TRAIN if i == num_iters - 1 and last_steps > 0: train_cmd = single_run + last_train_args try: subprocess.run(train_cmd, check=True) except Exception as e: raise RuntimeError(f"Training at iteration {i} failed.") from e # EVAL try: subprocess.run(eval_cmd, check=True) except Exception as e: raise RuntimeError(f"Evaluate at iteration {i} failed.") from e
[docs]def torchrun(filename: str, arguments: List[str], params: Dict[Any, Any]): """Starts a distributed GPU run using torchrun""" torchrun_cmd = [ which("torchrun", path=os.path.dirname(sys.executable)), "--nnodes=1", f"--nproc_per_node={torch.cuda.device_count()}", filename, *arguments, ] try: print( f"Starting distributed GPU run using torchrun:\n" f"{' '.join(torchrun_cmd)}" ) subprocess.run(torchrun_cmd, check=True) except Exception as e: raise RuntimeError( f"Failed to spawn distributed GPU run using torchrun" ) from e
[docs]def run( model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[DATA_FN_TYPE] = None, eval_data_fn: Optional[DATA_FN_TYPE] = None, default_params_fn: Optional[Callable[[dict], dict]] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """ Entry point to running pytorch models including CLI argument parsing """ parent = inspect.getouterframes(inspect.currentframe())[1] run_dir = os.path.dirname(os.path.abspath(parent.filename)) params = get_params_from_args(run_dir, extra_args_parser_fn) if default_params_fn: params = default_params_fn(params) or params main(params, model_fn, train_data_fn, eval_data_fn, script=parent.filename)
[docs]def main( params: Dict[str, Any], model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[DATA_FN_TYPE] = None, eval_data_fn: Optional[DATA_FN_TYPE] = None, script: Optional[str] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """Entry point to running pytorch models""" if not script: parent = inspect.getouterframes(inspect.currentframe())[1] script = parent.filename wsc_log_level = params["runconfig"].get("wsc_log_level") or {} set_wsc_log_level(wsc_log_level) if params["runconfig"]["mode"] == f"sideband_{modes.EVAL_ALL}": sideband_eval_all(script, sys.argv[1:], params) return None # TODO ambiguity on what to return, possibly just run the final checkpoint in # the main process below # TODO(SW-99336): Remove this once we properly support train_and_eval with cs if params["runconfig"]["mode"] == f"sideband_{modes.TRAIN_AND_EVAL}": sideband_train_eval_all(script, sys.argv[1:], params) return None if ( # If using distributed GPU with experimental API params["runconfig"]["target_device"] == DeviceType.GPU and params["runconfig"].get("enable_distributed", False) # If this is already set, we've already launched distributed training and os.environ.get("LOCAL_RANK") is None ): # use torchrun to launch distributed training torchrun(script, sys.argv[1:], params) return None return run_with_params( params, model_fn, train_data_fn, eval_data_fn, extra_args_parser_fn=extra_args_parser_fn, )
[docs]def run_with_params( params: Dict[str, Any], model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[DATA_FN_TYPE] = None, eval_data_fn: Optional[DATA_FN_TYPE] = None, extra_args_parser_fn: Optional[ Callable[[], List[argparse.ArgumentParser]] ] = None, ): """ Runs a full end-to-end CS/non-CS workflow for a given model Args: model_fn: A callable that takes in a 'params' argument which it uses to configure and return a torch.nn.Module train_data_fn: A callable that takes in a 'params' argument which it uses to configure and return a PyTorch dataloader corresponding to the training dataset eval_data_fn: A callable that takes in a 'params' argument which it uses to configure and return a PyTorch dataloader corresponding to the evaluation dataset default_params_fn: An optional callable that takes in the params dictionary and updates any missing params with default values extra_args_parser_fn: An optional callable that adds any extra parser args not covered in `get_parser` fn. """ runconfig_params = params["runconfig"] RunConfigParamsValidator(extra_args_parser_fn).validate(runconfig_params) # Save the params to the summary dir runconfig_params = params["runconfig"] mode = runconfig_params["mode"] summary_dir = ( runconfig_params["summary_dir"] if ( "summary_dir" in runconfig_params and runconfig_params["summary_dir"] is not None ) else os.path.join(runconfig_params["model_dir"], mode) ) os.makedirs(summary_dir, exist_ok=True) with open(os.path.join(summary_dir, f"params_{mode}.yaml"), "w") as f: yaml.dump(params, f, default_flow_style=False, sort_keys=False) # cache summary dir for later use runconfig_params["summary_dir"] = summary_dir return run_cstorch_flow(params, model_fn, train_data_fn, eval_data_fn)
[docs]def set_wsc_log_level(log_levels: Union[List[str], Dict[str, str]]): """Assert the list of log levels is valid""" if isinstance(log_levels, dict): for task, level in log_levels.items(): level = int(level) if level.isdigit() else get_level_name(level) if task: wsc_logger.getChild(task).setLevel(level) else: wsc_logger.setLevel(level) else: raise ValueError("Invalid log levels. Input must be a dict.") # validate log level setting collect_wsc_log_settings()