Source code for modelzoo.transformers.data_processing.scripts.hdf5_preprocessing.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import gzip
import io
import json
import logging
import os
import re
from dataclasses import asdict, dataclass
from itertools import repeat
from math import ceil
from multiprocessing import Pool
from pathlib import Path

import h5py
import jsonlines
import numpy as np
import yaml
import zstandard
from lm_dataformat import listdir_or_file, tarfile_reader
from tqdm import tqdm

from modelzoo.transformers.data_processing.utils import split_list

logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)


[docs]def add_common_args(parser): """ For the argparse to parse arguments for subcommands, we add common command line arguments to each subcommand parser here. """ parser.add_argument( "--params", type=str, default=None, help="Path to the YAML config file for setting dataset preprocessing hyper-parameters.", ) parser.add_argument( "--input_dir", type=str, help="Directory where raw data is stored.", ) parser.add_argument( "--metadata_files", type=str, default=None, help="Path to text file containing a list of file names " "corresponding to the raw input documents to be " "processed and stored; can handle multiple metadata files " "separated by comma.", ) parser.add_argument( "--output_dir", type=str, help="Directory where HDF5 files will be stored.", ) parser.add_argument( "--processes", type=int, help="Number of processes to use.", ) parser.add_argument( "--tokenizer_type", type=str, choices=["GPT2Tokenizer", "NeoXTokenizer", "HuggingFaceTokenizer"], help=( "Type of tokenizer to use for HDF5 dataset generation. " "Can be one of `GPT2Tokenizer`, `NeoXTokenizer` or `HuggingFaceTokenizer`." ), ) parser.add_argument( "--huggingface_tokenizer", type=str, default=None, help=( "Name/Path to the HuggingFace tokenizer." "Only used when tokenizer_type=HuggingFaceTokenizer" ), ) parser.add_argument( "--vocab_file", type=str, help="Path to the vocabulary file." ) parser.add_argument( "--encoder_file", type=str, help="Path to the encoder file." ) parser.add_argument( "--eos_id", type=int, help="Token id of the end of sentence token", ) parser.add_argument( "--pad_id", type=int, help="Token id of the padding token." ) parser.add_argument( "--max_seq_length", type=int, help="Maximum sequence length.", ) parser.add_argument( "--short_seq_prob", type=float, default=0.0, help=( "Probability of creating sequences which are shorter than the" + " maximum sequence length." ), ) parser.add_argument( "--use_ftfy", type=str, choices=["True", "False"], help="Whether to fix text with ftfy. Defaults to `True`.", ) parser.add_argument( "--ftfy_normalizer", type=str, choices=["NFC", None], help=( "Choose what kind of unicode normalization is applied. Usually, we " "apply `NFC` normalization, so that letters followed by combining " "characters become single combined characters. Using `None` " "applies no normalization while fixing text." ), ) parser.add_argument( "--wikitext_detokenize", type=str, choices=["True", "False"], help="Whether to use wikitext detokenizer to fix text. Defaults to `False`.", ) parser.add_argument( "--output_name", type=str, default="examples", help=( "Name of the dataset; i.e. prefix to use for HDF5 file names." + "Defaults to `examples`." ), ) parser.add_argument( "--files_per_record", type=int, help="Text files to write per HDF5 file.", ) parser.add_argument( "--write_in_batch", type=str, choices=["True", "False"], help="Whether to write the samples in batch for the HDF5 format, " "setting to false will save memory but a bit slower. Defaults to " "`True`.", ) parser.add_argument( "--write_remainder", type=str, choices=["True", "False"], help="Write the remainder files when data is left over from " "processing. Defaults to `True`.", ) parser.add_argument( "--pack_sequences", type=str, choices=["True", "False"], help="Concatenate a document smaller than maximum sequence length with " "other documents, instead of filling it with Padding token. Defaults " "to `True`.", ) parser.add_argument( "--min_sequence_len", type=int, default=6, help=( "sequences shorter than min_sequence_len tokens in length will be skipped" ), ) parser.add_argument( "--resume_from_checkpoint", type=str, choices=["True", "False"], help="Resume record writing from a given checkpoint. Defaults to `False`.", ) parser.add_argument( "--display_pbar", type=str, choices=["True", "False"], help="Display progress while runs. Defaults to `False`.", ) parser.add_argument( "--seed", type=int, help="Random seed.", )
[docs]def get_parser(desc): """Argparser definition for command line arguments from user. Returns: Argparse namespace object with command line arguments. """ parser = argparse.ArgumentParser(description=desc) subparser = parser.add_subparsers( description="Sub command for HDF5 conversion.", dest="mode", required=True, help="Sub command to choose saving the raw text into HDF5 files or " "pre-processed text converted into token ids at desired maximum " "sequence length.", ) lm_parser = subparser.add_parser( "LMData", help="Language modeling dataset in `.jsonl` or `.txt` format." ) add_common_args(lm_parser) lm_parser.add_argument( "--jsonl_key", type=str, default=None, help="The key name in input jsonl files from which the raw text will be " "extracted in order to further process it.", ) lm_parser.add_argument( "--split_text_to_tokenize", type=str, choices=["True", "False"], help="Whether to split the text into smaller chunks before tokenizing.", ) lm_parser.add_argument( "--chunk_len_to_split", type=int, help="Length of the chunk size to split the text document into.", ) lm_parser.add_argument( "--remove_bos_in_chunks", type=str, choices=["True", "False"], help="Whether to ignore bos token id in chunks when splitting the text.", ) summarization_parser = subparser.add_parser( "Summarization", help="Fine-tuning dataset in plane text format." ) add_common_args(summarization_parser) summarization_parser.add_argument( "--sep_token", type=str, default=None, help="Token added between prompt and completion in preprocessed sequences.", ) summarization_parser.add_argument( "--prompt_key", type=str, help="Json key for the prompt.", ) summarization_parser.add_argument( "--completion_key", type=str, help="Json key for the completion.", ) custom_parser = subparser.add_parser( "Customize", help="Provide customized dataset processor." ) add_common_args(custom_parser) custom_parser.add_argument( "--module", type=str, help="Python file name contains the custom dataset processor.", ) custom_parser.add_argument( "--dataset_processor", type=str, help="Name of the custom dataset processor.", ) return parser.parse_args()
[docs]def update_params(params, args): """ Update config parameters with CLI arguments """ setup_params = [ "input_dir", "metadata_files", "output_dir", "processes", "module", "dataset_processor", ] processing_params = [ "tokenizer_type", "huggingface_tokenizer", "vocab_file", "encoder_file", "eos_id", "pad_id", "split_text_to_tokenize", "chunk_len_to_split", "remove_bos_in_chunks", "max_seq_length", "short_seq_prob", "output_name", "files_per_record", "write_in_batch", "write_remainder", "resume_from_checkpoint", "display_pbar", "seed", ] dataset_params = [ "use_ftfy", "ftfy_normalizer", "wikitext_detokenize", "jsonl_key", "pack_sequences", "min_sequence_len", "input_ids_dtype", "input_mask_dtype", "inverted_mask", "prompt_key", "completion_key", "sep_token", ] processor_map = { "lmdata": "LMDataPreprocessor", "summarization": "SummarizationPreprocessor", } mode = args.pop("mode").lower() if mode != "customize": params["setup"]["dataset_processor"] = processor_map[mode] for key, value in args.items(): if value in ["True", "False"]: value = value == "True" if value is not None: if key in setup_params: params["setup"][key] = value elif key in processing_params: params["processing"][key] = value elif key in dataset_params: params["dataset"][key] = value else: raise ValueError(f"Unexpected arguments: {key}") set_defaults(params)
[docs]def set_defaults(params): params["processing"]["eos_id"] = params["processing"].get("eos_id") params["processing"]["pad_id"] = params["processing"].get("pad_id") params["processing"]["split_text_to_tokenize"] = params["processing"].get( "split_text_to_tokenize", False ) params["processing"]["chunk_len_to_split"] = params["processing"].get( "chunk_len_to_split", 2000 ) params["processing"]["remove_bos_in_chunks"] = params["processing"].get( "remove_bos_in_chunks", False ) params["processing"]["write_in_batch"] = params["processing"].get( "write_in_batch", True ) params["processing"]["write_remainder"] = params["processing"].get( "write_remainder", True ) params["processing"]["resume_from_checkpoint"] = params["processing"].get( "resume_from_checkpoint", False ) params["processing"]["display_pbar"] = params["processing"].get( "display_pbar", False ) params["dataset"]["use_ftfy"] = params["dataset"].get("use_ftfy", True) params["dataset"]["ftfy_normalizer"] = params["dataset"].get( "ftfy_normalizer", "NFC" ) params["dataset"]["wikitext_detokenize"] = params["dataset"].get( "wikitext_detokenize", False ) params["dataset"]["pack_sequences"] = params["dataset"].get( "pack_sequences", True )
[docs]def get_params(desc): """Retrieve configuration parameters Returns: params (Dict): Dictionary contains the parameters used to configure the data processing. """ args = get_parser(desc) args = vars(args) params_file = args.pop("params", None) if params_file: with open(params_file, 'r') as stream: params = yaml.safe_load(stream) else: params = {} for section in ["setup", "processing", "dataset"]: if not params.get(section, None): params[section] = {} update_params(params, args) return params
[docs]def dump_args(args, json_params_file): """ Write the input params to file. """ logger.info(f"User arguments can be found at {json_params_file}.") # write initial params to file with open(json_params_file, "w") as _fout: json.dump(args, _fout, indent=4, sort_keys=True)
[docs]def dump_result( results, dataset_stats, json_params_file, eos_id=None, pad_id=None, vocab_size=None, ): """ Write outputs of execution """ with open(json_params_file, "r") as _fin: data = json.load(_fin) post_process = {} post_process["discarded_files"] = results["discarded"] post_process["processed_files"] = results["processed"] post_process["successful_files"] = results["successful"] post_process["n_examples"] = results["examples"] post_process["raw_chars_count"] = results["raw_chars_count"] post_process["raw_bytes_count"] = results["raw_bytes_count"] if eos_id is not None: post_process["eos_id"] = eos_id if pad_id is not None: post_process["pad_id"] = pad_id if vocab_size is not None: post_process["vocab_size"] = vocab_size data["post-process"] = post_process data["h5_dataset_stats"] = asdict(dataset_stats) with open(json_params_file, "w") as _fout: json.dump(data, _fout, indent=4, sort_keys=True)
[docs]@dataclass class VerificationArgs: processes: int files_per_record: int max_seq_length: int tokenizer_obj: object eos_id: int pad_id: int
[docs]def get_verification_args(processes, data_processor): """Get arguments for verifying HDF5 dataset. Args: params (dict): Dictionary containing parameters for verifying HDF5 dataset. data_processor: Class containing methods that specify how the dataset will be processed and written into HDF5 files. """ verification_args = VerificationArgs( processes, data_processor.files_per_record, data_processor.max_seq_length, data_processor.tokenizer, data_processor.eos_id, data_processor.pad_id, ) return verification_args
[docs]def process_dataset(files, dataset_processor, processes): """Process a dataset and write it into HDF5 format. Args: files (list): List of files to process. dataset_processor: Class containing methods that specify how the dataset will be processed and written into HDF5 files. processes (int): Number of processes to use. Returns: Dictionary containing results of execution, specifically as number of processed, discarded, and successful files as well as number of examples from all processes. """ if processes < 2: # Run only single process run, with process number set as 0. return dataset_processor.create_dataset((files, 0)) try: n_proc = processes n_chunks = ceil(len(files) / n_proc) remain = len(files) % n_proc if n_chunks == 1 and remain: n_proc = remain logger.warning( f"There aren't enough files to distribute to {processes} " f"processes, resetting it to {n_proc}. If you're working with a " "small number of compressed archives and could extract it into " "txt files, you might be able to get more benefits from the " f"available {processes} processes." ) files = split_list(files, n_chunks) except ValueError as e: # We hit errors in two potential scenarios, # 1) Files is an empty list, in which case there is nothing to split # 2) There are more processes than files, in which case we cannot split # the files to processes correctly, as there will be many idle # processes which are not doing anything. logger.error(e) raise with Pool(processes=n_proc) as pool: pbar = tqdm( pool.imap( dataset_processor.create_dataset, zip(files, range(len(files)),), ), total=len(files), ) meta = { "discarded": 0, "processed": 0, "successful": 0, "examples": 0, "raw_chars_count": 0, "raw_bytes_count": 0, } for results in pbar: pbar.update() for k, v in results.items(): meta[k] += v return meta
[docs]@dataclass class DatasetStats: num_sequences: int num_tokens: int detokenized_bytes: int detokenized_chars: int non_pad_tokens: int loss_valid_tokens: int
[docs]def collect_stats(data_arr, args): """Collect statistics of the dataset. Args: data_arr (numpy.ndarray): Numpy array containing the dataset. args (ValidationArgs): Arguments for verifying HDF5 dataset. """ num_sequences = data_arr.shape[0] num_tokens = data_arr.shape[0] * data_arr.shape[2] non_pad_tokens = np.logical_and( data_arr[:, 0, :] != args.eos_id, data_arr[:, 0, :] != args.pad_id ).sum() loss_valid_tokens = data_arr[:, 1, :].sum() detokenized_bytes = 0 detokenized_chars = 0 for i in range(data_arr.shape[0]): line_str = args.tokenizer_obj.decode(data_arr[i, 0, :]) detokenized_bytes += len(line_str.encode("utf-8")) detokenized_chars += len(line_str) return DatasetStats( num_sequences, num_tokens, detokenized_bytes, detokenized_chars, int(non_pad_tokens), # cast to int to support saving to json int(loss_valid_tokens), # cast to int to support saving to json )
[docs]def verify_saved_hdf5_files(params): """ This function is used to do sanity checks at the end of the creation of hdf5 files. This function loads every .h5 files generated and checks: 1. The data type 2. Shape of the dataset 3. Fact that labels and inputs are as expected """ h5_files_path, args, vocab_size = params h5_stats = DatasetStats( 0, 0, 0, 0, 0, 0 ) # stats over list of files in a process for h5_file_path in h5_files_path: with h5py.File(h5_file_path, mode="r") as h5_file: n_examples = h5_file.attrs["n_examples"] dataset = h5_file["data"] data_arr = dataset[()] expected_dtype = "i4" assert dataset.dtype == expected_dtype, ( f"Error in {h5_file}, conversion is corrupted as the " f"datatype is unexpected. Expected: {expected_dtype}, " f"received {dataset.dtype}." ) data_shape = data_arr.shape assert ( data_shape[1:] == (3, args.max_seq_length) or args.max_seq_length == -1 ), ( f"Error in {h5_file}, conversion is corrupted as the " f"shape of example is unexpected. Expected:" f" {(3, args.max_seq_length)}, received {data_shape[1:]}." ) assert (data_arr < vocab_size).all(), ( f"Error in {h5_file}, conversion is corrupted as the " f"input ids are greater than vocab size." f"Please ensure that a correct tokenizer is used " f"and the eos_id and pad_id are correct within the " f"tokenizer vocabulary size." ) file_stats = collect_stats(data_arr, args) assert n_examples == file_stats.num_sequences, ( f"Error in {h5_file}, conversion is corrupted as the " f"number of examples in file is unexpected. Expected:" f" {n_examples}, collected {file_stats.num_sequences}." ) assert file_stats.num_tokens == n_examples * args.max_seq_length, ( f"Error in {h5_file}, conversion is corrupted as the " f"number of tokens in file is unexpected. Expected:" f" {n_examples * args.max_seq_length}, collected " f"{file_stats.num_tokens}." ) h5_stats.num_sequences += file_stats.num_sequences h5_stats.num_tokens += file_stats.num_tokens h5_stats.detokenized_bytes += file_stats.detokenized_bytes h5_stats.detokenized_chars += file_stats.detokenized_chars h5_stats.non_pad_tokens += file_stats.non_pad_tokens h5_stats.loss_valid_tokens += file_stats.loss_valid_tokens return h5_stats
[docs]def verify_saved_hdf5_files_mp(files, args, vocab_size): """Verify the generated HDF5 dataset. Args: files (list): List of files to process. args (VerificationArgs): Arguments for verifying HDF5 dataset. vocab_size (int): Size of the vocabulary from data_processor. """ if args.processes == 1: return verify_saved_hdf5_files((files, args, vocab_size)) try: n_proc = args.processes n_chunks = ceil(len(files) / n_proc) remain = len(files) % n_proc if n_chunks == 1 and remain: n_proc = remain logger.warning( f"There aren't enough files to distribute to {args.processes} " f"processes, resetting it to {n_proc}." ) files = split_list(files, n_chunks) except ValueError as e: # We hit errors in one potential scenario: # Files is an empty list, in which case there is nothing to split logger.error(e) raise dataset_stats = DatasetStats(0, 0, 0, 0, 0, 0) with Pool(processes=n_proc) as pool: pbar = tqdm(desc="Verifying HDF5 files", total=len(files),) for stats in pool.imap( verify_saved_hdf5_files, zip(files, repeat(args), repeat(vocab_size),), ): dataset_stats.num_sequences += stats.num_sequences dataset_stats.num_tokens += stats.num_tokens dataset_stats.detokenized_bytes += stats.detokenized_bytes dataset_stats.detokenized_chars += stats.detokenized_chars dataset_stats.non_pad_tokens += stats.non_pad_tokens dataset_stats.loss_valid_tokens += stats.loss_valid_tokens pbar.update() return dataset_stats
[docs]def handle_jsonl( jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key=None ): for ob in jsonl_reader: # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. if isinstance(ob, str): assert not get_meta yield ob continue if key == None: yield ob else: text = ob[key] if autojoin_paragraphs and isinstance(text, list): text = para_joiner.join(text) if get_meta: yield text, (ob['meta'] if 'meta' in ob else {}) else: yield text
# Slightly modified version of the Reader class from lm_dataformat. # from https://github.com/leogao2/lm_dataformat/blob/master/lm_dataformat/__init__.py
[docs]class Reader:
[docs] def __init__(self, in_path): self.in_path = in_path
def stream_data(self, get_meta=False, jsonl_key=None): self.f_name = "" files = listdir_or_file(self.in_path) if not files: raise FileNotFoundError(f"No valid file(s) found in {self.in_path}") for f in files: self.f_name = f if f.endswith('.jsonl'): yield from self.read_jsonl(f, get_meta, key=jsonl_key) elif f.endswith('.jsonl.zst'): yield from self.read_jsonl_zst(f, get_meta, key=jsonl_key) elif f.endswith('.jsonl.zst.tar'): yield from self.read_jsonl_tar(f, get_meta, key=jsonl_key) elif f.endswith('.json.zst'): assert not get_meta yield from self.read_json(f) elif f.endswith('.txt'): assert not get_meta yield from self.read_txt(f) elif f.endswith('.json.gz'): assert not get_meta yield from self.read_jsongz(f) else: # shouldn't be reached print( f'Skipping {f} as streaming for that filetype is not implemented' ) def read_txt(self, file): with open(file, 'r') as fh: yield fh.read() def read_gz(self, file): with gzip.open(file, 'rb') as f: for line in f: yield line.decode('utf-8') def read_jsongz(self, file): for line in self.read_gz(file): yield json.loads(line) def read_json(self, file): with open(file, 'rb') as fh: cctx = zstandard.ZstdDecompressor() reader = cctx.stream_reader(fh) ob = json.load(reader) yield from ob def read_jsonl( self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key=None, ): with jsonlines.open(file) as rdr: yield from handle_jsonl( rdr, get_meta, autojoin_paragraphs, para_joiner, key ) def read_jsonl_zst( self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key=None, ): with open(file, 'rb') as fh: cctx = zstandard.ZstdDecompressor() reader = io.BufferedReader(cctx.stream_reader(fh)) rdr = jsonlines.Reader(reader) yield from handle_jsonl( rdr, get_meta, autojoin_paragraphs, para_joiner, key ) def read_jsonl_tar( self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key=None, ): with open(file, 'rb') as fh: for f in tarfile_reader(fh, streaming=True): cctx = zstandard.ZstdDecompressor() reader = io.BufferedReader(cctx.stream_reader(f)) rdr = jsonlines.Reader(reader) yield from handle_jsonl( rdr, get_meta, autojoin_paragraphs, para_joiner, key ) f.close()
[docs]def validate_tokens(tokens, min_len=2): is_valid = len(tokens) >= min_len if not is_valid: logger.warning( f"token_ids must have at least {min_len} elements, skipping this example..." ) return is_valid
[docs]def create_features_auto_lm( token_ids, max_sequence_length, short_seq_prob=0, inverted_mask=False, pad_id=0, min_len=10, input_ids_dtype="int32", input_mask_dtype="int32", labels_dtype="int32", rng=None, ): """Given a list of token_ids, generate input sequence and labels. Args: token_ids (sequence): List containing token ids for creating features, labels and input mask from. max_sequence_length (int): Maximum sequence length for data writes. short_seq_prob (float): Probability of generating short sequences from data. Defaults to `0`. inverted_mask (bool): Invert mask if specified for runtime execution. Defaults to `False`. min_len (int): Minimum length of token_ids to be considered a valid sequence. pad_id (int): Id for pad token. Defaults to `0`. input_ids_dtype (str): Dtype as string for input ids. Defaults to `int32`. input_mask_dtype (str): Dtype as string for input mask. Defaults to `int32`. labels_dtype (str): Dtype as string for labels. Defaults to `int32`. rng (random.Random obj): Instance of random object, with states set. Defaults to `None`. Returns: Tuple containing features and labels """ if not validate_tokens(token_ids, min_len=min_len): return [] if rng.random() < short_seq_prob: token_ids = token_ids[0 : rng.randint(2, max_sequence_length - 1)] input_ids = token_ids[:-1] labels = token_ids[1:] input_mask = [1] * len(input_ids) # padding num_pad = max_sequence_length - len(input_ids) padding = [pad_id] * num_pad input_ids.extend(padding) labels.extend(padding) input_mask.extend([0] * num_pad) # assertions to ensure correct output shapes assert ( len(input_ids) == max_sequence_length and len(labels) == max_sequence_length and len(input_mask) == max_sequence_length ), "Wrong sequence length" # create feature dict features = dict() features["input_ids"] = getattr(np, input_ids_dtype)(input_ids) features["input_mask"] = getattr(np, input_mask_dtype)(input_mask) if inverted_mask: features["input_mask"] = np.equal(features["input_mask"], 0).astype( features["input_mask"].dtype ) labels = getattr(np, labels_dtype)(labels) return np.stack([features["input_ids"], features["input_mask"], labels])
[docs]def create_features_summarization( prompt_ids, completion_ids, max_sequence_length, eos_id=0, sep_id=None, pad_id=0, min_len=10, inverted_mask=False, input_ids_dtype="int32", input_mask_dtype="int32", labels_dtype="int32", ): """ Given a list of prompt_ids and completion_ids, generate input sequence and labels. Args: prompt_ids (sequence): List containing token ids for the prompt to create features,labels and input mask from. completion_ids (sequence): List containing token ids for the completion create features,labels and input mask from. max_sequence_length (int): Maximum sequence length for data writes. eos_id (int): Id for end of sequence token. Defaults to `0`. sep_id (int): Id for separator token. Defaults to `None`. pad_id (int): Id for pad token. Defaults to `0`. min_len (int): Minimum length of token_ids to be considered a valid sequence. inverted_mask (bool): Invert mask if specified for runtime execution. Defaults to `False`. input_ids_dtype (str): Dtype as string for input ids. Defaults to `int32`. input_mask_dtype (str): Dtype as string for input mask. Defaults to `int32`. labels_dtype (str): Dtype as string for labels. Defaults to `int32`. """ # extra <EOS> total_len = len(prompt_ids) + len(completion_ids) + 1 if sep_id is not None: total_len += 1 if total_len > max_sequence_length: logger.warning( "prompt_ids + completion_ids > max_sequence_length, skipping this example..." ) return [] if total_len < min_len: logger.warning( "prompt_ids + completion_ids < min_sequence_len, skipping this example..." ) return [] token_ids = prompt_ids if sep_id is not None: token_ids = token_ids + [sep_id] token_ids = token_ids + completion_ids + [eos_id] token_mask = [0] * (len(prompt_ids)) if sep_id is not None: token_mask += [1] else: # if no sep_id, prediction starts at the last token of prompt_ids token_mask[-1] = 1 token_mask += [1] * len(completion_ids) token_mask += [0] # EOS # add padding token_ids_pad = max_sequence_length + 1 - len(token_ids) input_mask_pad = max_sequence_length - len(token_mask) token_ids.extend([pad_id] * token_ids_pad) token_mask.extend([0] * input_mask_pad) input_ids = token_ids[:-1] labels = token_ids[1:] assert ( len(input_ids) == max_sequence_length and len(labels) == max_sequence_length and len(token_mask) == max_sequence_length ), "Wrong sequence length" features = dict() features["input_ids"] = getattr(np, input_ids_dtype)(input_ids) features["input_mask"] = getattr(np, input_mask_dtype)(token_mask) if inverted_mask: features["input_mask"] = np.equal(features["input_mask"], 0).astype( features["input_mask"].dtype ) labels = getattr(np, labels_dtype)(labels) return np.stack([features["input_ids"], features["input_mask"], labels])
[docs]def get_files(input_dir=None, filetypes=None, metadata_files=None): """Get all files of given filetypes from input directory. Args: input_dir (str): Input directory to read files from. filetypes (list): File types to fetch from the given input directory. Defaults to `None`. metadata_files (str): Comma separated string of metadata files. Returns: List of lists containing all file paths as strings """ if not filetypes: filetypes = [ '.jsonl', '.json.gz', '.jsonl.zst', '.jsonl.zst.tar', '.txt', ] if isinstance(filetypes, str): filetypes = [filetypes] filetypes = tuple(filetypes) assert input_dir or metadata_files, ( "User need to provide `input_dir` or `metadata_files`, " "but neither was provided." ) if metadata_files: if isinstance(metadata_files, str): metadata_files = [metadata_files] if input_dir: logger.warning( "Both `input_dir` and `metadata_files` were provided, " "ignoring `input_dir` and using `metadata_files`." ) input_files = [] for _file in metadata_files: with open(_file, "r") as _fin: input_files.extend(_fin.readlines()) input_files_list = [x.strip() for x in input_files if x] flattened_list = [x for x in input_files_list if x.endswith(filetypes)] else: files = [list(Path(input_dir).rglob(f"*{ft}")) for ft in filetypes] # flatten list of list -> list and stringify Paths flattened_list = [str(item) for sublist in files for item in sublist] if not flattened_list: raise Exception( f"Did not find any files at this path {input_dir}, please " f"ensure your files are in format {filetypes}." ) return flattened_list
[docs]def wikitext_detokenizer(string): """Detokenizer for wikitext. Used for special handling of data for substrings. Args: string (str): String to detoknize before tokenization. Returns: Detokenized string """ # contractions string = string.replace("s '", "s'") string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) # number separators string = string.replace(" @-@ ", "-") string = string.replace(" @,@ ", ",") string = string.replace(" @.@ ", ".") # punctuation string = string.replace(" : ", ": ") string = string.replace(" ; ", "; ") string = string.replace(" . ", ". ") string = string.replace(" ! ", "! ") string = string.replace(" ? ", "? ") string = string.replace(" , ", ", ") # double brackets string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) # miscellaneous string = string.replace("= = = =", "====") string = string.replace("= = =", "===") string = string.replace("= =", "==") string = string.replace(" " + chr(176) + " ", chr(176)) string = string.replace(" \n", "\n") string = string.replace("\n ", "\n") string = string.replace(" N ", " 1 ") string = string.replace(" 's", "'s") return string
[docs]def read_checkpoint(checkpoint_path, resume_from_checkpoint=True): """Checkpoint reader for execution. Args: checkpoint_path (str): Path to read checkpoint data from resume_from_checkpoint (bool): Resume from checkpoint for execution. Defaults to `True`. Returns: Tuple containing number of files processed and the count of tfrecords/HDF5 files written to output directory. """ if resume_from_checkpoint and os.path.isfile(checkpoint_path): try: resume_files_processed, count = [ int(i) for i in open(checkpoint_path, "r").read().split(", ") ] logger.info( f"Resuming from file number: {count}, " f"with raw file number processed: {resume_files_processed}" ) return resume_files_processed, count except Exception as e: # if checkpoint path is at initialization, # file may exist, but no data might be written in the file # in that event, do not do anything, go to the final return logger.error(e) return 0, 0