modelzoo.common.pytorch.model_utils.checkpoint_converters.streaming_checkpoints.StreamingShardedHFWriter#

class modelzoo.common.pytorch.model_utils.checkpoint_converters.streaming_checkpoints.StreamingShardedHFWriter[source]#

Bases: object

Writes a HuggingFace sharded checkpoint in a streaming manner rather than accumulating the full checkpoint into memory and then writing all shards at the end.

A partial checkpoint is accumulated into memory until it reaches the shard size limit at which point this shard is written to disk.

It is essential that self.save() is called in order to flush the last shard to disk and to save other required metadata.

The StreamingShardedHFWriter class supports re-accessing and even updating keys that have already been written. Note that accessing existing keys randomly may be slow due to the switching cost (loading) between shards that have already been written to disk. For this reason, it is recommend that keys are re-accessed in the order given by self.keys() or self.__iter__() as keys that appear in the same shard are in consecutive order. Note that updating data stored in a shard may result in a shard that is smaller/larger than the original shard size, as StreamingShardedHFWriter will not intelligently split or coalesce shards during updates.

Parameters
  • checkpoint_dir – Path to where a new directory will be created to store the checkpoint shards.

  • shard_size – The maximum size each checkpoint shard should be. Can be an integer representing the number of bytes, or a formatted string (ex: “10GB”). See convert_file_size_to_int for valid string formats.

  • export_safetensors – Whether the output shards should be saved as safetensors or pickle files. Default: False. When using pickle files, the checkpoint & index files are saved with the ‘pytorch_model` prefix while they use the ‘model’ prefix when using safetensors.

Methods

get_filename

items

keys

load_shard

save

save_shard

values

__init__(checkpoint_dir: str, shard_size: Union[str, int] = '10GB', export_safetensors=False) None[source]#