# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import tempfile
from pathlib import Path
from typing import List, Optional, Union
import h5py
import numpy as np
[docs]class H5Reader:
"""Class for reading individual sequences from HDF5 files stored on disk.
Supports 2 formats of data on disk:
1. rank-1 tensor of concatenated tokenized documents.
2. rank > 1 tensor of preprocessed samples where the 0th index of the
data on disk indexes the data by sample.
"""
[docs] def __init__(
self,
data_dirs: Union[str, List[str]],
sequence_length: Optional[int] = None,
read_extra_token: bool = False,
data_subset: Optional[str] = None,
sort: bool = True,
use_vsl: bool = False,
):
"""Creates a reader for an HDF5 corpus.
Args:
data_dirs: Directories containing h5 files to read from.
sequence_length: The number of tokens per sample if reading
from a corpus. Must be `None` if the data has already been
preprocessed into samples.
read_extra_token: Whether to read and return one extra token
after the end of the sequence. This can be useful for language
modeling tasks where you want to construct the labels as an
shifted version of the inputs. Setting this to `True` differs
from increasing `sequence_length` by one in that the extra token
returned due to this flag will be included in some other
sequence as the first token. Will be ignored if
`sequence_length` is `None`.
data_subset: A string specifying the subset of the corpus to
consider. E.g. if `data_subset="0.0-0.75"` is specified, only
samples in the first 3/4 of the dataset will be considered and
the last 1/4 of the dataset will be completely untouched. The
self reported length will be the length of the valid portion
of the dataset (e.g. the first 3/4), and any attempt to access
an element beyond this length will result in an exception.
sort: Whether to sort the file paths after reading them. This flag
is included for backwards compatibility and should almost always
be set to `True`. It will be removed in the future.
use_vsl: Flag to enable variable sequence length training.
It requires the dataset to have two extra features: the
`attention_span` of keys and the `position_ids` of tokens.
"""
files = []
if not isinstance(data_dirs, list):
data_dirs = [data_dirs]
for data_dir in data_dirs:
p = Path(data_dir)
if not p.is_dir():
raise ValueError(
f"The path {p} does not exist or is not a directory. "
f"Please specify a valid directory containing h5 files "
f"and ensure that the directory is mounted."
)
files.extend(p.glob("*.h5"))
if not files:
raise ValueError(
f"No *.h5 files found in specified data directories: "
f"{data_dirs}."
)
if sort:
files.sort()
by_sample = False
with h5py.File(files[0], "r") as f:
data_shape = f["data"].shape
if sequence_length is None:
if len(data_shape) < 2:
raise ValueError(
"If you don't specify `sequence_length`, then the data "
"being read must be preprocessed by sample, but the data "
f"written to {files[0]} has rank 1"
)
by_sample = True
elif len(data_shape) > 1:
if sequence_length is not None and sequence_length != data_shape[1]:
raise ValueError(
"If loading data that has been preprocessed into sequences "
"the sequence length provided must either be None or match "
"dimension 1 of the data on disk. Got sequence length "
f"{sequence_length}, but the shape of the data in "
f"{files[0]} is {data_shape}"
)
by_sample = True
if by_sample and use_vsl and data_shape[1] != 5:
raise ValueError(
f"Expected all dataset H5 files to have 5 features for "
f"variable sequence length training, but got "
f"{data_shape[1]} features in {files[0]}."
)
if by_sample:
self._impl = _SequencedH5Reader(files, data_subset=data_subset)
else:
self._impl = _CorpusH5Reader(
files,
sequence_length=sequence_length,
read_extra_token=read_extra_token,
data_subset=data_subset,
)
@property
def by_sample(self) -> bool:
return isinstance(self._impl, _SequencedH5Reader)
def __getitem__(self, i: int) -> np.ndarray:
"""Reads a single sequence of the dataset from disk.
Args:
i: The index of the item to return. Samples are indexed in
order of file name (sorted alphabetically) then location within
that file.
Returns:
The `i`th sample element of the corpus, i.e. a numpy array of shape
`(sequence_length + 1, )` if `read_extra_token` is `True` or of
shape `(sequence_length, )` otherwise. The dtype of the returned
array is `np.int32` regardless of how the data was written to disk.
"""
return self._impl[i]
def __len__(self) -> int:
"""Returns total number of sequences in the dataset."""
return len(self._impl)
@property
def vdataset(self):
v = getattr(self._impl, "_vdataset", None)
if v is None:
raise AttributeError(
"Trying to access virtual dataset attribute, but none was found"
)
return v
class _SequencedH5Reader:
"""Class for reading preprocessed samples from HDF5 files stored on disk."""
def __init__(self, files: List[str], data_subset: Optional[str] = None):
"""Creates an HDF5 reader for preprocessed sequences.
Args:
files: HDF5 files to read from.
data_subset: A string specifying the subset of the corpus to
consider.
"""
vsources: List[h5py.VirtualSource] = []
for idx, filepath in enumerate(files):
with h5py.File(filepath, "r") as f:
dataset = f["data"]
if idx == 0:
data_shape = dataset.shape
data_dtype = dataset.dtype
else:
if dataset.dtype != data_dtype:
raise ValueError(
f"Expected all dataset H5 files to have the same "
f"dtype, but got {data_dtype} in {files[0]} and "
f"{dataset.dtype} in {filepath}."
)
if dataset.shape[1:] != data_shape[1:]:
raise ValueError(
f"Expected all dataset H5 files to have the same "
f"shape beyond the first axis, but got "
f"{data_shape} in {files[0]} and {dataset.shape} "
f"in {filepath}."
)
vsources.append(h5py.VirtualSource(dataset))
self._vdataset = _VirtualDataset(vsources)
self._num_sequences = len(self._vdataset)
if data_subset is not None:
self._segmenter = _DatasetSegmenter(
self._num_sequences, data_subset
)
self._num_sequences -= self._segmenter.num_skipped_sequences
else:
self._segmenter = None
def __getitem__(self, i: int) -> np.ndarray:
"""Reads a single item of the dataset from disk."""
if self._segmenter:
i = self._segmenter.map_index(i)
return self._vdataset[i].astype(np.int32)
def __len__(self) -> int:
"""Returns total number of sequences in the dataset."""
return self._num_sequences
class _CorpusH5Reader:
"""Class for reading samples from HDF5 corpus stored on disk."""
def __init__(
self,
files: List[str],
sequence_length: Optional[int] = None,
read_extra_token: bool = False,
data_subset: Optional[str] = None,
):
"""Creates an HDF5 reader for an HDF5 corpus.
Args:
files: HDF5 files to read from.
sequence_length: The number of tokens per sample.
read_extra_token: Whether to read and return one extra token
after the end of the sequence.
data_subset: A string specifying the subset of the corpus to
consider.
"""
vsources: List[h5py.VirtualSource] = []
for idx, filepath in enumerate(files):
with h5py.File(filepath, "r") as f:
dataset = f["data"]
if len(dataset.shape) != 1:
raise ValueError(
f"Expected all dataset H5 files in corpus format to "
f"have rank 1, but got rank {len(dataset.shape)} in "
f"{filepath}."
)
if idx == 0:
data_dtype = dataset.dtype
else:
if dataset.dtype != data_dtype:
raise ValueError(
f"Expected all dataset H5 files to have the same "
f"dtype, but got {data_dtype} in {files[0]} and "
f"{dataset.dtype} in {filepath}."
)
vsources.append(h5py.VirtualSource(dataset))
self._vdataset = _VirtualDataset(vsources)
self._msl = sequence_length
self._num_extra_tokens = 1 if read_extra_token else 0
self._num_sequences = (
len(self._vdataset) - self._num_extra_tokens
) // self._msl
if data_subset is not None:
self._segmenter = _DatasetSegmenter(
self._num_sequences, data_subset
)
self._num_sequences -= self._segmenter.num_skipped_sequences
else:
self._segmenter = None
def __getitem__(self, i: int) -> np.ndarray:
"""Reads a single item of the dataset from disk."""
if self._segmenter:
i = self._segmenter.map_index(i)
tok_idx = self._msl * i
return self._vdataset[
tok_idx : tok_idx + self._msl + self._num_extra_tokens
].astype(np.int32)
def __len__(self) -> int:
"""Returns total number of sequences in the dataset."""
return self._num_sequences
class _VirtualDataset:
"""Class that represents a virtual dataset over multiple HDF5 files."""
def __init__(self, sources: List[h5py.VirtualSource]):
"""Constructs a virtual dataset from a list of virtual sources.
Args:
sources: A list of virtual sources to construct the dataset from.
It is expected that all virtual sources have the same shape
(except for the first axis) and dtype.
"""
length = sum(s.shape[0] for s in sources)
self._shape = (length, *sources[0].shape[1:])
self._dtype = sources[0].dtype
layout = h5py.VirtualLayout(shape=self._shape, dtype=self._dtype)
start = 0
for vsource in sources:
end = start + vsource.shape[0]
layout[start:end:1, ...] = vsource
start = end
self._dataset_tmpfile = tempfile.NamedTemporaryFile(
"w", prefix="virtual_dataset", suffix=".h5"
)
with h5py.File(self._dataset_tmpfile.name, "w", libver="latest") as f:
f.create_virtual_dataset("data", layout, fillvalue=0)
self.__dataset_file = None
self.__dataset = None
@property
def _dataset(self) -> h5py.Dataset:
"""Returns the underlying dataset.
The underlying dataset is lazily loaded from disk the first time it is
accessed. This is to avoid loading the dataset then forking when using
multiprocessing. The loaded dataset is then cached to avoid reloading
the dataset on every access, which has a high overhead. This is safe
because the dataset is opened in read-only mode and is not expected
to be modified while this object is alive.
"""
if self.__dataset is None:
self.__dataset_file = h5py.File(self._dataset_tmpfile.name, "r")
self.__dataset = self.__dataset_file["data"]
# h5py >= 3.4 hits a segfault on exit deep within hdf5 libraries if the dataset isn't
# freed up before the file is closed and hdf5 atexit handlers run. Just clearing
# `self.__dataset` fixes the segfault, but while we're at it, let's also manually close
# the file.
@atexit.register
def _close_at_exit():
self.__dataset = None
if self.__dataset_file is not None:
self.__dataset_file.close()
return self.__dataset
def __getitem__(self, i) -> np.ndarray:
"""Returns the `i`th element of the dataset."""
return self._dataset[i]
def __len__(self):
"""Returns the length of the dataset."""
return self._dataset.shape[0]
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
class _DatasetSegmenter:
def __init__(self, num_sequences: int, data_subset: str):
offsets_full_dataset = []
offsets_skipped_dataset = []
try:
segments = [
(float(seg.split("-")[0]), float(seg.split("-")[1]))
for seg in data_subset.strip().split(",")
]
except Exception as e:
raise RuntimeError(
f"There was a problem parsing data subset {data_subset}. "
"data_subset must be a string of comma separated ranges of "
"floats, for example '0.0-0.2,0.5-0.7'"
) from e
prev_end = 0
segments = [(0, 0)] + segments + [(1, 1)]
n = num_sequences
for start, end in segments:
if start < 0:
raise ValueError(
f"data_subset must contain only non-negative bounds. "
f"Got {data_subset} which contains {start}."
)
if end < start:
raise ValueError(
f"the end of each range in data_subset must be at "
f"least as large as the start of the range, but "
f"start={start} and end={end} are present in provided "
f"data subset {data_subset}"
)
if end > 1:
raise ValueError(
f"data_subset can only contain ranges which are subsets"
f" of the range [0, 1], but found end={end} in "
f"data_subset {data_subset}"
)
if start < prev_end:
raise ValueError(
f"ranges in data_subset must be monotonically "
f"increasing. Got {data_subset}"
)
offsets_full_dataset.append(int(n * end) - int(n * start))
offsets_skipped_dataset.append(int(n * start) - int(n * prev_end))
prev_end = end
self._offsets_skipped_dataset = np.cumsum(offsets_skipped_dataset)
self._offsets_full_dataset = np.cumsum(offsets_full_dataset)
@property
def num_skipped_sequences(self) -> int:
return self._offsets_skipped_dataset[-1]
def map_index(self, i):
if len(self._offsets_full_dataset):
chunk_idx = self._offsets_full_dataset.searchsorted(i, side="right")
i += self._offsets_skipped_dataset[chunk_idx]
return i
[docs]class Mixture:
"""
Mix several map-style datasets according to provided weights.
Args:
datasets: a list of objects implementing `__len__` and `__getitem__`
weights: a list of weights associated with each dataset. `weights`
must have the same length as `datasets` and contain only nonnegative
values. All weights will be normalized to sum to 1.
interleave: whether or not samples of different datasets should be
interleaved together. If all the datasets are preprocessed into
sequences and shuffled before being written to disk, then setting
this flag will allow you to avoid doing any shuffling at run time
while still having samples from the different datasets intermingled,
which may be desirable for enabling sequential disk reads. This is
implemented in a way that samples within a dataset are not shuffled
in relation to each other, i.e. sample 0 of dataset 0 will always
have a smaller index than sample 1 of dataset 0.
seed: the random seed used for interleaving. Ignored if `interleave`
is `False`.
"""
[docs] def __init__(
self,
datasets: List[H5Reader],
weights: List[int],
interleave: bool = False,
seed: int = 0,
):
self.interleave = interleave
self._by_sample = all(d.by_sample for d in datasets)
if not self._by_sample and any(d.by_sample for d in datasets):
raise ValueError(
"Datasets given to a Mixture must either all read data by "
"sample or all read data by slicing a corpus, but got datasets "
"that use a mixture"
)
if len(weights) != len(datasets):
raise ValueError(
f"weights must have same length as datasets, got {weights}"
)
if any(w < 0 for w in weights):
raise ValueError(f"weights must be nonnegative, got {weights}")
if all(w == 0 for w in weights):
raise ValueError(
f"at least one weight must be greater than 0, got {weights}"
)
self.datasets = []
new_weights = []
for d, w in zip(datasets, weights):
if w > 0:
self.datasets.append(d)
new_weights.append(w)
weights = new_weights
s = sum(weights)
weights = [w / s for w in weights]
# 1 epoch of a mixture is defined to be the number of samples required
# to see every sample in each sub-dataset of weight at least 5% at least
# once. Note that this means that some samples will be seen multiple
# times in each epoch
total_samples = max(
len(d) / w for (d, w) in zip(datasets, weights) if w > 0.05
)
if self.interleave:
self.dataset_indices = [
np.full(int(total_samples * w), i, dtype=np.uint16)
for i, w in enumerate(weights)
]
self.dataset_samples = [
np.arange(int(total_samples * w)) % len(d)
for d, w in zip(self.datasets, weights)
]
self.dataset_indices = np.concatenate(self.dataset_indices)
self.dataset_samples = np.concatenate(self.dataset_samples)
self.total_samples = len(self.dataset_indices)
indices = np.arange(self.total_samples)
rng = np.random.default_rng(seed)
rng.shuffle(indices)
# we want samples within a dataset to appear in order to take
# advantage of sequential read patterns, so we sort the
# sub-components after the shuffle
boundaries = [int(total_samples * w) for w in weights]
boundaries = np.insert(np.cumsum(boundaries), 0, 0)
for start, end in zip(boundaries[:-1], boundaries[1:]):
indices[
np.where((start <= indices) & (indices < end))
] = np.arange(start, end)
self.dataset_indices = self.dataset_indices[indices]
self.dataset_samples = self.dataset_samples[indices]
else:
self.boundaries = [int(total_samples * w) for w in weights]
self.boundaries = np.cumsum(self.boundaries)
self.total_samples = self.boundaries[-1]
self.boundaries = self.boundaries[:-1]
@property
def by_sample(self):
return self._by_sample
def __getitem__(self, i):
if self.interleave:
dataset = self.datasets[self.dataset_indices[i]]
return dataset[self.dataset_samples[i]]
else:
dataset_index = np.searchsorted(self.boundaries, i, side="right")
dataset = self.datasets[dataset_index]
offset = self.boundaries[dataset_index - 1] if dataset_index else 0
sample_index = (i - offset) % len(dataset)
return dataset[sample_index]
def __len__(self):
return self.total_samples