# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Implementations of the containers for appliance data tensors
"""
import copy
import os
import weakref
from numbers import Number
from typing import List, Optional, TextIO, Type, Union
import dill
import h5py as h5
import numpy
import torch
from torch.utils._pytree import tree_map
import cerebras_pytorch as cstorch
from cerebras_appliance.data.dtypes import bf16, is_bf16
from cerebras_appliance.saver.h5_saver import NumpyArrayH5Type, register_h5_type
from cerebras_pytorch.backend import current_torch_device, use_cs
from cerebras_pytorch.utils._contexts import BooleanContext
from .pt_h5_saver import TorchTensorH5Type
# Flag for controlling whether to store tensors to H5 via external links.
use_external_link = BooleanContext(default=False)
# Flag for controlling whether to pickle cstorch tensors as torch tensors.
use_cstorch_types = BooleanContext(default=False)
# Flag for controlling whether to cache deferred tensors.
cache_deferred_tensors = BooleanContext(default=True)
@register_h5_type()
class StoredTensorH5Type:
"""Class for loading custom torch.Tensor's from previous releases."""
@staticmethod
def save(tensor, f: h5.File, key: str, **kwargs):
raise NotImplementedError
@staticmethod
def load(f: h5.File, key: str):
return DeferredH5Tensor(f.filename, key)
@register_h5_type()
class StoredApplianceTensorH5Type:
"""Class for loading custom torch.Tensor's from previous releases."""
@staticmethod
def save(tensor, f: h5.File, key: str, **kwargs):
raise NotImplementedError
@staticmethod
def load(f: h5.File, key: str) -> torch.Tensor:
return DeferredFileTensor.load(f, key)
@register_h5_type()
class FullTensorH5Type:
"""Class for loading custom torch.Tensor's from previous releases."""
@staticmethod
def save(tensor, f: h5.File, key: str, **kwargs):
raise NotImplementedError
@staticmethod
def load(f: h5.File, key: str):
return DeferredFullTensor.load(f, key)
[docs]def full(shape, value: float, dtype=None):
"""
Returns an lazily initialized tensor filled with the provided value
Args:
shape: The shape of the tensor.
value: The value to fill the tensor with.
dtype: The dtype of the tensor.
"""
if not use_cs():
return torch.full(shape, value, dtype=dtype)
return DeferredFullTensor(shape, dtype=dtype, value=value).to(
current_torch_device()
)
[docs]def full_like(other: torch.Tensor, value: float, dtype=None):
"""
Returns an lazily initialized full tensor with the same properties as the
provided tensor
Args:
other: The tensor to copy the properties from
value: The value to fill the tensor with
dtype: The dtype of the tensor. If not provided, the dtype of the other
tensor is used
"""
if not dtype:
dtype = other.dtype
if not use_cs():
return torch.full_like(other, value, dtype=dtype)
return DeferredFullTensor(other.shape, dtype=dtype, value=value).to(
current_torch_device()
)
[docs]def ones(shape, dtype=None):
"""
Returns an lazily initialized tensor filled with ones
Args:
shape: The shape of the tensor
dtype: The dtype of the tensor
"""
if not use_cs():
return torch.ones(shape, dtype=dtype)
return DeferredFullTensor(shape, dtype=dtype, value=1).to(
current_torch_device()
)
[docs]def ones_like(other: torch.Tensor, dtype=None):
"""
Returns an lazily initialized tensor full of ones with the same properties
as the provided tensor
Args:
other: The tensor to copy the properties from
dtype: The dtype of the tensor. If not provided, the dtype of the other
tensor is used
"""
if not dtype:
dtype = other.dtype
if not use_cs():
return torch.ones_like(other, dtype=dtype)
return DeferredFullTensor(other.shape, dtype=dtype, value=1).to(
current_torch_device()
)
[docs]def zeros(shape, dtype=None):
"""
Returns an lazily initialized tensor filled with zeros
Args:
shape: The shape of the tensor
dtype: The dtype of the tensor
"""
if not use_cs():
return torch.zeros(shape, dtype=dtype)
return DeferredFullTensor(shape, dtype=dtype, value=0).to(
current_torch_device()
)
[docs]def zeros_like(other: torch.Tensor, dtype=None):
"""
Returns an lazily initialized tensor full of zeros with the same properties
as the provided tensor
Args:
other: The tensor to copy the properties from
dtype: The dtype of the tensor. If not provided, the dtype of the other
tensor is used
"""
if not dtype:
dtype = other.dtype
if not use_cs():
return torch.zeros_like(other, dtype=dtype)
return DeferredFullTensor(other.shape, dtype=dtype, value=0).to(
current_torch_device()
)
def lazy_tensor_data_wrapper(
tensor: Union[torch.Tensor, "cerebras_pytorch_lib.ApplianceDataInfo"]
) -> torch.Tensor:
"""A wrapper for tensors that returns the underlying CPU view.
Args:
tensor: The tensor to return the CPU view of. If tensor is on cpu, it
is returned as is. If tensor is on a lazy device, the underlying
ApplianceDataInfo object is queried first, then the CPU view of it
is returned.
Returns:
The CPU view of the tensor. Modifying this tensor will modify the
lazy tensor's device data.
"""
from cerebras_pytorch import cerebras_pytorch_lib
if isinstance(tensor, torch.Tensor):
if tensor.device.type == "lazy":
app_data = cerebras_pytorch_lib.get_appliance_data(tensor)
else:
return tensor
elif isinstance(tensor, cerebras_pytorch_lib.ApplianceDataInfo):
app_data = tensor
else:
raise ValueError(
f"Attempting to create a lazy tensor wrapper for a value of type "
f"{type(tensor)}, but one of torch.Tensor and ApplianceDataInfo "
f"was expected."
)
if app_data.filename:
# Currently, LTC creates a file-backed tensor internally with a normal
# torch.Tensor type, so we need to wrap it in a DeferredFileTensor to
# avoid copying when creating initial state.
return DeferredFileTensor(
app_data.filename, app_data.tensor.size(), app_data.tensor.dtype
)
else:
# This tensor may be one of the DeferredTensor's below
# Or raise an exception if tensor data does not exist
return app_data.tensor
def has_lazy_tensor_data_impl(tensor: torch.Tensor) -> bool:
"""Returns True if the lazy tensor has data it can use to create a CPU view"""
if isinstance(tensor, torch.Tensor) and tensor.device.type == "lazy":
from cerebras_pytorch import cerebras_pytorch_lib
if cerebras_pytorch_lib.has_backend_data(tensor):
app_data = cerebras_pytorch_lib.get_appliance_data(tensor)
return app_data.filename is not None or app_data.has_tensor
return False
class DeferredTensor(torch.Tensor):
"""A deferred tensor that is lazily materialized on the CPU.
This is a base class for a tensor that provides a recipe for getting its
value. The tensor is not materialized until some torch operation is called
on it, at which point it's materialized to CPU and all subsequent accesses
are applied to to the materialized CPU tensor.
Deferred tensors are especially useful when moving to lazy tensors. Instead
of incurring copies, the tensor handle is stored in the lazy tensor. If the
tensor is materialized and modified, moving the tensor to lazy incurs a full
copy because at that point the recipe is already out of data.
NOTE: that all subclass names must start with "Deferred" and end with
"Tensor" to be recognized by appliance data to avoid copying when moving to
lazy device.
"""
# If __torch_dispatch__ is defined, the default torch function
# implementation (which preserves subclasses) typically must be disabled.
__torch_function__ = torch._C._disabled_torch_function_impl
def __init__(self):
super().__init__()
# The cpu tensor that is materialized when the tensor is accessed.
self._tensor: Optional[torch.Tensor] = None
# Keep track of any changes to the CPU tensor. If there are changes,
# when moving to lazy, we need to use the CPU tensor. Otherwise, we
# use the original data.
self._is_dirty = False
# If True, we cache the CPU tensor. This is useful for tensors that
# are accessed multiple times, or are modified inplace.
# However, it can cause memory issues if the tensor is large, or if
# many larger tensors are cached.
self._cache_tensor = bool(cache_deferred_tensors)
@property
def is_modified(self) -> bool:
"""Returns True if the tensor has been materialized and modified."""
return self._tensor is not None and self._is_dirty
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
tensor_handles: List[cls] = []
def unwrap(t):
if isinstance(t, cls):
cpu_handle = t._materialize()
tensor_handles.append((t, cpu_handle))
return cpu_handle
return t
res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for t, c in tensor_handles:
if (
# In-place modification
(
hasattr(func, "_schema")
and func._schema.name[-1] == "_"
and t is args[0]
)
or ( # built-in inplace ops don't have schemas. Check name attribute
hasattr(func, "__name__")
and func.__name__[-1] == "_"
and t is args[0]
)
# Modified through being an output of an operation
or ("out" in kwargs and t is kwargs["out"])
):
# Explicitly cache a tensor if it's modified inplace
t._cache_tensor = True
t._tensor = c
t._is_dirty = True
return res
def save(
self, f: h5.File, name: str, **kwargs
) -> Union[TorchTensorH5Type, None]:
"""Saves the tensor to an H5 file.
If tensor is materialized on CPU and modified, this uses a normal torch
tensor H5 type to save and returns the type used for saving. Otherwise,
it uses the deferred tensor type and returns None, so that subsequent
loads use the same type.
Args:
f: The H5 file to save to.
name: The name of the dataset to save to.
**kwargs: Additional arguments to pass to the H5 save function for
compression.
Returns:
The H5 type used to save the tensor, or None if type(self) was used
to save.
"""
if self._is_dirty:
return TorchTensorH5Type.save(self._tensor, f, name, **kwargs)
return self._save(f, name, **kwargs)
def _materialize(
self, cache_override: Optional[bool] = None
) -> torch.Tensor:
"""Returns the materialized CPU tensor.
If tensor was already materialized, this returns the already cached
tensor. Otherwise, it materializes the tensor, (conditionally) caches
it, and returns it.
Args:
cache: Whether to override the default cache settings when materializing.
"""
tensor = self._to_cpu() if self._tensor is None else self._tensor
tensor.requires_grad = self.requires_grad
should_cache = (
cache_override if cache_override is not None else self._cache_tensor
)
if should_cache:
self._tensor = tensor
return tensor
############################################################################
# torch.Tensor overrides #
############################################################################
def to(self, *args, **kwargs):
"""Overrides the default to() implementation to handle lazy tensors."""
device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "lazy":
from cerebras_pytorch.backend import current_backend
from cerebras_pytorch.lib import cerebras_pytorch_lib
with current_backend().device:
if not self.is_modified:
# This custom implementation creates a new lazy tensor whose
# underlying device data is set to this tensor handle. This
# avoids copying any data. Note that moving to lazy deletes
# the storage of "self". This is fine because the storage
# is never directly used as any CPU operation is done on
# the materialized tensor.
return cerebras_pytorch_lib.eager_to_lazy(self)
else:
# If materialized tensor has been modified, we need to use
# the default implementation which copies the data.
return super().to(*args, **kwargs)
return super().to(*args, **kwargs)
def numpy(self) -> numpy.ndarray:
"""Implements numpy() for deferred tensors."""
# Set dirty to True as it is possible that the numpy array
# will be modified inplace
self._is_dirty = True
return self._materialize().numpy()
def tolist(self) -> list:
"""Implements tolist() for deferred tensors."""
return self._materialize().tolist()
def clone(self) -> "DeferredTensor":
"""Implements clone() for deferred tensors."""
if not self.is_modified:
cloned = self._clone()
cloned.requires_grad = self.requires_grad
return cloned
return super().clone()
def detach(self) -> torch.Tensor:
"""Implements detach() for deferred tensors.
Note that this currently falls back to the original implementation,
which materializes the tensor. The contract of detach is that the
returned tensor shares the same storage with the original one. However,
imagine the following case:
1. A is a deferred tensor not materialized yet.
2. B = A.detach() is called
3. A += 1 is called, which materialies A
In this sequence, B does not see the modification to A. To avoid this
issue, we currently materialize the tensor when detach() is called.
"""
return super().detach()
def __deepcopy__(self, memo: dict) -> "DeferredTensor":
"""Implements deepcopy() for deferred tensors."""
if not self.is_modified:
return memo.setdefault(id(self), self._clone())
new_tensor = copy.deepcopy(self._materialize(), memo)
new_tensor.requires_grad = self.requires_grad
return new_tensor
def __reduce_ex__(self, protocol):
"""Implements __reduce_ex__() for deferred tensors.
This add special pickling support for deferred tensors (e.g., used in
torch.save()). If saving cstorch types is allowed, the tensor subclass
is pickled as is. Otherwise, the tensor is materialized and the class
is pickled as a normal torch tensor. This is to avoid strict dependency
on cstorch types in checkpoints when needed.
"""
if use_cstorch_types:
return super().__reduce_ex__(protocol)
return self._materialize().__reduce_ex__(protocol)
############################################################################
# Abstract methods to override #
############################################################################
def _save(self, f: h5.File, name: str, **kwargs) -> None:
"""Saves the tensor to an H5 file.
This is called when the tensor has not been previously not materialized
on CPU, which means the deferred type can be saved to H5 for further
retrieval.
"""
raise NotImplementedError
@staticmethod
def load(f: h5.File, key: str) -> Type["DeferredTensor"]:
"""Loads a tensor from an H5 file.
Args:
f: The H5 file to load from.
key: The dataset name that holds the tensor value.
"""
raise NotImplementedError
def _to_cpu(self) -> torch.Tensor:
"""Materializes the tensor to CPU and returns it."""
raise NotImplementedError
def _clone(self) -> "DeferredTensor":
"""Clones the non-materialized tensor and returns it."""
raise NotImplementedError
@register_h5_type()
class DeferredFileTensor(DeferredTensor):
"""A deferred tensor whose data is stored in a binary file."""
def __new__(cls, filepath: str, size: torch.Size, dtype: torch.dtype):
data = torch.empty(size, dtype=dtype, device="cpu")
return cls._make_subclass(cls, data, require_grad=False)
def __init__(self, filepath: str, size: torch.Size, dtype: torch.dtype):
"""Constructs a `DeferredFileTensor` instance.
Args:
filepath: The path to the binary file that holds the tensor data.
size: The size of the tensor.
dtype: The data type of the tensor.
"""
super().__init__()
self._filepath = filepath
# Store the last modified time of the file so we can check if the file
# has been modified since the tensor was created before materializing it
self._last_modified = os.path.getmtime(filepath)
def _save(self, f: h5.File, name: str, **kwargs) -> None:
if not use_external_link or not self.shape:
# When external links are disabled, we need to materialize the
# tensor and save it to file. But note that we don't cache the
# materialized tensor to avoid OOM.
return TorchTensorH5Type.save(
self._materialize(cache_override=False), f, name, **kwargs
)
dset = f.create_dataset(name, data=h5.Empty("f"))
dset.attrs["filepath"] = self._filepath
dset.attrs["shape"] = tuple(self.shape)
dset.attrs["dtype"] = dill.dumps(self.dtype).hex()
@staticmethod
def load(f: h5.File, key: str) -> "DeferredFileTensor":
dataset = f[key]
return DeferredFileTensor(
filepath=dataset.attrs["filepath"],
size=torch.Size(dataset.attrs["shape"]),
dtype=dill.loads(bytes.fromhex(dataset.attrs["dtype"])),
)
def _to_cpu(self) -> torch.Tensor:
modified_time = os.path.getmtime(self._filepath)
if modified_time > self._last_modified:
raise RuntimeError(
f"Attempting to materialize deferred tensor from file "
f"from file {self._filepath}, but the file has "
f"since been modified. The loaded tensor value may be "
f"different from originally loaded tensor. Please refrain "
f"from modifying the file while the run is in progress."
)
# Return a read-only file-backed tensor. Upon write, the tensor will
# be converted to an in-memory tensor.
return torch.from_file(
self._filepath,
shared=False, # Opens in read-only mode
size=self.shape.numel(),
dtype=self.dtype,
).reshape(self.shape)
def _clone(self) -> "DeferredFileTensor":
cloned = DeferredFileTensor(self._filepath, self.shape, self.dtype)
cloned.requires_grad = self.requires_grad
return cloned
@register_h5_type()
class DeferredFullTensor(DeferredTensor):
"""A deferred torch.full() tensor."""
def __new__(
cls,
size: torch.Size,
dtype: Optional[torch.dtype] = None,
value: Optional[Number] = None,
):
data = torch.empty(size, dtype=dtype, device="cpu")
return cls._make_subclass(cls, data, require_grad=False)
def __init__(
self,
size: torch.Size,
dtype: Optional[torch.dtype] = None,
value: Optional[Number] = None,
):
"""Constructs a `DeferredFullTensor` instance.
Args:
size: The size of the tensor.
dtype: The data type of the tensor. If not specified, defaults to
the default torch dtype.
value: The value to fill the tensor with. If not specified, defaults
to uninitialized data.
"""
super().__init__()
self._value = value
@property
def fill_value(self) -> Number:
"""Returns the fill value."""
return self._value
def _save(self, f: h5.File, name: str, **kwargs) -> None:
np_dtype = torch_to_np_dtype(self.dtype)
dset = f.create_dataset(name, dtype=np_dtype)
dset.attrs["shape"] = tuple(self.shape)
dset.attrs["fill_value"] = self._value
dset.attrs["is_bfloat16"] = is_bf16(np_dtype)
@staticmethod
def load(f: h5.File, key: str) -> "DeferredFullTensor":
dset = f[key]
size = torch.Size(dset.attrs["shape"])
value = dset.attrs["fill_value"]
np_dtype = dset.dtype
if dset.attrs["is_bfloat16"]:
np_dtype = bf16
dtype = _np_to_torch_dtype(np_dtype)
return DeferredFullTensor(size, dtype=dtype, value=value)
def _to_cpu(self) -> torch.Tensor:
if self._value is None:
return torch.empty(self.shape, dtype=self.dtype)
elif self._value == 0:
return torch.zeros(self.shape, dtype=self.dtype)
elif self._value == 1:
return torch.ones(self.shape, dtype=self.dtype)
else:
return torch.full(self.shape, self._value, dtype=self.dtype)
def _clone(self) -> "DeferredFullTensor":
cloned = DeferredFullTensor(self.shape, self.dtype, self._value)
cloned.requires_grad = self.requires_grad
return cloned
class _CachingFileOpener:
"""File opener that reuses file descriptors from previous opened files."""
def __init__(self):
# Keep a weak reference to the file descriptors
self._open_files = weakref.WeakValueDictionary()
def __call__(self, path, *args, **kwargs):
"""Opens the file (or reuses previously opened file) and returns the file descriptor."""
stat = os.stat(path)
key = (path, stat.st_mtime_ns)
if key in self._open_files:
return self._open_files[key]
fp = open(path, *args, **kwargs)
self._open_files[key] = fp
return fp
@register_h5_type()
class DeferredH5Tensor(DeferredTensor):
"""A deferred tensor whose data is stored in an H5 file."""
# Class property for opening files using cached descriptors. This is to avoid opening the same
# file each time and instead reusing the descriptor for it. The cache helps with memory usages.
_FILE_OPENER = _CachingFileOpener()
def __new__(cls, filepath: str, key: str, fp: Optional[TextIO] = None):
if fp is not None:
ctx = h5.File(fp, "r")
elif h5.is_hdf5(filepath):
ctx = h5.File(filepath, "r")
else:
raise ValueError(f"{filepath} is not a valid HDF5 file.")
with ctx as f:
size = f[key].shape
np_dtype = f[key].dtype
if f[key].attrs.get("is_bfloat16"):
np_dtype = bf16
dtype = _np_to_torch_dtype(np_dtype)
data = torch.empty(size, dtype=dtype, device="cpu")
return cls._make_subclass(cls, data, require_grad=False)
def __init__(self, filepath: str, key: str, fp: Optional[TextIO] = None):
"""Constructs a `DeferredH5Tensor` instance.
Args:
filepath: The path to the H5 file that holds the tensor data.
key: The dataset name with which to retrieve the tensor value from the H5 file.
fp: An optional file pointer to the opened filepath. If provided, `filepath` is not
opened and the file pointer is used instead.
"""
super().__init__()
self._filepath = filepath
self._key = key
# We keep the file open to avoid it being deleted from under us. DeferredH5Tensor's are
# generally used for lazily loading values from checkpoints. By keeping this reference,
# we're avoiding the case where during training someone deletes the original checkpoint
# then we go and save this tensor to a new checkpoint but fail because the original was
# deleted.
# Note: H5 doesn't allow keeping a file open in different modes. So to keep the file open,
# we use a regular `open()` instead of `h5.File()`.
self._fp = self._FILE_OPENER(filepath, "rb") if fp is None else fp
# Store the last modified time of the file so we can check if the file
# has been modified since the tensor was created before materializing it
self._stat = os.fstat(self._fileno)
@property
def _fileno(self) -> int:
"""Returns the file number of the unerlyding file descriptor."""
return self._fp.fileno()
def __getstate__(self):
clear_cache = False
if self._tensor is None and (
not os.path.exists(self._filepath)
or os.stat(self._filepath).st_mtime_ns > self._stat.st_mtime_ns
):
self._materialize(cache_override=True)
clear_cache = True
state = self.__dict__.copy()
# Delete file descriptors since we need to reopen when unpickling. The attribute might not
# exist if this is an unpickled instance whose `self._tensor` existed. See `__setstate__`.
state.pop("_fp", None)
if clear_cache:
self._tensor = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
# If tensor is non-materialized, we need to ensure the backing file still exists and its
# timestamp hasn't changed from before.
if self._tensor is None:
if not os.path.exists(self._filepath):
raise RuntimeError(
f"Attempting to unpickle a deferred tensor whose backing file {self._filepath} "
f"no longer exists."
)
self._fp = self._FILE_OPENER(self._filepath, "rb")
# Here we're checking against the file stats now vs when the original tensor that was
# pickled and we loaded into `self._stat` now.
self._check_file_modification("unpickle")
def _save(
self, f: h5.File, name: str, **kwargs
) -> Union[None, Type[TorchTensorH5Type]]:
if not use_external_link:
# When external links are disabled, we need to materialize the
# tensor and save it to file. But note that we don't cache the
# materialized tensor to avoid OOM. This mainly happens when we load
# from an initial H5 checkpoint and save the initial weights to
# another H5 checkpoint.
return TorchTensorH5Type.save(
self._materialize(cache_override=False), f, name, **kwargs
)
dset = f.create_dataset(name, data=h5.Empty("f"))
dset.attrs["filepath"] = self._filepath
dset.attrs["key"] = self._key
# TODO: Need to save fstat and compare upon load to avoid loading tampered file
return None
@staticmethod
def load(f: h5.File, key: str) -> "DeferredH5Tensor":
dset = f[key]
return DeferredH5Tensor(dset.attrs["filepath"], dset.attrs["key"])
def _to_cpu(self) -> torch.Tensor:
self._check_file_modification("materialize")
with h5.File(self._fp, "r") as f:
return cstorch.from_numpy(NumpyArrayH5Type.load(f, self._key))
def _check_file_modification(self, msg):
"""Check whether the backing file has been modified since the tensors was created."""
stat = os.fstat(self._fileno)
if stat.st_mtime_ns > self._stat.st_mtime_ns:
raise RuntimeError(
f"Attempting to {msg} deferred tensor with key "
f"\"{self._key}\" from file {self._filepath}, but the file has "
f"since been modified. The loaded tensor value may be "
f"different from originally loaded tensor. Please refrain "
f"from modifying the file while the run is in progress."
)
def _clone(self) -> "DeferredH5Tensor":
# Use the file descriptor, since the filepath may have been unlinked. But since we hold the
# descriptor open, the file itself hasn't been deleted.
self._check_file_modification("clone")
cloned = DeferredH5Tensor(self._filepath, self._key, fp=self._fp)
cloned.requires_grad = self.requires_grad
return cloned
def torch_to_np_dtype(dtype: torch.dtype) -> numpy.dtype:
"""Converts a torch dtype to a numpy dtype."""
return cstorch.to_numpy(torch.empty(0, dtype=dtype)).dtype
def _np_to_torch_dtype(dtype: numpy.dtype) -> torch.dtype:
"""Converts a numpy dtype to a torch dtype."""
return cstorch.from_numpy(numpy.empty(0).astype(dtype)).dtype