Source code for cerebras.modelzoo.data.nlp.bert.BertCSVDataProcessor

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Processor for PyTorch BERT training.
"""
import csv
import random

import numpy as np
import torch

from cerebras.modelzoo.common.input_utils import (
    bucketed_batch,
    get_streaming_batch_size,
)
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.common.input_utils import (
    get_data_for_task,
    num_tasks,
    shard_list_interleaved,
    task_id,
)
from cerebras.modelzoo.data.nlp.bert.bert_utils import get_meta_data


[docs]@registry.register_datasetprocessor("BertCSVDataProcessor") class BertCSVDataProcessor(torch.utils.data.IterableDataset): """ Reads csv files containing the input text tokens, and MLM features. :param <dict> params: dict containing input parameters for creating dataset. Expects the following fields: - "data_dir" (string): path to the data files to use. - "batch_size" (int): Batch size. - "shuffle" (bool): Flag to enable data shuffling. - "shuffle_seed" (int): Shuffle seed. - "shuffle_buffer" (int): Shuffle buffer size. - "dynamic_mlm_scale" (bool): Flag to dynamically scale the loss. - "num_workers" (int): How many subprocesses to use for data loading. - "drop_last" (bool): If True and the dataset size is not divisible by the batch size, the last incomplete batch will be dropped. - "prefetch_factor" (int): Number of samples loaded in advance by each worker. - "persistent_workers" (bool): If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. - "mixed_precision" (bool): Casts input mask to fp16 if set to True. Otherwise, the generated mask is float32. """
[docs] def __init__(self, params): super(BertCSVDataProcessor, self).__init__() # Input params. self.meta_data = get_meta_data(params["data_dir"]) self.meta_data_values = list(self.meta_data.values()) self.meta_data_filenames = list(self.meta_data.keys()) # Please note the appending of [0] self.meta_data_values_cum_sum = np.cumsum([0] + self.meta_data_values) self.num_examples = sum(map(int, self.meta_data.values())) self.disable_nsp = params.get("disable_nsp", False) self.batch_size = get_streaming_batch_size(params["batch_size"]) self.num_batches = self.num_examples // self.batch_size assert ( self.num_batches > 0 ), "Dataset does not contain enough samples for one batch. Please choose a smaller batch size" self.num_tasks = num_tasks() self.task_id = task_id() self.num_batch_per_task = self.num_batches // self.num_tasks assert ( self.num_batch_per_task > 0 ), "Dataset cannot be evenly distributed across the given tasks. Please choose fewer tasks to run with" self.num_examples_per_task = self.num_batch_per_task * self.batch_size self.files_in_task = get_data_for_task( self.task_id, self.meta_data_values_cum_sum, self.num_examples_per_task, self.meta_data_values, self.meta_data_filenames, ) self.shuffle = params.get("shuffle", True) self.shuffle_seed = params.get("shuffle_seed", None) self.shuffle_buffer = params.get("shuffle_buffer", 10 * self.batch_size) self.dynamic_mlm_scale = params.get("dynamic_mlm_scale", False) self.buckets = params.get("buckets", None) # Multi-processing params. self.num_workers = params.get("num_workers", 0) self.drop_last = params.get("drop_last", True) self.prefetch_factor = params.get("prefetch_factor", 2) self.persistent_workers = params.get("persistent_workers", False) # Store params. self.data_buffer = [] self.csv_files_per_task_per_worker = [] self.processed_buffers = 0
[docs] def load_buffer(self): """ Generator to read the data in chunks of size of `data_buffer`. :returns: Yields the data stored in the `data_buffer`. """ self.processed_buffers = 0 self.data_buffer = [] while self.processed_buffers < len(self.csv_files_per_task_per_worker): ( current_file_path, num_examples, start_id, ) = self.csv_files_per_task_per_worker[self.processed_buffers] with open(current_file_path, "r", newline="") as fin: data_reader = csv.DictReader(fin) for row_id, row in enumerate(data_reader): if start_id <= row_id < start_id + num_examples: self.data_buffer.append(row) else: continue if len(self.data_buffer) == self.shuffle_buffer: if self.shuffle: self.rng.shuffle(self.data_buffer) for ind in range(len(self.data_buffer)): yield self.data_buffer[ind] self.data_buffer = [] self.processed_buffers += 1 if self.shuffle: self.rng.shuffle(self.data_buffer) for ind in range(len(self.data_buffer)): yield self.data_buffer[ind] self.data_buffer = []
def __len__(self): # Returns the len of dataset on the task process if not self.drop_last: return ( self.num_examples_per_task + self.batch_size - 1 ) // self.batch_size elif self.buckets is None: return self.num_examples_per_task // self.batch_size else: # give an under-estimate in case we don't fully fill some buckets length = self.num_examples_per_task // self.batch_size length -= self.batch_size * (len(self.buckets) + 1) return length
[docs] def get_single_item(self): """ Iterating over the data to construct input features. :return: A tuple with training features: * np.array[int.32] input_ids: Numpy array with input token indices. Shape: (`max_sequence_length`). * np.array[int.32] labels: Numpy array with labels. Shape: (`max_sequence_length`). * np.array[int.32] attention_mask Shape: (`max_sequence_length`). * np.array[int.32] token_type_ids: Numpy array with segment indices. Shape: (`max_sequence_length`). * np.array[int.32] next_sentence_label: Numpy array with labels for NSP task. Shape: (1). * np.array[int.32] masked_lm_mask: Numpy array with a mask of predicted tokens. Shape: (`max_predictions`) `0` indicates the non masked token, and `1` indicates the masked token. """ # Iterate over the data rows to create input features. for data_row in self.load_buffer(): # `data_row` is a dict with keys: features = { "input_ids": np.array( eval(data_row["input_ids"]), dtype=np.int32 ), "masked_lm_mask": np.array( # Stored as masked_lm_weights, but really masked_lm_mask eval(data_row["masked_lm_weights"]), dtype=np.int32, ), "masked_lm_positions": np.array( eval(data_row["masked_lm_positions"]), dtype=np.int32 ), "attention_mask": np.array( eval(data_row["attention_mask"]), dtype=np.int32 ), "labels": np.array(eval(data_row["labels"]), dtype=np.int32), } if not self.disable_nsp: features["next_sentence_label"] = np.array( eval(data_row["next_sentence_label"]), dtype=np.int32 ) features["token_type_ids"] = np.array( eval(data_row["token_type_ids"]), dtype=np.int32 ) yield features
def __iter__(self): batched_dataset = bucketed_batch( self.get_single_item(), self.batch_size, buckets=self.buckets, element_length_fn=lambda feats: np.sum(feats["attention_mask"]), drop_last=self.drop_last, seed=self.shuffle_seed, ) for batch in batched_dataset: if self.dynamic_mlm_scale: scale = self.batch_size / torch.sum(batch["masked_lm_mask"]) batch["mlm_loss_scale"] = scale.expand(self.batch_size, 1) yield batch def _worker_init_fn(self, worker_id): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = worker_info.num_workers else: # Single-process worker_id = 0 num_workers = 1 self.processed_buffers = 0 if self.shuffle_seed is not None: self.shuffle_seed += worker_id + 1 self.rng = random.Random(self.shuffle_seed) # Shard the data across multiple processes. self.csv_files_per_task_per_worker = shard_list_interleaved( self.files_in_task, worker_id, num_workers ) if self.shuffle: self.rng.shuffle(self.csv_files_per_task_per_worker)
[docs] def create_dataloader(self): """ Classmethod to create the dataloader object. """ if self.num_workers: dataloader = torch.utils.data.DataLoader( self, batch_size=None, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, worker_init_fn=self._worker_init_fn, ) else: dataloader = torch.utils.data.DataLoader(self, batch_size=None) self._worker_init_fn(0) return dataloader