Source code for cerebras.pytorch.distributed

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Get information about the current cluster setup."""

import os
from pathlib import Path

from cerebras.appliance.utils.units import bytes_to_human
from cerebras.pytorch.utils.utils import get_dir_size

from .cluster_resolver import TaskRole
from .service_resolver import BaseServiceResolver
from .worker_state import WorkerState


[docs]def get_worker_state(): """API exposing internal state info captured by each CSX Worker for the current run at a checkpoint step. This state info is represented in the :py:class:`DataLoaderCheckpoint` dataclass format: Returns: :py:class:`DataLoaderCheckpoint` instance holding worker state information at the checkpoint step .. note:: - This method may only be called inside of a custom implementation of `state_dict` for dataloaders conforming to the :py:class:`RestartableDataLoader` protocol, since `state_dict` is well-defined only at a checkpoint step. - Use this method to save any of the aforementioned state info recorded by each worker when defining `state_dict` for custom implementations of restartable dataloaders. - This state info captured by each worker is for the current run only, i.e. if you pause and restart a run, the counters gathering information returned by this function will be reset. """ return WorkerState.get_worker_state()
def service_resolver(): resolver = BaseServiceResolver.get_resolver() return resolver def num_tasks(): """Returns total number of tasks in the cluster.""" return service_resolver().cluster_resolver.num_tasks def num_streamers(): """Returns total number of tasks responsible for streaming inputs.""" return len(service_resolver().streamer_ordinals()) def num_receivers(): """Returns total number of tasks responsible for receiving outputs.""" return len(service_resolver().receiver_ordinals()) def get_ordinal(): """Returns the ordinal number of the current task.""" return service_resolver().cluster_resolver.rank def get_streaming_rank(): """Returns the rank of the current task among streamers.""" streamers = sorted(service_resolver().streamer_ordinals()) ordinal = get_ordinal() assert ordinal in streamers, f"Ordinal {ordinal} is not a streamer." return streamers.index(ordinal) def get_streaming_batch_size(effective_batch_size: int) -> int: """Returns the streaming batch size of the current task. In a Wafer-Scaler Cluster setup with more than 1 CS-X node, the batch size used in compile and specified by user is the effective batch size at which gradient updates are done. However, each worker node streams a local batch of data to a given CS-X node to consitute data parallel training. This helper method returns the local batch size that the current task should use given the desired effective batch size. Args: effective_batch_size: The effective batch size of the model. Returns: The local batch size to be streamed by this task. If queried on the user node (used when compiling the model), this returns the original effective batch size as passed in the argument. """ # Do some basic validation if not isinstance(effective_batch_size, int): raise TypeError( f"Expected effective batch size to be an integer, but got type " f"{type(effective_batch_size)} with value {effective_batch_size}." ) if effective_batch_size <= 0: raise ValueError( f"Expected effective batch size to be a positive integer, but got " f"value {effective_batch_size}." ) # If not queried on the worker node, return the effective batch size as is # so the compile can automatically handle data parallel and gradient # accumulation. if not is_streamer(): return effective_batch_size # If queried on the worker node, return the local batch size num_csx = service_resolver().cluster_spec.num_csx if num_csx <= 0: raise ValueError( f"Expected number of CS-X nodes to be a positive integer, but " f"got {num_csx}." ) if effective_batch_size % num_csx != 0: raise ValueError( f"Expected effective batch size {effective_batch_size} to be a " f"multiple of number of CS-X nodes {num_csx}." ) return effective_batch_size // num_csx def is_master_ordinal(local=False): """Returns True if the current task is the master task.""" # Note: keeping `local` argument for compatibility with XLA API. return service_resolver().cluster_resolver.assumes_role(TaskRole.MASTER) def is_streamer(): """Returns True if the current task is a streamer task.""" return get_ordinal() in service_resolver().streamer_ordinals() def is_receiver(): """Returns True if the current task is a receiver task.""" return get_ordinal() in service_resolver().receiver_ordinals() # constants SSD_LIMIT = 0.8 WORKER_CACHE_ROOT = "/n0/cache" def hit_worker_cache_limit(src_dir: str, dest_dir: str): """ Identifies whether copying the src_dir to a dest_dir (within worker_cache), will lead to a cache overflow Args: src_dir (str, required): directory path of the source dest_dir (str, required): directory path of the destination within the worker cache Returns: A tuple of (``is_limit_hit``, ``dir_size``, ``available_space_for_copy``) where ``is_limit_hit`` is a bool indicating whether cache limit will be hit with the copy, ``dir_size`` is the size of the src_dir to be copied to the cache, ``available_space_for_copy`` is the space available for src_dir copy, including the space occupied by the currently cached_dir corresponding to src_dir. """ # Raises if dest_dir path is not a descendant of WORKER_CACHE_ROOT Path(dest_dir).resolve().relative_to(Path(WORKER_CACHE_ROOT).resolve()) # Only add things to cache if < SSD_LIMIT occupied ssd_mount = WORKER_CACHE_ROOT # Get size of SSD mount statvfs = os.statvfs(ssd_mount) max_size = statvfs.f_frsize * statvfs.f_blocks dir_size = get_dir_size(src_dir) ssd_available = statvfs.f_frsize * statvfs.f_bavail ssd_occupied = max_size - ssd_available removal_size = get_dir_size(dest_dir) cap = SSD_LIMIT * max_size new_size = dir_size + ssd_occupied - removal_size is_limit_hit = new_size > cap available_space_for_copy = ( cap - ssd_occupied + removal_size if cap > (ssd_occupied - removal_size) else 0 ) return ( is_limit_hit, bytes_to_human(dir_size), bytes_to_human(available_space_for_copy), )