Source code for cerebras.modelzoo.data.common.input_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
from typing import Iterator, Sized

import numpy as np
import torch

import cerebras.pytorch as cstorch
import cerebras.pytorch.distributed as dist
from cerebras.pytorch.distributed.cluster_resolver import ClusterSpec, TaskSpec


[docs]def get_data_for_task( task_id, meta_data_values_cum_sum, num_examples_per_task, meta_data_values, meta_data_filenames, ): """ Function to get distribute files with given number of examples such that each distributed task has access to exactly the same number of examples Args: task_id (int): Integer id for a task. meta_data_values_cum_sum (int): Cumulative sum of the file sizes in lines from meta data file. num_examples_per_task (int): Number of the examples specified per slurm task. Equal to `batch_size` * `num_batch_per_task`. meta_data_values (list[int]): List of the files sizes in lines in the meta data file. meta_data_filenames (list[str]): List with file names in the meta data file. Returns: list of tuples of length 3. The tuple contains at - index 0: filepath. - index 1: number of examples to be considered for this task_id. - index 2: start index in the file from where these examples should be considered The list represents the files that should be considered for this task_id. """ files_in_task = [] # file where the split starts file_start_idx = np.min( np.where(meta_data_values_cum_sum > task_id * num_examples_per_task)[0] ) # Index in file from where the examples should be considered for this task start_idx = ( task_id * num_examples_per_task - meta_data_values_cum_sum[file_start_idx - 1] # -1 since len(`meta_data_values_cum_sum`) = len(`meta_data_values`) + 1 ) # Number of examples to pick from this file. # We do a `min` to handle a case where the file has # examples > num_examples_per_task num_examples = min( meta_data_values[file_start_idx - 1] - start_idx, num_examples_per_task, ) files_in_task.append( ( meta_data_filenames[file_start_idx - 1], num_examples, start_idx, ) # (file_path, num_examples, start_index) ) if num_examples != num_examples_per_task: # If the file has fewer number of examples than # `num_examples_per_task`, continue through files # till we reach our required number of examples. indices = np.where( meta_data_values_cum_sum > (task_id + 1) * num_examples_per_task )[0] if indices.size != 0: file_end_idx = np.min(indices) else: file_end_idx = len(meta_data_values_cum_sum) for i in range(file_start_idx + 1, file_end_idx): files_in_task.append( ( meta_data_filenames[i - 1], meta_data_values[i - 1], 0, ) # (file_path, num_examples, start_index) ) # If the number of examples needed to fulfill # `num_examples_per_task`, falls in between a file num_end_examples = ( task_id + 1 ) * num_examples_per_task - meta_data_values_cum_sum[file_end_idx - 1] if num_end_examples > 0: files_in_task.append( ( meta_data_filenames[file_end_idx - 1], num_end_examples, 0, ) # (file_path, num_examples, start_index) ) assert ( sum([num_examples for _, num_examples, _ in files_in_task]) == num_examples_per_task ), f"Incorrect number of examples in the split with task_id {task_id}" return files_in_task
[docs]def is_distributed(): """ Returns True if DDP is enabled. """ return ( torch.distributed.is_available() and torch.distributed.is_initialized() )
[docs]def task_id(): if dist.is_streamer(): return dist.get_streaming_rank() elif is_distributed(): return dist.get_rank() else: return 0
[docs]def num_tasks(): if dist.is_streamer(): return dist.num_streamers() elif is_distributed(): return dist.get_world_size() else: return 1
[docs]def cluster_config(): """ Returns (ClusterSpec, TaskSpec). The TaskSpec contains the following fields: - rank: the global rank of the current worker - local_rank: the rank of the current worker among workers who feed the same system as the current worker - wse_id: the index of the system that the current worker is associated with The ClusterSpec contains the following fields: - tasks: a list of TaskSpecs for each task running on the cluster - rank: the rank of the current process's task in the cluster - num_csx: the number of CSX systems in the cluster - num_workers_per_csx: the number of worker tasks per CSX If the current job is running on GPU instead of CS system, then the ranks and world sizes in the returned TaskSpec will be set to the GPU rank and world size. """ if cstorch.use_cs() and dist.is_streamer(): cluster_spec = dist.service_resolver().cluster_spec task_spec = cluster_spec.task() return cluster_spec, task_spec elif is_distributed(): task_spec = TaskSpec( rank=dist.get_rank(), local_rank=dist.get_rank(), wse_id=0, node_name="unknown", ) cluster_spec = ClusterSpec( [task_spec], dist.get_rank(), 1, dist.get_world_size(), ) return cluster_spec, task_spec else: task_spec = TaskSpec( rank=0, local_rank=0, wse_id=0, node_name="unknown" ) cluster_spec = ClusterSpec([task_spec], 0, 1, 1) return cluster_spec, task_spec
[docs]class ShardedSampler(torch.utils.data.Sampler): """ Modified from: https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler Sampler that restricts data loading to a subset of the dataset. Dataset is assumed to be of constant size. Args: dataset (torch.utils.data.Dataset): Dataset used for sampling. shuffle (bool, optional): If `True` (default), sampler will shuffle the indices. seed (int, optional): Random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: `0`. drop_last (bool, optional): If `True`, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If `False`, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: `False`. """
[docs] def __init__(self, dataset, shuffle=True, seed=None, drop_last=False): self.num_tasks = num_tasks() self.task_id = task_id() self.dataset = dataset self.dataset_len = len(self.dataset) self.drop_last = drop_last if cstorch.use_cs() and not self.drop_last: raise ValueError( "On CS2 we do not support unequal batch sizes so `drop_last` " "must be set to `True`." ) # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_tasks: # Split to nearest available length that is evenly divisible. # This is to ensure each task receives the same amount of data when # using this sampler. self.num_samples = len(self.dataset) // self.num_tasks else: self.num_samples = math.ceil(len(self.dataset) / self.num_tasks) self.total_size = self.num_samples * self.num_tasks self.shuffle = shuffle self.seed = seed self.indices = list(range(self.dataset_len)) if not self.drop_last: # add extra samples to make it evenly divisible across tasks padding_indices_size = self.total_size - self.dataset_len # choose padding indices at random to reduce the chance of # reusing samples. random.seed(self.seed) padding_indices = random.sample(self.indices, padding_indices_size) self.indices += padding_indices else: # remove tail of data to make it evenly divisible. self.indices = self.indices[: self.total_size] assert len(self.indices) == self.total_size, ( f"Total `indices` after dropping/padding indices must be equal " f"to `total_size` of the dataset. Received total indices: " f"`{len(self.indices)}` and total size is: `{self.total_size}`." )
def __iter__(self): if self.shuffle: random.seed(self.seed) random.shuffle(self.indices) # subsample indices = self.indices[self.task_id : self.total_size : self.num_tasks] assert len(indices) == self.num_samples, ( f"Total `indices` for tasks must be equal to `num_samples` in a " f"task. Received total indices: `{len(indices)}` and samples in " f"task are: `{self.num_samples}`." ) yield from indices def __len__(self): return self.num_samples
[docs]def check_sharding_sanity( examples_per_file, batch_size, num_workers, drop_last, ): """Checks if with the given sharding, at least one batch is generated. Note that this method is operating based on how `shard_and_shuffle_data` is sharding the data across workers. :param list examples_per_file: Total examples per file for this task. :param int batch_size: Batch size of the model. :param int num_workers: Number of workers to use in the dataloader. :param bool drop_last: Boolean indicating whether the last incomplete batch of the dataloader is dropped. :raises ValueError: If no batches are generated with the given sharding. """ if drop_last is False: return if num_workers == 0: total_samples = sum(examples_per_file) if total_samples < batch_size: raise ValueError( f"Task {task_id()} only generates {total_samples}, which " f"is fewer than a full batch of size {batch_size}. " ) return examples_per_worker = [0] * num_workers for file_idx, examples_in_file in enumerate(examples_per_file): worker_id = file_idx % num_workers examples_per_worker[worker_id] += examples_in_file max_examples = max(examples_per_worker) if max_examples < batch_size: raise ValueError( f"Maximum number of samples generated in dataloader workers of " f"task {task_id()} is {max_examples}. Since {max_examples} is less " f"than batch size {batch_size} and `drop_last` is True, this task " f"will end up not producing any samples. Please specify a fewer " f"number of workers or tasks." )
[docs]def shard_list_contiguous(input_list, worker_id, num_workers): """ Shards a list by splitting it into `num_workers` contiguous segments. Only the `worker_id`th shard is returned. If the length of the list is not divisible by the number of workers, the last worker will be assigned all remainder elements. Args: input_list (list): list to shard into contiguous segments worker_id (int): index of shard to return num_workers (int): number of shards to create Returns: A sublist of contiguous elements (`worker_id`'s shard) """ assert num_workers <= len(input_list), ( f"Number of processes should be less than number of files, " f"Got `num_workers` equal to {num_workers} and `num_files` equal to {len(input_list)}." ) per_worker_num_files = len(input_list) // num_workers if worker_id < num_workers - 1: output_list = input_list[ (worker_id * per_worker_num_files) : ( (worker_id + 1) * per_worker_num_files ) ] else: output_list = input_list[(worker_id * per_worker_num_files) :] return output_list
[docs]def shard_list_interleaved(input_list, worker_id, num_workers): """ Shards a list by assigning consecutive elements to alternating workers (i.e. interleaving). If the length of the list is not divisible by the number of workers, the remainder elements are spread across a subset of the workers such that each worker in the subset receives 1 extra element. Args: input_list (list): list to shard in an interleaved fashion worker_id (int): index of shard to return num_workers (int): number of shards to create Returns: `worker_id`'s shard (a subset of `input_list`). """ output_for_cur_worker = [] if num_workers != 0: assert num_workers <= len(input_list), ( f"Number of processes should be less than number of files, " f"Got `num_workers` equal to {num_workers} and `num_files` equal to {len(input_list)}." ) # Gather files for the input worker based in the file index and # number of workers. for index, elm in enumerate(input_list): if index % num_workers == worker_id: output_for_cur_worker.append(elm) else: output_for_cur_worker = input_list return output_for_cur_worker
[docs]def shard_list_of_chunks_contiguous( input_list_of_chunks, worker_id, num_workers ): """ Shards a list of chunks by distributing contiguous segments of each chunk across shards. If the chunk's length is not divisible by the number of workers, the remainder elements are spread across a subset of the workers such that each worker in the subset receives 1 extra element. Args: input_list (list of tuples): list of chunks to shard. List should be of format `[... (chunk_i, length_of_chunk_i), ...]` worker_id (int): index of shard to return num_workers (int): number of shards to create Returns: `worker_id`'s shard: a list of the same length as `input_list` of the format: `[... (chunk_i, shard_start_index_i, shard_length_i), ...]` """ output_for_cur_worker = [] for elm, chunk_length in input_list_of_chunks: # Try to evenly distribute chunk_length between workers chunk_length_per_worker = [(chunk_length // num_workers)] * num_workers for i in range(chunk_length % num_workers): chunk_length_per_worker[i] += 1 assert sum(chunk_length_per_worker) == chunk_length output_for_cur_worker.append( ( elm, sum(chunk_length_per_worker[:worker_id]) if worker_id > 0 else 0, # Start index chunk_length_per_worker[worker_id], # Length of data chunk ) ) return output_for_cur_worker
[docs]class SubsetSequentialSampler(torch.utils.data.Sampler[int]): r"""Samples elements sequentially, starting from given `start_index`, always in the same order. Args: data_source (Dataset): dataset to sample from start_index (int): index where sampling starts from """ data_source: Sized start_index: int
[docs] def __init__(self, data_source: Sized, start_index: int) -> None: self.data_source = data_source self.start_index = start_index
def __iter__(self) -> Iterator[int]: return iter(range(self.start_index, len(self.data_source))) def __len__(self) -> int: return len(self.data_source)