common.pytorch package#
- common.pytorch.layers package
- Submodules
- common.pytorch.layers.AlibiPositionEmbeddingLayer module
- common.pytorch.layers.AttentionHelper module
- common.pytorch.layers.AttentionLayer module
- common.pytorch.layers.BCELoss module
- common.pytorch.layers.BCEWithLogitsLoss module
- common.pytorch.layers.BiaslessLayerNorm module
- common.pytorch.layers.CosineEmbeddingLoss module
- common.pytorch.layers.CrossEntropyLoss module
- common.pytorch.layers.EmbeddingLayer module
- common.pytorch.layers.FeedForwardNetwork module
- common.pytorch.layers.GPTJDecoderLayer module
- common.pytorch.layers.GaussianNLLLoss module
- common.pytorch.layers.HingeEmbeddingLoss module
- common.pytorch.layers.HuberLoss module
- common.pytorch.layers.KLDivLoss module
- common.pytorch.layers.L1Loss module
- common.pytorch.layers.MSELoss module
- common.pytorch.layers.MarginRankingLoss module
- common.pytorch.layers.MultiLabelSoftMarginLoss module
- common.pytorch.layers.MultiMarginLoss module
- common.pytorch.layers.NLLLoss module
- common.pytorch.layers.PoissonNLLLoss module
- common.pytorch.layers.RelativePositionEmbeddingLayer module
- common.pytorch.layers.SmoothL1Loss module
- common.pytorch.layers.Transformer module
- common.pytorch.layers.TransformerDecoder module
- common.pytorch.layers.TransformerDecoderLayer module
- common.pytorch.layers.TransformerEncoder module
- common.pytorch.layers.TransformerEncoderLayer module
- common.pytorch.layers.TripletMarginLoss module
- common.pytorch.layers.TripletMarginWithDistanceLoss module
- common.pytorch.layers.utils module
- Module contents
- common.pytorch.metrics package
- Submodules
- common.pytorch.metrics.accuracy module
- common.pytorch.metrics.auc module
- common.pytorch.metrics.cb_metric module
- common.pytorch.metrics.dice_coefficient module
- common.pytorch.metrics.fbeta_score module
- common.pytorch.metrics.mean_iou module
- common.pytorch.metrics.metric_utils module
- common.pytorch.metrics.perplexity module
- common.pytorch.metrics.precision_at_k module
- common.pytorch.metrics.recall_at_k module
- common.pytorch.metrics.rouge_score module
- Module contents
- common.pytorch.model_utils package
- Subpackages
- common.pytorch.model_utils.checkpoint_converters package
- Submodules
- common.pytorch.model_utils.checkpoint_converters.base_converter module
- common.pytorch.model_utils.checkpoint_converters.bert module
- common.pytorch.model_utils.checkpoint_converters.bert_finetune module
- common.pytorch.model_utils.checkpoint_converters.gpt2_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.gpt_neox_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.gptj_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.salesforce_codegen_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.t5 module
- Module contents
- common.pytorch.model_utils.checkpoint_converters package
- Submodules
- common.pytorch.model_utils.BertPretrainModelLoss module
- common.pytorch.model_utils.GPTLMHeadModelLoss module
- common.pytorch.model_utils.RotaryPositionEmbeddingHelper module
- common.pytorch.model_utils.T5ForConditionalGenerationLoss module
- common.pytorch.model_utils.activations module
- common.pytorch.model_utils.convert_checkpoint module
- common.pytorch.model_utils.create_initializer module
- common.pytorch.model_utils.weight_initializers module
- Module contents
- Subpackages
- common.pytorch.optim package
- Subpackages
- Submodules
- common.pytorch.optim.ASGD module
- common.pytorch.optim.Adadelta module
- common.pytorch.optim.Adafactor module
- common.pytorch.optim.Adagrad module
- common.pytorch.optim.AdamBase module
- common.pytorch.optim.Adamax module
- common.pytorch.optim.CSOptimizer module
- common.pytorch.optim.Lamb module
- common.pytorch.optim.NAdam module
- common.pytorch.optim.RAdam module
- common.pytorch.optim.RMSprop module
- common.pytorch.optim.Rprop module
- common.pytorch.optim.SGD module
- common.pytorch.optim.lr_scheduler module
- common.pytorch.optim.utils module
- Module contents
- common.pytorch.summaries package
common.pytorch.PyTorchBaseModel module#
Abstract base class for PyTorch models.
- class common.pytorch.PyTorchBaseModel.Final#
Placeholder class for deprecation warning
- static __new__(mcs, name, bases, classdict)#
- class common.pytorch.PyTorchBaseModel.PyTorchBaseModel#
Base Model Definition for Cerebras runners
- __init__(params: dict, model_fn: Union[Callable[[dict], torch.nn.Module], torch.nn.Module], device: Optional[torch.device] = None)#
- property duplicate_params_map#
Returns a map of param names which hold the same tensors key and value are same as the names that appear in state_dict
- eval()#
Sets the model into eval mode, equivalent to .eval() called on a torch.nn.Module.
- get_lr_scheduler()#
Returns the LR Scheduler associated with this model.
- get_optimizer()#
Returns the optimizer associated with this model.
- get_state(keep_vars=False)#
Returns the state of the model and optimizer
- set_state(state, strict=True)#
Sets the state of the model and optimizer
- property supported_cs_modes#
Returns a list of modes that are supported for CS runs.
By default we support train and eval, however, this property is designed to be overriden on a model-by-model basis.
- property supported_modes#
Supported modes conditional on hardware backend
- property supported_non_cs_modes#
Returns a list of modes that are supported for non-CS (CPU/GPU) runs.
By default we support train, eval and train_and_eval, however, this property is designed to be overriden on a model-by-model basis.
- supports_mode(mode) bool #
Check if model supports provided mode
- train()#
Sets the model into training mode, equivalent to .train() called on a torch.nn.Module.
- trainable_named_parameters()#
Gather trainable named parameters from model
common.pytorch.gradient_clipper module#
- class common.pytorch.gradient_clipper.GradientClipper#
- __init__(max_gradient_norm: float = 0.0, max_gradient_value: float = 0.0)#
- check_amp()#
Disable GGC here if GGC + DLS is enabled by the GradScaler
- clip(params: dict)#
- set_max_gradients(max_gradient_norm: float = 0.0, max_gradient_value: float = 0.0)#
common.pytorch.half_dtype module#
Module which provides utilities for selecting half dtype between float16 and bfloat16
common.pytorch.input_utils module#
- common.pytorch.input_utils.bucketed_batch(data_iterator, batch_size, buckets=None, element_length_fn=None, collate_fn=None, drop_last=False, seed=None)#
Batch the data from an iterator such that sampels of similar length end up in the same batch. If buckets is not supplied, then this just batches the dataset normally.
- Parameters
data_iterator – An iterater that yields data one sample at a time.
batch_size (int) – The number of samples in a batch.
buckets (list) – A list of bucket boundaries. If set to None, then no bucketing will happen, and data will be batched normally. If set to a list, then data will be grouped into len(buckets) + 1 buckets. A sample s will go into bucket i if buckets[i-1] <= element_length_fn(s) < buckets[i] where 0 and inf are the implied lowest and highest boundaries respectively. buckets must be sorted and all elements must be non-zero.
element_length_fn (callable) – A function that takes a single sample and returns an int representing the length of that sample.
collate_fn (callable) – The function to use to collate samples into a batch. Defaults to PyTorch’s default collate function.
drop_last (bool) – Whether or not to drop incomplete batches at the end of the dataset. If using bucketing, buckets that are not completely full will also be dropped, even if combined there are more than batch_size samples remaining spread across multiple buckets.
seed (int) – If using drop_last = False, we don’t want to feed out leftover samples with order correlated to their lengths. The solution is to shuffle the leftover samples before batching and yielding them. This seed gives the option to make this shuffle deterministic. It is only used when buckets is not None and drop_last = True.
- Yields
Batches of samples of type returned by collate_fn, or batches of PyTorch tensors if using the default collate function.
common.pytorch.loss_utils module#
common.pytorch.perf_utils module#
- class common.pytorch.perf_utils.PerfData#
Data structure for holding performance data.
- Parameters
total_samples – Total number of samples processes.
total_time – Total time spent processing those samples.
samples_per_sec – Throuput of processing those samples.
compile_time – Time spent compiling the model.
programming_time – Time spent programming the fabric.
est_samples_per_sec – Estimated throughput based on compile and fabric.
- __init__(total_samples: int = 0, total_time: float = 0.0, samples_per_sec: float = 0.0, compile_time: float = 0.0, programming_time: float = 0.0, est_samples_per_sec: float = 0.0) None #
- compile_time: float = 0.0#
- est_samples_per_sec: float = 0.0#
- merge(other: common.pytorch.perf_utils.PerfData)#
Merge another PerfData instance into self.
- Parameters
other – The other PerfData instance to merge.
- programming_time: float = 0.0#
- samples_per_sec: float = 0.0#
- throughput_dict() dict #
- total_samples: int = 0#
- total_time: float = 0.0#
- common.pytorch.perf_utils.collect_perf_data(tracker: modelzoo.common.pytorch.cb_model.RateTracker)#
Collect performance data from a run.
- Parameters
tracker – Tracker which contains performance data.
- Returns
A PerfData instance containing the perf data.
- common.pytorch.perf_utils.save_perf(outdir: str)#
Utility method for saving performance metrics from a run.
- Parameters
outdir – Output directory to write performance files to.
common.pytorch.pytorch_base_cs_runner module#
Module containing the Base PyTorch CS Runner
- class common.pytorch.pytorch_base_cs_runner.PyTorchBaseCSRunner#
Base Class containing common elements between CS runner and CS compiler
- __init__(*args, **kwargs)#
- on_eval_batch_start(data)#
- on_eval_end(early_exit: bool)#
- on_train_batch_start(data)#
- on_train_end(early_exit: bool)#
- train_and_eval(train_data_loader:, eval_data_loader:
common.pytorch.pytorch_base_runner module#
Modulek containing the Base PyTorch Runner
- class common.pytorch.pytorch_base_runner.PyTorchBaseRunner#
The base class for running PyTorch models on any device.
Construct a PyTorchRunner instance.
- Parameters
model – The PyTorch model to run.
param – A dict of params that specify the behavior of the model.
- __init__(model: modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel, params: dict)#
Construct a PyTorchRunner instance.
- Parameters
model – The PyTorch model to run.
param – A dict of params that specify the behavior of the model.
- backward(loss)#
Runs the backward pass.
Override this method to provide any additional functionality around the backward call.
- compute_eval_metrics()#
Compute and log the eval metrics
- static create(model_fn: Callable[[dict, Optional[torch.device]], modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel], params: dict) common.pytorch.pytorch_base_runner.PyTorchBaseRunner #
Creates and returns an instance of PyTorchBaseRunner that has been configured based on the hardware specified by the provided params dictionary
- Parameters
model_fn – A callable that takes in a ‘params’ argument and optionally a torch.device which it uses to configure and return a PyTorchBaseModel
params – A dictionary containing all the parameters required to initialize and configure both the model and the runner
- eval_epoch(dataloader, epoch: Optional[int] = None)#
Runs an epoch of evaluation
- Parameters
dataloader – The dataloader to iterate through
epoch – The current epoch
- eval_forward(data)#
Runs the eval forward pass.
Override this method to provide any additional functionality around the eval forward pass call.
- evaluate(eval_dataloader:
Evaluate the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
- is_master_ordinal()#
Checks if distributed if enabled and if so whether it’s the main process, most reading and writing should only happens on main process.
- lr_scheduler_step()#
Performs the lr_scheduler step
- on_checkpoint_saved(checkpoint_path: str, step: int)#
Function to execute after a checkpoint is saved.
- on_eval_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)#
Actions to perform after the eval batch iteration is complete
- on_eval_batch_start(data)#
Optionally pre-process data before eval batch start
- on_eval_end(early_exit: bool)#
Function to execute after eval ends
- on_eval_epoch_end(early_exit: bool)#
Function to execute after the eval epoch ends
- on_eval_epoch_start()#
Function to execute before the eval epoch begins
- on_eval_start()#
Function to execute before eval starts
- on_train_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)#
Actions to perform after the train batch iteration is complete
- on_train_batch_start(data)#
Optionally pre-process data before train batch start
- on_train_end(early_exit: bool)#
Function to execute after training ends
- on_train_epoch_end(early_exit: bool)#
Function to execute after the training epoch ends
- on_train_epoch_start()#
Function to execute before the training epoch begins
- on_train_start()#
Function to execute before training starts
- optimizer_step()#
Performs the optimizer step
- optimizer_zero_grad()#
Zeroes out the gradients in the optimizer
- print_eval_metrics(eval_metrics)#
Compute and log the eval metrics
- train(train_dataloader:
Train the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
- train_and_eval(train_dataloader:, eval_dataloader:
Train and evaluate the model with data generated by dataloaders.
In each epoch, this method trains the model first, then runs evaluation every epoch.
- Parameters
train_dataloader – A data loader for generating training data to feed to the model.
eval_dataloader – A data loader for generating evaluation data to feed to the model.
- train_epoch(epoch: int, dataloader: bool #
Runs an epoch of training
- Parameters
epoch – The current epoch
dataloader – The dataloader to iterate through
- train_forward(data)#
Runs the train forward pass.
Override this method to provide any additional functionality around the train forward pass call.
common.pytorch.pytorch_cs_appliance module#
Contains the CS Appliance mode runner
- class common.pytorch.pytorch_cs_appliance.PyTorchCSAppliance#
Class for compiling PyTorch models for Cerebras hardware.
- __init__(model: modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel, params: dict)#
- backward(loss)#
- compute_eval_metrics()#
- eval_forward(data)#
- evaluate(eval_dataloader:
- get_input_fn_params()#
Construct the input function params using params dictionary
- get_loss_value() torch.Tensor #
Fetch all activations and return the loss value.
- is_master_ordinal()#
- lr_scheduler_step()#
- maybe_get_loss_value(step) torch.Tensor #
Fetch loss value if its a fetch step otherwise return None.
- on_eval_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)#
- on_eval_batch_start(data)#
- on_eval_end(early_exit=False)#
- on_eval_epoch_end(early_exit: bool)#
- on_eval_start()#
- on_train_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)#
- on_train_batch_start(data)#
- on_train_end(early_exit=False)#
- on_train_epoch_end(early_exit: bool)#
- on_train_start()#
- optimizer_step()#
- optimizer_zero_grad()#
- train(train_dataloader: None #
- train_forward(data)#
common.pytorch.pytorch_cs_compiler module#
Contains the CS Compiler
- class common.pytorch.pytorch_cs_compiler.PyTorchCSCompiler#
Class for compiling PyTorch models for Cerebras hardware.
- __init__(*args, **kwargs)#
- compute_eval_metrics()#
- eval_forward(data)#
- evaluate(dataloader:
- on_eval_batch_end(*args, **kwargs)#
- on_eval_epoch_end(early_exit: bool)#
- on_eval_start()#
- on_train_batch_end(*args, **kwargs)#
- on_train_epoch_end(early_exit: bool)#
- on_train_start()#
- train(dataloader:
common.pytorch.pytorch_cs_runner module#
Contains the CS PyTorch Runner
- class common.pytorch.pytorch_cs_runner.PyTorchCSRunner#
Class for running PyTorch models on Cerebras hardware.
- evaluate(dataloader:
- on_eval_end(early_exit: bool)#
- on_eval_epoch_end(early_exit: bool)#
- on_eval_start()#
- on_train_end(early_exit: bool)#
- on_train_epoch_end(early_exit: bool)#
- on_train_start()#
- optimizer_step()#
- train(dataloader: None #
common.pytorch.pytorch_dist_runner module#
- class common.pytorch.pytorch_dist_runner.PyTorchDistRunner#
Class for running PyTorch models on multiple GPUs.
- __init__(model, params)#
- compute_eval_metrics()#
Compute and log the eval metrics
- evaluate(eval_data_fn)#
Evaluate the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
- is_master_ordinal()#
Checks if distributed if enabled and if so whether it’s the main process, most reading and writing should only happens on main process.
- on_eval_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)#
Actions to perform after the eval batch iteration is complete
- on_eval_batch_start(data)#
- on_eval_end(early_exit: bool)#
- on_process_start(all_metrics)#
- on_train_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)#
Actions to perform after the train batch iteration is complete
- on_train_batch_start(data)#
- on_train_end(early_exit: bool)#
- on_train_epoch_end(early_exit: bool)#
- on_train_epoch_start()#
- train(train_data_fn)#
- train_and_eval(train_data_fn, eval_data_fn)#
Train the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
common.pytorch.pytorch_runner module#
- class common.pytorch.pytorch_runner.PyTorchRunner#
Class for running PyTorch models on CPU/GPU.
- __init__(device, model, params)#
- eval_forward(data)#
- on_eval_batch_start(data)#
- on_eval_end(early_exit: bool)#
- on_train_batch_start(data)#
- on_train_end(early_exit: bool)#
- on_train_start()#
- train_forward(data)#
common.pytorch.run_cstorch_flow module#
Generic run scripts build using the cstorch API
- common.pytorch.run_cstorch_flow.compute_grad_norm(model)#
Compute the model wise and per layer norm of the gradients
- common.pytorch.run_cstorch_flow.compute_params_norm(model)#
Compute the model wise norm of the parameters
- common.pytorch.run_cstorch_flow.optimizer_step_with_summaries(loss: torch.Tensor, optimizer: cstorch.optim.Optimizer, grad_scaler: cstorch.amp.GradScaler, max_gradient_norm: float = None, max_gradient_value: float = None, log_summaries: bool = False, model: torch.nn.Module = None)#
Customized equivalent to cstorch.amp.optimizer_step additionally featuring grad norm summaries
- common.pytorch.run_cstorch_flow.run_cstorch_eval(params, model_fn, input_fn)#
Runs the evaluatiion workflow built using the cstorch API
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
input_data_fn – A callable that takes in the param dictionary and returns a
- common.pytorch.run_cstorch_flow.run_cstorch_flow(params, model_fn, train_data_fn, eval_data_fn)#
Set up the cstorch run and call the appropriate helper based on the mode
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
train_data_fn – A callable that takes in the param dictionary and returns a
eval_data_fn – A callable that takes in the param dictionary and returns a
- common.pytorch.run_cstorch_flow.run_cstorch_train(params, model_fn, input_fn)#
Runs the training workflow built using the cstorch API
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
input_data_fn – A callable that takes in the param dictionary and returns a
common.pytorch.run_utils module#
Utilities for running Cerebras Pytorch Models
- common.pytorch.run_utils.arg_filter(arg: str, keyword: str) bool #
Checks if a given arg matches the given keyword
- Callable[[dict], torch.nn.Module], train_data_fn: Optional[Callable[[dict],]] = None, eval_data_fn: Optional[Callable[[dict],]] = None, default_params_fn: Optional[Callable[[dict], dict]] = None, extra_args_parser_fn: Optional[Callable[[], List[argparse.ArgumentParser]]] = None)#
Entry point to running pytorch models
- common.pytorch.run_utils.run_base_model_flow(params, model_fn, train_data_fn, eval_data_fn)#
Runs PytorchBaseModel and Runner flow
- common.pytorch.run_utils.run_with_params(params: Dict[str, Any], model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[Callable[[dict],]] = None, eval_data_fn: Optional[Callable[[dict],]] = None)#
Runs a full end-to-end CS/non-CS workflow for a given model
- Parameters
model_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a torch.nn.Module
train_data_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a PyTorch dataloader corresponding to the training dataset
eval_data_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a PyTorch dataloader corresponding to the evaluation dataset
default_params_fn – An optional callable that takes in the params dictionary and updates any missing params with default values
extra_args_parser_fn – An optional callable that adds any extra parser args not covered in get_parser fn.
- common.pytorch.run_utils.sideband_eval_all(filename: str, arguments: List[str], params: Dict[Any, Any])#
Temporary support for running eval multiple times via subprocess
- common.pytorch.run_utils.sideband_train_eval_all(filename: str, arguments: List[str], params: Dict[Any, Any])#
Temporary support for running train and eval multiple times via subprocess
- common.pytorch.run_utils.update_sideband_mode_arg(arguments: List[str], new_mode_arg: str, old_mode: str) List[str] #
Updates sideband arguments to a different mode
common.pytorch.summary_collection module#
common.pytorch.utils module#
General purpose Pytorch Utilities
- class common.pytorch.utils.BufferedShuffleDataset#
Dataset shuffled from the original dataset.
This class is useful to shuffle an existing instance of an IterableDataset. The buffer with buffer_size is filled with the items from the dataset first. Then, each item will be yielded from the buffer by reservoir sampling via iterator. buffer_size is required to be larger than 0. For buffer_size == 1, the dataset is not shuffled. In order to fully shuffle the whole dataset, buffer_size is required to be greater than or equal to the size of dataset. When it is used with
, each item in the dataset will be yielded from theDataLoader
iterator. And, the method to set up a random seed is different based onnum_workers
. For single-process mode (num_workers == 0
), the random seed is required to be set before theDataLoader
in the main process.- Parameters
dataset (IterableDataset) – The original IterableDataset.
buffer_size (int) – The buffer size for shuffling.
For multi-process mode (
num_workers > 0
), the random seed is set by a callable function in each worker.>>> ds = BufferedShuffleDataset(dataset) >>> random.seed(...) >>> print(list(, num_workers=0))) >>> ds = BufferedShuffleDataset(dataset) >>> def init_fn(worker_id): ... random.seed(...) >>> print(list(, ..., num_workers=n, worker_init_fn=init_fn)))
- __init__(dataset, buffer_size)#
- class common.pytorch.utils.RunConfigParamsValidator#
Validate Run Configs
- __init__()#
- validate(config)#
Validate params match existing schema
- class common.pytorch.utils.SampleGenerator#
Iterator which returns multiple samples of a given input data.
Can be used in place of a PyTorch DataLoader to generate synthetic data.
- Parameters
data – The data which should be returned at each iterator step.
sample_count – The maximum number of data samples to be returned.
- __init__(data, sample_count)#
- next()#
Generate next data sample
- common.pytorch.utils.get_checkpoints(model_dir: str) List[str] #
Gather checkpoints in a model directory
- common.pytorch.utils.get_debug_args(debug_args_path, debug_ini_path)#
Appliance mode DebugArgs.
- common.pytorch.utils.get_input_dtype(to_float16: bool)#
Determine input datatype based on environment
- common.pytorch.utils.group_optimizer_params(trainable_params, no_decay_layers, weight_decay_rate)#
Group optimizer with associated parameters
- common.pytorch.utils.setup_logging(chief_logging_level: str, streamer_logging_level: str, logging_dir: Optional[str] = None)#
Configure default logging format
- common.pytorch.utils.to_cpu(tensor)#
Move tensor from device to cpu
- common.pytorch.utils.to_tensor(value, device=None)#
If the provided value is a Python int or float, it converts them into PyTorch Tensors of type int32 and float32 respectively. Otherwise, it just returns the value.
- common.pytorch.utils.trainable_named_parameters(model)#
Returns the traininable named parameters for the model as well as the modules that contain normalization
- common.pytorch.utils.visit_structure(data_structure: Union[Any, list, tuple, dict], select_fn: Callable[[Any], bool], strict: bool = False, scope: Optional[List[str]] = None) Generator[Tuple[List[str], Any], None, None] #
Recursively traverse nested structure and return the items accepted by the selector.
- Parameters
data_structure – A nested data structure to traverse recursively.
select_fn – A callable that returns true if the item passed should be selected.
strict – Strictly checks that an item in the nested structure is either a list/dict/tuple or selected by the select_fn. Otherwise, raises an error. Defaults to False.
scope – The current hierarchical scope of the data structure. Defaults to None.
- Yields
A tuples of (scope, item) for each item selected by the select_fn.