# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Utilities for saving and loading checkpoints."""
import os
from typing import IO, Any, Callable, Union
import torch
from cerebras_appliance import logger
from cerebras_appliance.utils.file import StrPath, get_path_size, is_pathlike
from cerebras_appliance.utils.memory import (
get_available_memory,
with_memory_info_logged,
)
from cerebras_pytorch.backend import current_backend_impl
from cerebras_pytorch.saver.checkpoint_reader import CheckpointReader
from .pt_h5_saver import PyTorchH5Saver
from .storage import cache_deferred_tensors, use_external_link
# A file-like object, which has to implement `read`, `readline`, `tell`, and
# `seek` methods.
_CkptFileT = Union[StrPath, IO]
_MapLocT = Union[str, torch.device, Callable, dict, None]
_StateDictT = Any
[docs]def save(obj: dict, checkpoint_file: str) -> None:
"""Save a PyTorch state dict to the given file.
Args:
obj: The object to save.
checkpoint_file: The path to save the object to.
"""
backend = current_backend_impl(raise_exception=False)
# Disable external links when storing tensors to checkpoint to make
# checkpoint standalone and not dependent on external files.
with use_external_link(value=False):
if backend is None:
logger.debug(
f"No Cerebras backend found. Defaulting to using CPU for "
f"saving."
)
saver = PyTorchH5Saver()
saver.save(checkpoint_file, obj)
else:
backend.save(obj, checkpoint_file)
logger.verbose(f"Successfully saved checkpoint to {checkpoint_file}")
@with_memory_info_logged(
"loading checkpoint", info=["available", "used"], logger=logger,
)
def load(
checkpoint_file: _CkptFileT, map_location: _MapLocT = None, **kwargs,
) -> _StateDictT:
"""Load a PyTorch checkpoint from a file.
Args:
checkpoint_file: The path to the checkpoint to load.
map_location: A mapping of where to load the checkpoint content to.
If the map_location is `None`, then the tensors will be lazily loaded
from the checkpoint file every single time the tensor is accessed.
If the map_location is "cache", then the tensors will be cached
once they are lazily loaded from the checkpoint file.
If the map location is "cpu", then the tensors will be eagerly loaded
into memory from the checkpoint file.
**kwargs: Additional keyword arguments to pass to the vanilla torch
checkpoint loader. These are ignored if the checkpoint is a
Cerebras HDF5 checkpoint.
Returns:
The loaded checkpoint file.
Raises:
RuntimeError: If the checkpoint file does not exist or checkpoint is not
a valid HDF5 or vanilla torch checkpoint.
"""
if not is_pathlike(
checkpoint_file
) or not PyTorchH5Saver.is_valid_checkpoint(checkpoint_file):
logger.debug(
f"Checkpoint is not a valid HDF5 checkpoint. Falling back to "
f"normal PyTorch checkpoint loading."
)
return _torch_load(checkpoint_file, map_location, **kwargs)
else:
logger.debug(
f"Checkpoint is a valid HDF5 checkpoint. Using the HDF5 checkpoint "
f"loader."
)
cache_tensors = False
if map_location == "cache":
cache_tensors = True
map_location = None
if map_location is not None:
if isinstance(map_location, (str, torch.device)):
map_location = torch.device(map_location)
else:
raise TypeError(
f"Unsupported `map_location` provided for loading HDF5 "
f"checkpoint. Expected `None` or a torch device, but got "
f"`{map_location}`"
)
CheckpointReader.saver_cls = PyTorchH5Saver
reader = CheckpointReader(checkpoint_file)
tensor_names = reader.tensor_names
spec = reader.spec
if not spec:
raise RuntimeError(
f"Checkpoint `{checkpoint_file}` is an HDF5 file but does not "
f"conform to the Cerebras HDF5 checkpoint specification. Please "
f"ensure that the checkpoint was saved using `cstorch.save()`."
)
from cerebras_pytorch.utils.nest import recurse_spec
spec_keys = list(map(".".join, recurse_spec(spec)))
missing = set(spec_keys) - set(tensor_names)
if missing:
logger.warning(
f"The checkpoint is missing the following keys that are "
f"found in the spec: {sorted(missing)}"
)
backend = current_backend_impl(raise_exception=False)
if backend is None:
logger.debug(
"No backend has been initialized. Loading tensors onto CPU."
)
map_location = torch.device("cpu")
saver = PyTorchH5Saver()
def get_tensor(key):
if key in missing:
return None
val = saver.load_tensor(checkpoint_file, key)
if map_location is not None and isinstance(val, torch.Tensor):
val = val.to(map_location)
return val
with cache_deferred_tensors(cache_tensors):
values = list(map(get_tensor, spec_keys))
# pylint: disable=protected-access
res = torch.utils._pytree.tree_unflatten(values, spec)
logger.debug(f"Loaded HDF5 checkpoint {checkpoint_file}.")
return res
def _torch_load(
checkpoint_file: _CkptFileT, map_location: _MapLocT = None, **kwargs,
) -> _StateDictT:
"""Load a PyTorch checkpoint using vanilla torch.load.
Args:
checkpoint_file: The path to the checkpoint to load.
map_location: A mapping of where to load the checkpoint content to.
**kwargs: Additional keyword arguments to pass to torch.load.
"""
if is_pathlike(checkpoint_file) and os.path.exists(checkpoint_file):
unit = "GB"
file_size = get_path_size(checkpoint_file, unit=unit)
free_mem = get_available_memory(unit=unit)
if file_size > 10:
backend = current_backend_impl(raise_exception=False)
if backend is not None and backend.backend_type.is_csx:
extra_msg = ", could significantly slow down weight transfer,"
else:
extra_msg = ""
logger.warning(
f"Checkpoint file is a vanilla torch checkpoint and has "
f"size {file_size} {unit}. This may take a while to load"
f"{extra_msg} and may occupy a large amount of memory."
)
if file_size > free_mem:
logger.warning(
f"Checkpoint file is a vanilla torch checkpoint and has "
f"size {file_size} {unit}, which is larger than the "
f"currently available memory {free_mem} {unit}. Since "
f"torch checkpoints are loaded in their entirety into "
f"memory, this may cause out-of-memory errors."
)
try:
state_dict = torch.load(
checkpoint_file, map_location=map_location, **kwargs
)
except FileNotFoundError as e:
# Error message is already descriptive enough
raise e
except Exception as e:
raise RuntimeError(
f"Failed to load checkpoint file `{checkpoint_file}`."
) from e
logger.debug(f"Loaded checkpoint {checkpoint_file} into memory.")
return state_dict
__all__ = ["save", "load"]