Source code for cerebras_pytorch.saver

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Utilities for saving and loading checkpoints."""
import os
from typing import IO, Any, Callable, Union

import torch

from cerebras_appliance import logger
from cerebras_appliance.utils.file import StrPath, get_path_size, is_pathlike
from cerebras_appliance.utils.memory import (
    get_available_memory,
    with_memory_info_logged,
)
from cerebras_pytorch.backend import current_backend_impl
from cerebras_pytorch.saver.checkpoint_reader import CheckpointReader

from .pt_h5_saver import PyTorchH5Saver
from .storage import cache_deferred_tensors, use_external_link

# A file-like object, which has to implement `read`, `readline`, `tell`, and
# `seek` methods.
_CkptFileT = Union[StrPath, IO]
_MapLocT = Union[str, torch.device, Callable, dict, None]
_StateDictT = Any


[docs]def save(obj: dict, checkpoint_file: str) -> None: """Save a PyTorch state dict to the given file. Args: obj: The object to save. checkpoint_file: The path to save the object to. """ backend = current_backend_impl(raise_exception=False) # Disable external links when storing tensors to checkpoint to make # checkpoint standalone and not dependent on external files. with use_external_link(value=False): if backend is None: logger.debug( f"No Cerebras backend found. Defaulting to using CPU for " f"saving." ) saver = PyTorchH5Saver() saver.save(checkpoint_file, obj) else: backend.save(obj, checkpoint_file) logger.verbose(f"Successfully saved checkpoint to {checkpoint_file}")
@with_memory_info_logged( "loading checkpoint", info=["available", "used"], logger=logger, ) def load( checkpoint_file: _CkptFileT, map_location: _MapLocT = None, **kwargs, ) -> _StateDictT: """Load a PyTorch checkpoint from a file. Args: checkpoint_file: The path to the checkpoint to load. map_location: A mapping of where to load the checkpoint content to. If the map_location is `None`, then the tensors will be lazily loaded from the checkpoint file every single time the tensor is accessed. If the map_location is "cache", then the tensors will be cached once they are lazily loaded from the checkpoint file. If the map location is "cpu", then the tensors will be eagerly loaded into memory from the checkpoint file. **kwargs: Additional keyword arguments to pass to the vanilla torch checkpoint loader. These are ignored if the checkpoint is a Cerebras HDF5 checkpoint. Returns: The loaded checkpoint file. Raises: RuntimeError: If the checkpoint file does not exist or checkpoint is not a valid HDF5 or vanilla torch checkpoint. """ if not is_pathlike( checkpoint_file ) or not PyTorchH5Saver.is_valid_checkpoint(checkpoint_file): logger.debug( f"Checkpoint is not a valid HDF5 checkpoint. Falling back to " f"normal PyTorch checkpoint loading." ) return _torch_load(checkpoint_file, map_location, **kwargs) else: logger.debug( f"Checkpoint is a valid HDF5 checkpoint. Using the HDF5 checkpoint " f"loader." ) cache_tensors = False if map_location == "cache": cache_tensors = True map_location = None if map_location is not None: if isinstance(map_location, (str, torch.device)): map_location = torch.device(map_location) else: raise TypeError( f"Unsupported `map_location` provided for loading HDF5 " f"checkpoint. Expected `None` or a torch device, but got " f"`{map_location}`" ) CheckpointReader.saver_cls = PyTorchH5Saver reader = CheckpointReader(checkpoint_file) tensor_names = reader.tensor_names spec = reader.spec if not spec: raise RuntimeError( f"Checkpoint `{checkpoint_file}` is an HDF5 file but does not " f"conform to the Cerebras HDF5 checkpoint specification. Please " f"ensure that the checkpoint was saved using `cstorch.save()`." ) from cerebras_pytorch.utils.nest import recurse_spec spec_keys = list(map(".".join, recurse_spec(spec))) missing = set(spec_keys) - set(tensor_names) if missing: logger.warning( f"The checkpoint is missing the following keys that are " f"found in the spec: {sorted(missing)}" ) backend = current_backend_impl(raise_exception=False) if backend is None: logger.debug( "No backend has been initialized. Loading tensors onto CPU." ) map_location = torch.device("cpu") saver = PyTorchH5Saver() def get_tensor(key): if key in missing: return None val = saver.load_tensor(checkpoint_file, key) if map_location is not None and isinstance(val, torch.Tensor): val = val.to(map_location) return val with cache_deferred_tensors(cache_tensors): values = list(map(get_tensor, spec_keys)) # pylint: disable=protected-access res = torch.utils._pytree.tree_unflatten(values, spec) logger.debug(f"Loaded HDF5 checkpoint {checkpoint_file}.") return res def _torch_load( checkpoint_file: _CkptFileT, map_location: _MapLocT = None, **kwargs, ) -> _StateDictT: """Load a PyTorch checkpoint using vanilla torch.load. Args: checkpoint_file: The path to the checkpoint to load. map_location: A mapping of where to load the checkpoint content to. **kwargs: Additional keyword arguments to pass to torch.load. """ if is_pathlike(checkpoint_file) and os.path.exists(checkpoint_file): unit = "GB" file_size = get_path_size(checkpoint_file, unit=unit) free_mem = get_available_memory(unit=unit) if file_size > 10: backend = current_backend_impl(raise_exception=False) if backend is not None and backend.backend_type.is_csx: extra_msg = ", could significantly slow down weight transfer," else: extra_msg = "" logger.warning( f"Checkpoint file is a vanilla torch checkpoint and has " f"size {file_size} {unit}. This may take a while to load" f"{extra_msg} and may occupy a large amount of memory." ) if file_size > free_mem: logger.warning( f"Checkpoint file is a vanilla torch checkpoint and has " f"size {file_size} {unit}, which is larger than the " f"currently available memory {free_mem} {unit}. Since " f"torch checkpoints are loaded in their entirety into " f"memory, this may cause out-of-memory errors." ) try: state_dict = torch.load( checkpoint_file, map_location=map_location, **kwargs ) except FileNotFoundError as e: # Error message is already descriptive enough raise e except Exception as e: raise RuntimeError( f"Failed to load checkpoint file `{checkpoint_file}`." ) from e logger.debug(f"Loaded checkpoint {checkpoint_file} into memory.") return state_dict __all__ = ["save", "load"]