Source code for modelzoo.common.pytorch.model_utils.checkpoint_converters.streaming_checkpoints

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import math
import os
import re
from typing import Union

import safetensors.torch as safetensors_torch
import torch

from cerebras_appliance.utils.units import convert_byte_unit
from cerebras_pytorch.saver.pt_h5_saver import PyTorchH5Saver
from cerebras_pytorch.utils.nest import recurse_spec


[docs]def convert_file_size_to_int(size: Union[int, str]): """ Converts a size expressed as a string with digits and unit to an integer. Args: size (`int` or `str`): The size to convert (e.g., `"5MB"`). Will be directly returned if an `int`. Returns: The size in bytes. Example: ```py >>> convert_file_size_to_int("10GiB") 10737418240 ``` """ if isinstance(size, str): match = re.search(r'(\d+)(.*)', size) if not match: raise ValueError( f"size '{size}' is not in a valid format. Use an integer followed by the " f"unit, e.g., '10GB'." ) try: num = int(match.group(1)) unit = match.group(2) size = convert_byte_unit(num, "B", src_unit=unit) except: raise ValueError( f"size '{size}' is not in a valid format. Use an integer followed by the " f"unit, e.g., '10GB'." ) return size
[docs]def dtype_byte_size(dtype: torch.dtype) -> float: """ Returns the size (in bytes) occupied by one parameter of type `dtype`. Example: ```py >>> dtype_byte_size(torch.float32) 4.0 ``` """ if dtype == torch.bool: return 1 / 8 if dtype.is_floating_point: return torch.finfo(dtype).bits / 8 else: return torch.iinfo(dtype).bits / 8
[docs]class StreamingShardedHFReader: r"""Allows sharded HuggingFace checkpoints to be read in a streaming manner rather than loading all shards into memory all at once. The underlying checkpoint is read-only. Only one shard is stored into memory at a time. For this reason, accessing random keys may slow due to the switching cost (loading) between shards. For this reason, it is recommend that keys are accessed in the order given by `self.keys()` or `self.__iter__()` as keys that appear in the same shard are in consecutive order. Args: index_file: Path to .index.json file. """
[docs] def __init__(self, index_file: str) -> None: self.index_dir = os.path.dirname(index_file) with open(index_file, "r") as f: index = json.load(f) self.weight_map = index["weight_map"] self.file2keys = { file: [] for file in sorted(set(self.weight_map.values())) } for file in self.file2keys: shard_path = os.path.join(self.index_dir, file) if not os.path.exists(shard_path): raise FileNotFoundError( f"Detected missing checkpoint shard: {shard_path}" ) for key, file in self.weight_map.items(): self.file2keys[file].append(key) self.active_file_name = None self.active_file_data = None
def load_shard(self, file): if file.endswith(".safetensors"): return safetensors_torch.load_file(file, device="cpu") else: return torch.load(file, map_location="cpu") def __len__(self): return len(self.weight_map) def __iter__(self): for file in self.file2keys: for key in self.file2keys[file]: yield key def __getitem__(self, key): if key not in self.weight_map: raise KeyError file = self.weight_map[key] if file != self.active_file_name: self.active_file_name = file if self.active_file_data is not None: # Drop old data *before* load. # Without this, peak mem usage = prev shard + new shard del self.active_file_data self.active_file_data = self.load_shard( os.path.join(self.index_dir, file), ) return self.active_file_data[key] def items(self): for key in self.keys(): yield key, self[key] def keys(self): return list(self.__iter__()) def values(self): for key in self.keys(): yield self[key]
[docs]class StreamingShardedHFWriter: r"""Writes a HuggingFace sharded checkpoint in a streaming manner rather than accumulating the full checkpoint into memory and then writing all shards at the end. A partial checkpoint is accumulated into memory until it reaches the shard size limit at which point this shard is written to disk. It is essential that `self.save()` is called in order to flush the last shard to disk and to save other required metadata. The StreamingShardedHFWriter class supports re-accessing and even updating keys that have already been written. Note that accessing existing keys randomly may be slow due to the switching cost (loading) between shards that have already been written to disk. For this reason, it is recommend that keys are re-accessed in the order given by `self.keys()` or `self.__iter__()` as keys that appear in the same shard are in consecutive order. Note that updating data stored in a shard may result in a shard that is smaller/larger than the original shard size, as StreamingShardedHFWriter will not intelligently split or coalesce shards during updates. Args: checkpoint_dir: Path to where a new directory will be created to store the checkpoint shards. shard_size: The maximum size each checkpoint shard should be. Can be an integer representing the number of bytes, or a formatted string (ex: "10GB"). See convert_file_size_to_int for valid string formats. export_safetensors: Whether the output shards should be saved as safetensors or pickle files. Default: False. When using pickle files, the checkpoint & index files are saved with the 'pytorch_model` prefix while they use the 'model' prefix when using safetensors. """
[docs] def __init__( self, checkpoint_dir: str, shard_size: Union[str, int] = "10GB", export_safetensors=False, ) -> None: self.checkpoint_dir = checkpoint_dir self.file_ext = 'safetensors' if export_safetensors else 'bin' self.file_prefix = "pytorch_" if not export_safetensors else "" os.mkdir(self.checkpoint_dir) self.index_file = os.path.join( self.checkpoint_dir, f"{self.file_prefix}model.{self.file_ext}.index.json", ) self.weight_map = {} self.current_file_number = 0 self.last_file_number = 0 self.total_shards_finalized = 0 self.active_file_name = self.get_filename( self.current_file_number, self.total_shards_finalized ) self.active_file_data = {} self.file_size = {self.active_file_name: 0} self.dirty = True self.max_shard_size = convert_file_size_to_int(shard_size)
def __len__(self): return len(self.weight_map) def __iter__(self): for key in self.weight_map: yield key def __getitem__(self, key): if key not in self.weight_map: raise KeyError file = self.weight_map[key] if file != self.active_file_name: self._switch_shards(file) return self.active_file_data[key] def __setitem__(self, key, value): if key in self.weight_map: # We are updating a key that has already been seen before file = self.weight_map[key] if self.active_file_name != file: self._switch_shards(file) old_value = self.active_file_data[key] old_weight_size = math.ceil( old_value.numel() * dtype_byte_size(old_value.dtype) ) weight_size = math.ceil( value.numel() * dtype_byte_size(value.dtype) ) delta_size = weight_size - old_weight_size if ( self.file_size[self.active_file_name] + delta_size > self.max_shard_size ): logging.warning( f"Updating {key} is causing shard {self.active_file_name} to be larger than " f"limit." ) self.active_file_data[key] = value self.weight_map[key] = self.active_file_name self.file_size[self.active_file_name] += delta_size self.dirty = True else: # We are adding a new key that hasn't been seen before weight_size = math.ceil( value.numel() * dtype_byte_size(value.dtype) ) if self.current_file_number != self.last_file_number: self._switch_shards( self.get_filename( self.last_file_number, self.total_shards_finalized ) ) # Create a new shard if this new weight "tips" us over the limit: if ( self.file_size[self.active_file_name] + weight_size > self.max_shard_size ): self._flush() self.last_file_number += 1 self.current_file_number = self.last_file_number if self.active_file_data is not None: # Drop old data *before* load. # Without this, peak mem usage = prev shard + new shard del self.active_file_data self.active_file_data = {} self.active_file_name = self.get_filename( self.current_file_number, self.total_shards_finalized ) self.file_size[self.active_file_name] = 0 self.active_file_data[key] = value self.weight_map[key] = self.active_file_name self.file_size[self.active_file_name] += weight_size self.dirty = True def get_filename(self, file_number, total_shards=0): return f"{self.file_prefix}model-{file_number+1:05d}-of-{total_shards:05d}.{self.file_ext}" def load_shard(self, file): if self.file_ext == "safetensors": return safetensors_torch.load_file(file, device="cpu") else: return torch.load(file, map_location="cpu") def save_shard(self, data, file): if self.file_ext == "safetensors": def materialize(value): if hasattr(value, "_materialize"): value = value._materialize() if isinstance(value, torch.Tensor): value = value.contiguous() return value materialized_data = {k: materialize(v) for k, v in data.items()} safetensors_torch.save_file( materialized_data, file, {"format": "pt"} ) else: torch.save(data, file) def _flush(self): if self.dirty: self.save_shard( self.active_file_data, os.path.join(self.checkpoint_dir, self.active_file_name), ) self.dirty = False def _switch_shards(self, new_file): self._flush() self.active_file_name = new_file if self.active_file_data is not None: # Drop old data *before* load. # Without this, peak mem usage = prev shard + new shard del self.active_file_data self.active_file_data = self.load_shard( os.path.join(self.checkpoint_dir, new_file), ) def save(self): self._flush() total_size = sum(shard_size for shard_size in self.file_size.values()) # Finalize total number of shards: new_total_shards = self.last_file_number + 1 if self.total_shards_finalized != new_total_shards: # Step 1: Figure out the prev file -> new file mapping so that # we can rename the files / data structures file_renames = { self.get_filename( i, self.total_shards_finalized ): self.get_filename(i, new_total_shards) for i in range(new_total_shards) } # Step 2: Rename the checkpoint files for prev_file, new_file in file_renames.items(): os.rename( os.path.join(self.checkpoint_dir, prev_file), os.path.join(self.checkpoint_dir, new_file), ) # Step 3: Update the weight map & file size data structures: self.weight_map = { key: file_renames[prev_file] for key, prev_file in self.weight_map.items() } self.file_size = { file_renames[prev_file]: size for prev_file, size in self.file_size.items() } # Step 4: Update the # of finalized shards so that future updates # to the writer will be able to correctly pick up the shards self.total_shards_finalized = new_total_shards with open(self.index_file, "w") as f: f.write( json.dumps( { "metadata": {"total_size": total_size,}, "weight_map": self.weight_map, }, indent=4, ) ) def items(self): for key in self.keys(): yield key, self[key] def keys(self): return list(self.__iter__()) def values(self): for key in self.keys(): yield self[key]
[docs]class StreamingCSLeaf: r"""Marks checkpoint keys that can be directly loaded from/saved to the H5 checkpoint. Non-leafs are accessed through StreamingCSWriterView due to their iterable nature. """ def __str__(self) -> str: return "*" def __repr__(self) -> str: return "*"
[docs]class StreamingCSWriterView: r"""StreamingCSWriterView allows for checkpoints with arbitrarily nested dictionaries/lists to be written in a streaming (incremental) manner by offering a "view" into a StreamingCSWriter. For example, in a checkpoint with the structure {"model": {<model state>}}, we can obtain a view into the model state via checkpoint["model"]. This view has state <model state> and prefix ["model"]. The view acts like a dict (offers `__getitem__`, `__setitem__`, etc operations) which incrementally saves/loads from an H5 checkpoint under the hood. Args: checkpoint_file: Path to H5 checkpoint state: (Sub)state dictionary corresponding to the current view of the checkpoint. prefix: Chain of keys that were accessed in the checkpoint that yielded the current view """
[docs] def __init__(self, checkpoint_file, state, prefix=None) -> None: self.checkpoint_file = checkpoint_file self.state = state self.prefix = prefix or []
def __str__(self): return str(self.state) def __repr__(self): return f"StreamingCSWriterView: {str(self)}" def __iter__(self): if isinstance(self.state, dict): for key in self.keys(): yield key if isinstance(self.state, (list, tuple)): for i in range(len(self.state)): yield self[i] def __len__(self): return len(self.state) def items(self): assert isinstance(self.state, dict) for key in self.keys(): yield key, self[key] def keys(self): assert isinstance(self.state, dict) for key in self.state: if key in self: yield key def values(self): assert isinstance(self.state, dict) for key in self.keys(): yield self[key] def __contains__(self, item): return item in self.state def __getitem__(self, key): value = self.state[key] if isinstance(value, StreamingCSLeaf): saver = PyTorchH5Saver() name = ".".join(self.prefix + [key]) return saver.load_tensor(self.checkpoint_file, name) if isinstance(value, StreamingCSWriterView): return value if isinstance(value, (dict, list, tuple)): subview = StreamingCSWriterView( self.checkpoint_file, value, self.prefix + [key] ) return subview return value def get(self, key, default=None): if key in self: return self[key] return default def __setitem__(self, key, value): if key in self.state and not isinstance( self.state[key], StreamingCSLeaf ): raise ValueError( "StreamingCSWriter does not support updating an existing \ key which had a dict/list/tuple value" ) if isinstance(value, (dict, list, tuple)): if key in self.state: raise ValueError( "StreamingCSWriter does not support updating a key which \ already exists with a dict/list/tuple" ) flattened, spec = torch.utils._pytree.tree_flatten(value) for scope, v in zip(recurse_spec(spec), flattened): name = ".".join(self.prefix + [key] + scope) saver = PyTorchH5Saver() saver.save_tensor(self.checkpoint_file, name, v) substate = torch.utils._pytree.tree_unflatten( [StreamingCSLeaf() for i in range(len(flattened))], spec, ) self.state[key] = substate else: name = ".".join(self.prefix + [key]) saver = PyTorchH5Saver() saver.save_tensor(self.checkpoint_file, name, value) self.state[key] = StreamingCSLeaf()
[docs]class StreamingCSWriter(StreamingCSWriterView): r"""Writes a Cerebras H5 checkpoint in a streaming (incremental) manner rather than accumulating the full checkpoint into memory and then writing all weights at the end. It is essential that `self.save()` is called in order to flush the required metadata (state's spec). Without this call, the resulting checkpoint will not be able to be loaded with `cstorch.load(...)`. The StreamingCSWriter class supports re-accessing and even updating keys that have already been written. There are two restrictions: 1. An existing key that stores a dict/list/tuple cannot be replaced. 2. An existing key storing any type cannot be replaced by a dict/list/tuple Args: checkpoint_file: Path to new H5 checkpoint. A file cannot already exist at this location. """
[docs] def __init__(self, checkpoint_file) -> None: if os.path.exists(checkpoint_file): raise FileExistsError( f"Checkpoint file \"{checkpoint_file}\" cannot be created because " "file already exists" ) super().__init__(checkpoint_file, {})
def save(self): saver = PyTorchH5Saver() _, spec = saver.flatten_state_dict(self.state) saver.save_spec(self.checkpoint_file, spec) def __str__(self): return f"{self.checkpoint_file}:\n{self.state}" def __repr__(self): return f"StreamingCSWriter: {str(self)}"
[docs]class OnDemandDictionaryConverter: r"""Wraps around an input dictionary in order to transform its values on-the-fly. The transformation has the following restrictions: 1. It must maintain a 1-1 mapping (i.e. no new/dropped keys) 2. The keys cannot change names (only values can change) There is error checking during object initialization and during runtime to ensure that this restriction holds. Args: underlying_dict: Underlying dictionary that needs to be transformed in an on-demand fashion converter_class: A subclass of BaseDictionaryConverter which describes the transformation of the underlying dictionary action_fn_args: Additional arguments that may be used in the converter's action functions. """
[docs] def __init__( self, underlying_dict, converter_class, action_fn_args=None ) -> None: super().__init__() self.underlying_dict = ReadOnlyDict(underlying_dict) self.converter_instance = converter_class() self.action_fn_args = action_fn_args or {} self.verify_converter()
def verify_converter(self): # Deferred to prevent circular import: from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import ( BaseDictionaryConverter, ) assert isinstance(self.converter_instance, BaseDictionaryConverter), ( f"{self.__class__}'s nested converter must subclass " f"BaseDictionaryConverter" ) disallowed_fns = [ "pre_checkpoint_convert", "pre_model_convert", "post_model_convert", "post_checkpoint_convert", ] for fn_name in disallowed_fns: assert not hasattr(self.converter_instance, fn_name), ( f"{self.__class__} only supports converters that are 1-1 " f"mappings. Therefore, the nested converter cannot contain the " f"{fn_name} function" ) for rule in self.converter_instance.rules: if not all(isinstance(elm, str) for elm in rule.segments): raise ValueError( f"{self.__class__} only supports converters that are 1-1 " f"mappings. Therefore, their rules can only contain regex " f"strings (no EquivalentSubkey or BaseDictionaryConverter " f"objects). The following conversion rule offends this " f"constraint:\n{rule}" ) def __len__(self): return len(self.underlying_dict) def __iter__(self): return self.underlying_dict.__iter__() def __getitem__(self, key): if key not in self.underlying_dict: raise KeyError new_temp_dict = {} from_index = 0 match = self.converter_instance.convert_key( key, self.underlying_dict, new_temp_dict, from_index, action_fn_args=self.action_fn_args, ) if set(new_temp_dict) != {key}: raise ValueError( f"{self.__class__}'s nested converter did not create a 1-1 " f"mapping." ) if not match: raise KeyError return new_temp_dict[key] def items(self): for key in self.keys(): yield key, self[key] def keys(self): return self.underlying_dict.keys() def values(self): for key in self.keys(): yield self[key]
def _readonly(self, *args, **kwargs): raise RuntimeError("Cannot modify ReadOnlyDict")
[docs]class ReadOnlyDict(dict): """A Read-only dict. Note that this object doesn't guard against the values from being mutated in-place. """ __setitem__ = _readonly __delitem__ = _readonly pop = _readonly popitem = _readonly clear = _readonly update = _readonly setdefault = _readonly