# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gzip
import io
import json
import logging
import os
import re
from dataclasses import asdict, dataclass
from itertools import repeat
from math import ceil
from multiprocessing import Pool
from pathlib import Path
import h5py
import jsonlines
import numpy as np
import pyarrow.parquet as pq
import yaml
import zstandard
from lm_dataformat import tarfile_reader
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from modelzoo.transformers.data_processing.tokenizers.BPETokenizer import (
BPETokenizer,
)
from modelzoo.transformers.data_processing.tokenizers.HFTokenizer import (
HFTokenizer,
)
from modelzoo.transformers.data_processing.utils import split_list
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
## Added .parquet extension to the list of valid extensions
VALID_EXTENSIONS = [
'.jsonl',
'.jsonl.zst',
'.jsonl.zst.tar',
'.txt',
'.json.gz',
'.parquet',
]
[docs]def has_valid_extension(file):
return any([file.endswith(ext) for ext in VALID_EXTENSIONS])
def _listdir_or_file(x):
if isinstance(x, list):
return reduce(lambda x, y: x + y, map(listdir_or_file, sorted(x)))
if os.path.isfile(x):
return [x]
elif os.path.isdir(x):
return [str(Path(x) / fn) for fn in sorted(os.listdir(x))]
else:
raise FileNotFoundError(f"{x} not found")
[docs]def listdir_or_file(x):
return list(filter(has_valid_extension, _listdir_or_file(x)))
[docs]def add_common_args(parser):
"""
For the argparse to parse arguments for subcommands, we add common
command line arguments to each subcommand parser here.
"""
parser.add_argument(
"--params",
type=str,
default=None,
help="Path to the YAML config file for setting dataset preprocessing hyper-parameters.",
)
parser.add_argument(
"--input_dir", type=str, help="Directory where raw data is stored.",
)
parser.add_argument(
"--metadata_files",
type=str,
default=None,
help="Path to text file containing a list of file names "
"corresponding to the raw input documents to be "
"processed and stored; can handle multiple metadata files "
"separated by comma.",
)
parser.add_argument(
"--output_dir",
type=str,
help="Directory where HDF5 files will be stored.",
)
parser.add_argument(
"--processes", type=int, help="Number of processes to use.",
)
parser.add_argument(
"--tokenizer_type",
type=str,
choices=["GPT2Tokenizer", "NeoXTokenizer", "HuggingFaceTokenizer"],
help=(
"Type of tokenizer to use for HDF5 dataset generation. "
"Can be one of `GPT2Tokenizer`, `NeoXTokenizer` or `HuggingFaceTokenizer`."
),
)
parser.add_argument(
"--huggingface_tokenizer",
type=str,
default=None,
help=(
"Name/Path to the HuggingFace tokenizer."
"Only used when tokenizer_type=HuggingFaceTokenizer"
),
)
parser.add_argument(
"--vocab_file", type=str, help="Path to the vocabulary file."
)
parser.add_argument(
"--encoder_file", type=str, help="Path to the encoder file."
)
parser.add_argument(
"--eos_id", type=int, help="Token id of the end of sentence token",
)
parser.add_argument(
"--pad_id", type=int, help="Token id of the padding token."
)
parser.add_argument(
"--max_seq_length", type=int, help="Maximum sequence length.",
)
parser.add_argument(
"--short_seq_prob",
type=float,
default=0.0,
help=(
"Probability of creating sequences which are shorter than the"
+ " maximum sequence length."
),
)
parser.add_argument(
"--use_ftfy",
type=str,
choices=["True", "False"],
help="Whether to fix text with ftfy. Defaults to `True`.",
)
parser.add_argument(
"--ftfy_normalizer",
type=str,
choices=["NFC", None],
help=(
"Choose what kind of unicode normalization is applied. Usually, we "
"apply `NFC` normalization, so that letters followed by combining "
"characters become single combined characters. Using `None` "
"applies no normalization while fixing text."
),
)
parser.add_argument(
"--wikitext_detokenize",
type=str,
choices=["True", "False"],
help="Whether to use wikitext detokenizer to fix text. Defaults to `False`.",
)
parser.add_argument(
"--output_name",
type=str,
default="examples",
help=(
"Name of the dataset; i.e. prefix to use for HDF5 file names."
+ "Defaults to `examples`."
),
)
parser.add_argument(
"--files_per_record",
type=int,
help="Text files to write per HDF5 file.",
)
parser.add_argument(
"--write_in_batch",
type=str,
choices=["True", "False"],
help="Whether to write the samples in batch for the HDF5 format, "
"setting to false will save memory but a bit slower. Defaults to "
"`True`.",
)
parser.add_argument(
"--write_remainder",
type=str,
choices=["True", "False"],
help="Write the remainder files when data is left over from "
"processing. Defaults to `True`.",
)
parser.add_argument(
"--pack_sequences",
type=str,
choices=["True", "False"],
help="Concatenate a document smaller than maximum sequence length with "
"other documents, instead of filling it with Padding token. Defaults "
"to `True`.",
)
parser.add_argument(
"--min_sequence_len",
type=int,
default=6,
help=(
"sequences shorter than min_sequence_len tokens in length will be skipped"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
choices=["True", "False"],
help="Resume record writing from a given checkpoint. Defaults to `False`.",
)
parser.add_argument(
"--display_pbar",
type=str,
choices=["True", "False"],
help="Display progress while runs. Defaults to `False`.",
)
parser.add_argument(
"--seed", type=int, help="Random seed.",
)
[docs]def add_lm_args(parser):
"""
The language-modeling format is common enough (FIM is very similar)
that we can re-use the arguments for it
"""
parser.add_argument(
"--jsonl_key",
type=str,
default=None,
help="The key name in input jsonl files from which the raw text will be "
"extracted in order to further process it.",
)
parser.add_argument(
"--split_text_to_tokenize",
type=str,
choices=["True", "False"],
help="Whether to split the text into smaller chunks before tokenizing.",
)
parser.add_argument(
"--chunk_len_to_split",
type=int,
help="Length of the chunk size to split the text document into.",
)
parser.add_argument(
"--remove_bos_in_chunks",
type=str,
choices=["True", "False"],
help="Whether to ignore bos token id in chunks when splitting the text.",
)
[docs]def get_parser(desc):
"""Argparser definition for command line arguments from user.
Returns:
Argparse namespace object with command line arguments.
"""
parser = argparse.ArgumentParser(description=desc)
subparser = parser.add_subparsers(
description="Sub command for HDF5 conversion.",
dest="mode",
required=True,
help="Sub command to choose saving the raw text into HDF5 files or "
"pre-processed text converted into token ids at desired maximum "
"sequence length.",
)
### LMData ###
lm_parser = subparser.add_parser(
"LMData", help="Language modeling dataset in `.jsonl` or `.txt` format."
)
add_common_args(lm_parser)
add_lm_args(lm_parser)
### Summarization ###
summarization_parser = subparser.add_parser(
"Summarization", help="Fine-tuning dataset in plane text format."
)
add_common_args(summarization_parser)
summarization_parser.add_argument(
"--sep_token",
type=str,
default=None,
help="Token added between prompt and completion in preprocessed sequences.",
)
summarization_parser.add_argument(
"--prompt_key", type=str, help="Json key for the prompt.",
)
summarization_parser.add_argument(
"--completion_key", type=str, help="Json key for the completion.",
)
### FIM ###
fim_parser = subparser.add_parser(
"FIM", help="Pre-processing to allow Fill-in-the-Middle objective"
)
add_common_args(fim_parser)
add_lm_args(fim_parser)
fim_parser.add_argument(
"--fim_rate",
type=float,
default=0.9,
help="Percent of samples to undergo FIM transformation",
)
fim_parser.add_argument(
"--spm_rate",
type=float,
default=0.5,
help="""Percent of FIM samples to go into SPM format (as opposed
to PSM)""",
)
fim_parser.add_argument(
"--fim_prefix_tok",
type=str,
help="Can specify the special token denoting FIM prefix section",
)
fim_parser.add_argument(
"--fim_middle_tok",
type=str,
help="Can specify the special token denoting FIM middle section",
)
fim_parser.add_argument(
"--fim_suffix_tok",
type=str,
help="Can specify the special token denoting FIM suffix section",
)
### LMData (VSL) ###
lm_vsl_parser = subparser.add_parser(
"LMData_VSL",
help="Language modeling dataset for variable sequence length training.",
)
add_common_args(lm_vsl_parser)
add_lm_args(lm_vsl_parser)
lm_vsl_parser.add_argument(
"--fold_long_doc",
type=str,
choices=["True", "False"],
help="Whether to fold long documents into multiple sequences. Defaults to `True`.",
)
lm_vsl_parser.add_argument(
"--position_ids_dtype",
type=str,
help="Dtype for VSL position ids. Defaults to `int32`.",
)
### Summarization (VSL) ###
summarization_vsl_parser = subparser.add_parser(
"Summarization_VSL",
help="Fine-tuning dataset in plane text format for variable sequence length training.",
)
add_common_args(summarization_vsl_parser)
summarization_vsl_parser.add_argument(
"--sep_token",
type=str,
default=None,
help="Token added between prompt and completion in preprocessed sequences.",
)
summarization_vsl_parser.add_argument(
"--prompt_key", type=str, help="Json key for the prompt.",
)
summarization_vsl_parser.add_argument(
"--completion_key", type=str, help="Json key for the completion.",
)
summarization_vsl_parser.add_argument(
"--position_ids_dtype",
type=str,
help="Dtype for VSL position ids. Defaults to `int32`.",
)
### Customize ###
custom_parser = subparser.add_parser(
"Customize", help="Provide customized dataset processor."
)
add_common_args(custom_parser)
custom_parser.add_argument(
"--module",
type=str,
help="Python file name contains the custom dataset processor.",
)
custom_parser.add_argument(
"--dataset_processor",
type=str,
help="Name of the custom dataset processor.",
)
return parser.parse_args()
[docs]def update_params(params, args):
"""
Update config parameters with CLI arguments
"""
setup_params = [
"input_dir",
"metadata_files",
"output_dir",
"processes",
"module",
"dataset_processor",
]
processing_params = [
"tokenizer_type",
"huggingface_tokenizer",
"vocab_file",
"encoder_file",
"eos_id",
"pad_id",
"split_text_to_tokenize",
"chunk_len_to_split",
"remove_bos_in_chunks",
"max_seq_length",
"short_seq_prob",
"output_name",
"files_per_record",
"write_in_batch",
"write_remainder",
"resume_from_checkpoint",
"display_pbar",
"seed",
"fim_rate",
"spm_rate",
"fim_prefix_tok",
"fim_middle_tok",
"fim_suffix_tok",
]
dataset_params = [
"use_ftfy",
"ftfy_normalizer",
"wikitext_detokenize",
"jsonl_key",
"pack_sequences",
"min_sequence_len",
"input_ids_dtype",
"input_mask_dtype",
"inverted_mask",
"prompt_key",
"completion_key",
"sep_token",
"fold_long_doc",
"seq_lengths_dtype",
]
processor_map = {
"lmdata": "LMDataPreprocessor",
"summarization": "SummarizationPreprocessor",
"fim": "FIMDataPreprocessor",
"lmdata_vsl": "VSLLMDataPreprocessor",
"summarization_vsl": "VSLSummarizationPreprocessor",
}
mode = args.pop("mode").lower()
if mode != "customize":
params["setup"]["dataset_processor"] = processor_map[mode]
for key, value in args.items():
if value in ["True", "False"]:
value = value == "True"
if value is not None:
if key in setup_params:
params["setup"][key] = value
elif key in processing_params:
params["processing"][key] = value
elif key in dataset_params:
params["dataset"][key] = value
else:
raise ValueError(f"Unexpected arguments: {key}")
set_defaults(params)
[docs]def set_defaults(params):
params["processing"]["eos_id"] = params["processing"].get("eos_id")
params["processing"]["pad_id"] = params["processing"].get("pad_id")
params["processing"]["split_text_to_tokenize"] = params["processing"].get(
"split_text_to_tokenize", False
)
params["processing"]["chunk_len_to_split"] = params["processing"].get(
"chunk_len_to_split", 2000
)
params["processing"]["remove_bos_in_chunks"] = params["processing"].get(
"remove_bos_in_chunks", False
)
params["processing"]["write_in_batch"] = params["processing"].get(
"write_in_batch", True
)
params["processing"]["write_remainder"] = params["processing"].get(
"write_remainder", True
)
params["processing"]["resume_from_checkpoint"] = params["processing"].get(
"resume_from_checkpoint", False
)
params["processing"]["display_pbar"] = params["processing"].get(
"display_pbar", False
)
params["dataset"]["use_ftfy"] = params["dataset"].get("use_ftfy", True)
params["dataset"]["ftfy_normalizer"] = params["dataset"].get(
"ftfy_normalizer", "NFC"
)
params["dataset"]["wikitext_detokenize"] = params["dataset"].get(
"wikitext_detokenize", False
)
params["dataset"]["pack_sequences"] = params["dataset"].get(
"pack_sequences", True
)
[docs]def get_params(desc):
"""Retrieve configuration parameters
Returns:
params (Dict): Dictionary contains the parameters used to configure
the data processing.
"""
args = get_parser(desc)
args = vars(args)
params_file = args.pop("params", None)
if params_file:
with open(params_file, 'r') as stream:
params = yaml.safe_load(stream)
else:
params = {}
for section in ["setup", "processing", "dataset"]:
if not params.get(section, None):
params[section] = {}
update_params(params, args)
return params
[docs]def dump_args(args, json_params_file):
"""
Write the input params to file.
"""
logger.info(f"User arguments can be found at {json_params_file}.")
# write initial params to file
with open(json_params_file, "w") as _fout:
json.dump(args, _fout, indent=4, sort_keys=True)
[docs]def dump_result(
results,
dataset_stats,
json_params_file,
eos_id=None,
pad_id=None,
vocab_size=None,
):
"""
Write outputs of execution
"""
with open(json_params_file, "r") as _fin:
data = json.load(_fin)
post_process = {}
post_process["discarded_files"] = results["discarded"]
post_process["processed_files"] = results["processed"]
post_process["successful_files"] = results["successful"]
post_process["n_examples"] = results["examples"]
post_process["raw_chars_count"] = results["raw_chars_count"]
post_process["raw_bytes_count"] = results["raw_bytes_count"]
if eos_id is not None:
post_process["eos_id"] = eos_id
if pad_id is not None:
post_process["pad_id"] = pad_id
if vocab_size is not None:
post_process["vocab_size"] = vocab_size
data["post-process"] = post_process
data["h5_dataset_stats"] = asdict(dataset_stats)
with open(json_params_file, "w") as _fout:
json.dump(data, _fout, indent=4, sort_keys=True)
[docs]@dataclass
class VerificationArgs:
processes: int
files_per_record: int
max_seq_length: int
tokenizer_obj: object
eos_id: int
pad_id: int
use_vsl: bool
[docs]def get_verification_args(processes, data_processor):
"""Get arguments for verifying HDF5 dataset.
Args:
params (dict): Dictionary containing parameters for verifying HDF5 dataset.
data_processor: Class containing methods that specify how the dataset
will be processed and written into HDF5 files.
"""
verification_args = VerificationArgs(
processes,
data_processor.files_per_record,
data_processor.max_seq_length,
data_processor.tokenizer,
data_processor.eos_id,
data_processor.pad_id,
getattr(data_processor, "use_vsl", False),
)
return verification_args
[docs]def process_dataset(files, dataset_processor, processes):
"""Process a dataset and write it into HDF5 format.
Args:
files (list): List of files to process.
dataset_processor: Class containing methods that specify how the dataset
will be processed and written into HDF5 files.
processes (int): Number of processes to use.
Returns:
Dictionary containing results of execution, specifically as number of
processed, discarded, and successful files as well as number of examples
from all processes.
"""
if processes < 2:
# Run only single process run, with process number set as 0.
return dataset_processor.create_dataset((files, 0))
try:
n_proc = processes
n_chunks = ceil(len(files) / n_proc)
remain = len(files) % n_proc
if n_chunks == 1 and remain:
n_proc = remain
logger.warning(
f"There aren't enough files to distribute to {processes} "
f"processes, resetting it to {n_proc}. If you're working with a "
"small number of compressed archives and could extract it into "
"txt files, you might be able to get more benefits from the "
f"available {processes} processes."
)
files = split_list(files, n_chunks)
except ValueError as e:
# We hit errors in two potential scenarios,
# 1) Files is an empty list, in which case there is nothing to split
# 2) There are more processes than files, in which case we cannot split
# the files to processes correctly, as there will be many idle
# processes which are not doing anything.
logger.error(e)
raise
with Pool(processes=n_proc) as pool:
pbar = tqdm(
pool.imap(
dataset_processor.create_dataset,
zip(files, range(len(files)),),
),
total=len(files),
)
meta = {
"discarded": 0,
"processed": 0,
"successful": 0,
"examples": 0,
"raw_chars_count": 0,
"raw_bytes_count": 0,
}
for results in pbar:
pbar.update()
for k, v in results.items():
meta[k] += v
return meta
[docs]@dataclass
class DatasetStats:
num_sequences: int
num_tokens: int
detokenized_bytes: int
detokenized_chars: int
non_pad_tokens: int
loss_valid_tokens: int
[docs]def collect_stats(data_arr, args):
"""Collect statistics of the dataset.
Args:
data_arr (numpy.ndarray): Numpy array containing the dataset.
args (ValidationArgs): Arguments for verifying HDF5 dataset.
"""
num_sequences = data_arr.shape[0]
num_tokens = data_arr.shape[0] * data_arr.shape[2]
non_pad_tokens = np.logical_and(
data_arr[:, 0, :] != args.eos_id, data_arr[:, 0, :] != args.pad_id
).sum()
loss_valid_tokens = data_arr[:, 1, :].sum()
detokenized_bytes = 0
detokenized_chars = 0
for i in range(data_arr.shape[0]):
line_str = args.tokenizer_obj.decode(data_arr[i, 0, :])
detokenized_bytes += len(line_str.encode("utf-8"))
detokenized_chars += len(line_str)
return DatasetStats(
num_sequences,
num_tokens,
detokenized_bytes,
detokenized_chars,
int(non_pad_tokens), # cast to int to support saving to json
int(loss_valid_tokens), # cast to int to support saving to json
)
[docs]def verify_saved_hdf5_files(params):
"""
This function is used to do sanity checks at the end of the creation
of hdf5 files.
This function loads every .h5 files generated and checks:
1. The data type
2. Shape of the dataset
3. Fact that labels and inputs are as expected
"""
h5_files_path, args, vocab_size = params
h5_stats = DatasetStats(
0, 0, 0, 0, 0, 0
) # stats over list of files in a process
for h5_file_path in h5_files_path:
with h5py.File(h5_file_path, mode="r") as h5_file:
n_examples = h5_file.attrs["n_examples"]
dataset = h5_file["data"]
data_arr = dataset[()]
expected_dtype = "i4"
if args.use_vsl:
expected_shape = (5, args.max_seq_length)
else:
expected_shape = (3, args.max_seq_length)
assert dataset.dtype == expected_dtype, (
f"Error in {h5_file}, conversion is corrupted as the "
f"datatype is unexpected. Expected: {expected_dtype}, "
f"received {dataset.dtype}."
)
data_shape = data_arr.shape
assert (
data_shape[1:] == expected_shape or args.max_seq_length == -1
), (
f"Error in {h5_file}, conversion is corrupted as the "
f"shape of example is unexpected. Expected:"
f" {expected_shape}, received {data_shape[1:]}."
)
assert (data_arr < vocab_size).all(), (
f"Error in {h5_file}, conversion is corrupted as the "
f"input ids are greater than vocab size."
f"Please ensure that a correct tokenizer is used "
f"and the eos_id and pad_id are correct within the "
f"tokenizer vocabulary size."
)
file_stats = collect_stats(data_arr, args)
assert n_examples == file_stats.num_sequences, (
f"Error in {h5_file}, conversion is corrupted as the "
f"number of examples in file is unexpected. Expected:"
f" {n_examples}, collected {file_stats.num_sequences}."
)
assert file_stats.num_tokens == n_examples * args.max_seq_length, (
f"Error in {h5_file}, conversion is corrupted as the "
f"number of tokens in file is unexpected. Expected:"
f" {n_examples * args.max_seq_length}, collected "
f"{file_stats.num_tokens}."
)
h5_stats.num_sequences += file_stats.num_sequences
h5_stats.num_tokens += file_stats.num_tokens
h5_stats.detokenized_bytes += file_stats.detokenized_bytes
h5_stats.detokenized_chars += file_stats.detokenized_chars
h5_stats.non_pad_tokens += file_stats.non_pad_tokens
h5_stats.loss_valid_tokens += file_stats.loss_valid_tokens
return h5_stats
[docs]def verify_saved_hdf5_files_mp(files, args, vocab_size):
"""Verify the generated HDF5 dataset.
Args:
files (list): List of files to process.
args (VerificationArgs): Arguments for verifying HDF5 dataset.
vocab_size (int): Size of the vocabulary from data_processor.
"""
if args.processes == 1:
return verify_saved_hdf5_files((files, args, vocab_size))
try:
n_proc = args.processes
n_chunks = ceil(len(files) / n_proc)
remain = len(files) % n_proc
if n_chunks == 1 and remain:
n_proc = remain
logger.warning(
f"There aren't enough files to distribute to {args.processes} "
f"processes, resetting it to {n_proc}."
)
files = split_list(files, n_chunks)
except ValueError as e:
## In this case files is an empty list. This happens if no output hdf5 file is created.
## This may happen when the output preprocessed dataset is too small to fit in 1 hdf5 file and write_remainder = False
logger.info(
"No output hdf5 files are created. This \
may happen when the output preprocessed dataset is too small to fit in 1 hdf5 file and write_remainder = False\
Change write_remainder flag to True to get output hdf5 files."
)
return DatasetStats(0, 0, 0, 0, 0, 0)
dataset_stats = DatasetStats(0, 0, 0, 0, 0, 0)
with Pool(processes=n_proc) as pool:
pbar = tqdm(desc="Verifying HDF5 files", total=len(files),)
for stats in pool.imap(
verify_saved_hdf5_files,
zip(files, repeat(args), repeat(vocab_size),),
):
dataset_stats.num_sequences += stats.num_sequences
dataset_stats.num_tokens += stats.num_tokens
dataset_stats.detokenized_bytes += stats.detokenized_bytes
dataset_stats.detokenized_chars += stats.detokenized_chars
dataset_stats.non_pad_tokens += stats.non_pad_tokens
dataset_stats.loss_valid_tokens += stats.loss_valid_tokens
pbar.update()
return dataset_stats
[docs]def handle_jsonl(
jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key=None
):
for ob in jsonl_reader:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if isinstance(ob, str):
assert not get_meta
yield ob
continue
if key == None:
yield ob
else:
text = ob[key]
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
else:
yield text
# Slightly modified version of the Reader class from lm_dataformat.
# from https://github.com/leogao2/lm_dataformat/blob/master/lm_dataformat/__init__.py
[docs]class Reader:
[docs] def __init__(self, in_path, tokenizable_columns):
self.in_path = in_path
## required for reading parquet data
self.tokenizable_columns = tokenizable_columns
def stream_data(self, get_meta=False):
self.f_name = ""
files = listdir_or_file(self.in_path)
jsonl_key = self.tokenizable_columns.get('jsonl_key', None)
prompt_key = self.tokenizable_columns.get('prompt_key', None)
completion_key = self.tokenizable_columns.get('completion_key', None)
if not files:
raise FileNotFoundError(f"No valid file(s) found in {self.in_path}")
for f in files:
self.f_name = f
if f.endswith('.jsonl'):
yield from self.read_jsonl(f, get_meta, key=jsonl_key)
elif f.endswith('.jsonl.zst'):
yield from self.read_jsonl_zst(f, get_meta, key=jsonl_key)
elif f.endswith('.jsonl.zst.tar'):
yield from self.read_jsonl_tar(f, get_meta, key=jsonl_key)
elif f.endswith('.json.zst'):
assert not get_meta
yield from self.read_json(f)
elif f.endswith('.txt'):
assert not get_meta
yield from self.read_txt(f)
elif f.endswith('.json.gz'):
assert not get_meta
yield from self.read_jsongz(f)
elif f.endswith('parquet'):
yield from self.read_parquet(
f,
jsonl_key=jsonl_key,
prompt_key=prompt_key,
completion_key=completion_key,
)
else:
# shouldn't be reached
print(
f'Skipping {f} as streaming for that filetype is not implemented'
)
def read_txt(self, file):
with open(file, 'r') as fh:
yield fh.read()
def read_gz(self, file):
with gzip.open(file, 'rb') as f:
for line in f:
yield line.decode('utf-8')
def read_jsongz(self, file):
for line in self.read_gz(file):
yield json.loads(line)
def read_json(self, file):
with open(file, 'rb') as fh:
cctx = zstandard.ZstdDecompressor()
reader = cctx.stream_reader(fh)
ob = json.load(reader)
yield from ob
def read_jsonl(
self,
file,
get_meta=False,
autojoin_paragraphs=True,
para_joiner='\n\n',
key=None,
):
with jsonlines.open(file) as rdr:
yield from handle_jsonl(
rdr, get_meta, autojoin_paragraphs, para_joiner, key
)
def read_parquet(
self, file, jsonl_key=None, prompt_key=None, completion_key=None
):
source = pq.ParquetFile(file)
num_row_groups = source.num_row_groups
for idx in range(num_row_groups):
table = source.read_row_group(
idx
) # Read the table outside of the blocks
if jsonl_key:
for cell in table.column(jsonl_key):
yield str(cell.as_py())
elif prompt_key and completion_key:
doc = {}
zipped_columns = zip(
table.column(prompt_key), table.column(completion_key)
)
for prompt, completion in zipped_columns:
yield {
prompt_key: str(prompt.as_py()),
completion_key: str(completion.as_py()),
}
else:
## the file is corrupted. So return empty doc
yield {}
def read_jsonl_zst(
self,
file,
get_meta=False,
autojoin_paragraphs=True,
para_joiner='\n\n',
key=None,
):
with open(file, 'rb') as fh:
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
rdr = jsonlines.Reader(reader)
yield from handle_jsonl(
rdr, get_meta, autojoin_paragraphs, para_joiner, key
)
def read_jsonl_tar(
self,
file,
get_meta=False,
autojoin_paragraphs=True,
para_joiner='\n\n',
key=None,
):
with open(file, 'rb') as fh:
for f in tarfile_reader(fh, streaming=True):
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(f))
rdr = jsonlines.Reader(reader)
yield from handle_jsonl(
rdr, get_meta, autojoin_paragraphs, para_joiner, key
)
f.close()
[docs]def validate_tokens(tokens, min_len=2):
is_valid = len(tokens) >= min_len
if not is_valid:
logger.warning(
f"token_ids must have at least {min_len} elements, skipping this example..."
)
return is_valid
[docs]def create_features_auto_lm(
token_ids,
max_sequence_length,
short_seq_prob=0,
inverted_mask=False,
pad_id=0,
min_len=10,
input_ids_dtype="int32",
input_mask_dtype="int32",
labels_dtype="int32",
rng=None,
):
"""Given a list of token_ids, generate input sequence and labels.
Args:
token_ids (sequence): List containing token ids for creating features,
labels and input mask from.
max_sequence_length (int): Maximum sequence length for data writes.
short_seq_prob (float): Probability of generating short sequences from
data. Defaults to `0`.
inverted_mask (bool): Invert mask if specified for runtime execution.
Defaults to `False`.
min_len (int): Minimum length of token_ids to be considered a valid
sequence.
pad_id (int): Id for pad token. Defaults to `0`.
input_ids_dtype (str): Dtype as string for input ids.
Defaults to `int32`.
input_mask_dtype (str): Dtype as string for input mask.
Defaults to `int32`.
labels_dtype (str): Dtype as string for labels. Defaults to `int32`.
rng (random.Random obj): Instance of random object, with states set.
Defaults to `None`.
Returns:
Tuple containing features and labels
"""
if not validate_tokens(token_ids, min_len=min_len):
return []
if rng.random() < short_seq_prob:
token_ids = token_ids[0 : rng.randint(2, max_sequence_length - 1)]
input_ids = token_ids[:-1]
labels = token_ids[1:]
input_mask = [1] * len(input_ids)
# padding
num_pad = max_sequence_length - len(input_ids)
padding = [pad_id] * num_pad
input_ids.extend(padding)
labels.extend(padding)
input_mask.extend([0] * num_pad)
# assertions to ensure correct output shapes
assert (
len(input_ids) == max_sequence_length
and len(labels) == max_sequence_length
and len(input_mask) == max_sequence_length
), "Wrong sequence length"
# create feature dict
features = dict()
features["input_ids"] = getattr(np, input_ids_dtype)(input_ids)
features["input_mask"] = getattr(np, input_mask_dtype)(input_mask)
if inverted_mask:
features["input_mask"] = np.equal(features["input_mask"], 0).astype(
features["input_mask"].dtype
)
labels = getattr(np, labels_dtype)(labels)
return np.stack([features["input_ids"], features["input_mask"], labels])
[docs]def create_features_summarization(
prompt_ids,
completion_ids,
max_sequence_length,
eos_id=0,
sep_id=None,
pad_id=0,
min_len=10,
inverted_mask=False,
input_ids_dtype="int32",
input_mask_dtype="int32",
labels_dtype="int32",
):
"""
Given a list of prompt_ids and completion_ids, generate input sequence
and labels.
Args:
prompt_ids (sequence): List containing token ids for the prompt to
create features,labels and input mask from.
completion_ids (sequence): List containing token ids for the completion
create features,labels and input mask from.
max_sequence_length (int): Maximum sequence length for data writes.
eos_id (int): Id for end of sequence token. Defaults to `0`.
sep_id (int): Id for separator token. Defaults to `None`.
pad_id (int): Id for pad token. Defaults to `0`.
min_len (int): Minimum length of token_ids to be considered a valid
sequence.
inverted_mask (bool): Invert mask if specified for runtime execution.
Defaults to `False`.
input_ids_dtype (str): Dtype as string for input ids.
Defaults to `int32`.
input_mask_dtype (str): Dtype as string for input mask.
Defaults to `int32`.
labels_dtype (str): Dtype as string for labels. Defaults to `int32`.
"""
# extra <EOS>
total_len = len(prompt_ids) + len(completion_ids) + 1
if sep_id is not None:
total_len += 1
if total_len > max_sequence_length:
logger.warning(
"prompt_ids + completion_ids > max_sequence_length, skipping this example..."
)
return []
if total_len < min_len:
logger.warning(
"prompt_ids + completion_ids < min_sequence_len, skipping this example..."
)
return []
token_ids = prompt_ids
if sep_id is not None:
token_ids = token_ids + [sep_id]
token_ids = token_ids + completion_ids + [eos_id]
token_mask = [0] * (len(prompt_ids))
if sep_id is not None:
token_mask += [1]
else:
# if no sep_id, prediction starts at the last token of prompt_ids
token_mask[-1] = 1
token_mask += [1] * len(completion_ids)
token_mask += [0] # EOS
# add padding
token_ids_pad = max_sequence_length + 1 - len(token_ids)
input_mask_pad = max_sequence_length - len(token_mask)
token_ids.extend([pad_id] * token_ids_pad)
token_mask.extend([0] * input_mask_pad)
input_ids = token_ids[:-1]
labels = token_ids[1:]
assert (
len(input_ids) == max_sequence_length
and len(labels) == max_sequence_length
and len(token_mask) == max_sequence_length
), "Wrong sequence length"
features = dict()
features["input_ids"] = getattr(np, input_ids_dtype)(input_ids)
features["input_mask"] = getattr(np, input_mask_dtype)(token_mask)
if inverted_mask:
features["input_mask"] = np.equal(features["input_mask"], 0).astype(
features["input_mask"].dtype
)
labels = getattr(np, labels_dtype)(labels)
return np.stack([features["input_ids"], features["input_mask"], labels])
[docs]def create_features_auto_lm_vsl(
bin,
max_sequence_length,
num_pad,
pad_id=0,
inverted_mask=False,
input_ids_dtype="int32",
input_mask_dtype="int32",
labels_dtype="int32",
attention_span_dtype="int32",
position_ids_dtype="int32",
):
"""Given a list of VSL sequences, generate input features and labels.
Args:
bin (list(sequence)): list of VSL sequences.
max_sequence_length (int): Maximum sequence length for data writes.
num_pad (int): number of padding tokens in the sequence.
pad_id (int): Id for pad token. Defaults to `0`.
inverted_mask (bool): Invert mask if specified for runtime execution.
Defaults to `False`.
input_ids_dtype (str): Dtype as string for input ids.
Defaults to `int32`.
input_mask_dtype (str): Dtype as string for input mask.
Defaults to `int32`.
labels_dtype (str): Dtype as string for labels. Defaults to `int32`.
attention_span_dtype (str): Dtype as string for keys attention span in VSL.
Defaults to `int32`.
position_ids_dtype (str): Dtype as string for position ids and
attention span in VSL. Defaults to `int32`.
Returns:
Tuple containing features and labels
"""
input_ids, labels, attention_span, position_ids = [], [], [], []
for sample in bin:
input_ids.extend(sample[:-1])
labels.extend(sample[1:])
sample_len = len(sample) - 1
attention_span.extend(list(range(sample_len - 1, -1, -1)))
position_ids.extend(list(range(sample_len)))
input_mask = [1] * len(input_ids)
# padding
padding = [pad_id] * num_pad
input_ids.extend(padding)
labels.extend(padding)
padding = [0] * num_pad
input_mask.extend(padding)
attention_span.extend(padding)
position_ids.extend(padding)
# assertions to ensure correct output shapes
assert (
len(input_ids) == max_sequence_length
and len(labels) == max_sequence_length
and len(input_mask) == max_sequence_length
and len(attention_span) == max_sequence_length
and len(position_ids) == max_sequence_length
), "Wrong sequence length"
input_ids = getattr(np, input_ids_dtype)(input_ids)
input_mask = getattr(np, input_mask_dtype)(input_mask)
if inverted_mask:
input_mask = np.equal(input_mask, 0).astype(input_mask.dtype)
labels = getattr(np, labels_dtype)(labels)
attention_span = getattr(np, attention_span_dtype)(attention_span)
position_ids = getattr(np, position_ids_dtype)(position_ids)
return np.stack(
[input_ids, input_mask, labels, attention_span, position_ids]
)
[docs]def create_features_summarization_vsl(
bin,
max_sequence_length,
num_pad,
pad_id=0,
eos_id=0,
sep_id=None,
inverted_mask=False,
input_ids_dtype="int32",
input_mask_dtype="int32",
labels_dtype="int32",
attention_span_dtype="int32",
position_ids_dtype="int32",
):
"""Given a list of VSL sequences, generate input features and labels.
Args:
bin (list(sequence)): list of VSL sequences.
max_sequence_length (int): Maximum sequence length for data writes.
num_pad (int): number of padding tokens in the sequence.
pad_id (int): Id for pad token. Defaults to `0`.
eos_id (int): Id for end of sequence token. Defaults to `0`.
sep_id (int): Id for separator token. Defaults to `None`.
inverted_mask (bool): Invert mask if specified for runtime execution.
Defaults to `False`.
input_ids_dtype (str): Dtype as string for input ids.
Defaults to `int32`.
input_mask_dtype (str): Dtype as string for input mask.
Defaults to `int32`.
labels_dtype (str): Dtype as string for labels. Defaults to `int32`.
attention_span_dtype (str): Dtype as string for keys attention span in VSL.
Defaults to `int32`.
position_ids_dtype (str): Dtype as string for position ids in VSL.
Defaults to `int32`.
Returns:
Tuple containing features and labels
"""
input_ids, input_mask, labels, attention_span, position_ids = (
[],
[],
[],
[],
[],
)
for prompt_ids, completion_ids in bin:
token_ids = prompt_ids
if sep_id is not None:
token_ids = token_ids + [sep_id]
token_ids = token_ids + completion_ids + [eos_id]
token_mask = [0] * len(prompt_ids)
if sep_id is not None:
token_mask += [1]
else:
# if no sep_id, prediction starts at the last token of prompt_ids
token_mask[-1] = 1
token_mask += [1] * len(completion_ids)
input_ids.extend(token_ids[:-1])
labels.extend(token_ids[1:])
input_mask.extend(token_mask)
sample_len = len(token_ids) - 1
attention_span.extend(list(range(sample_len - 1, -1, -1)))
position_ids.extend(list(range(sample_len)))
# padding
padding = [pad_id] * num_pad
input_ids.extend(padding)
labels.extend(padding)
padding = [0] * num_pad
input_mask.extend(padding)
attention_span.extend(padding)
position_ids.extend(padding)
# assertions to ensure correct output shapes
assert (
len(input_ids) == max_sequence_length
and len(labels) == max_sequence_length
and len(input_mask) == max_sequence_length
and len(attention_span) == max_sequence_length
and len(position_ids) == max_sequence_length
), "Wrong sequence length"
input_ids = getattr(np, input_ids_dtype)(input_ids)
input_mask = getattr(np, input_mask_dtype)(input_mask)
if inverted_mask:
input_mask = np.equal(input_mask, 0).astype(input_mask.dtype)
labels = getattr(np, labels_dtype)(labels)
attention_span = getattr(np, attention_span_dtype)(attention_span)
position_ids = getattr(np, position_ids_dtype)(position_ids)
return np.stack(
[input_ids, input_mask, labels, attention_span, position_ids]
)
[docs]def get_files(input_dir=None, filetypes=None, metadata_files=None):
"""Get all files of given filetypes from input directory.
Args:
input_dir (str): Input directory to read files from.
filetypes (list): File types to fetch from the given input
directory. Defaults to `None`.
metadata_files (str): Comma separated string of metadata files.
Returns:
List of lists containing all file paths as strings
"""
if not filetypes:
filetypes = [
'.jsonl',
'.json.gz',
'.jsonl.zst',
'.jsonl.zst.tar',
'.txt',
'.parquet',
]
if isinstance(filetypes, str):
filetypes = [filetypes]
filetypes = tuple(filetypes)
assert input_dir or metadata_files, (
"User need to provide `input_dir` or `metadata_files`, "
"but neither was provided."
)
if metadata_files:
if isinstance(metadata_files, str):
metadata_files = [metadata_files]
if input_dir:
logger.warning(
"Both `input_dir` and `metadata_files` were provided, "
"ignoring `input_dir` and using `metadata_files`."
)
input_files = []
for _file in metadata_files:
with open(_file, "r") as _fin:
input_files.extend(_fin.readlines())
input_files_list = [x.strip() for x in input_files if x]
flattened_list = [x for x in input_files_list if x.endswith(filetypes)]
else:
files = [list(Path(input_dir).rglob(f"*{ft}")) for ft in filetypes]
# flatten list of list -> list and stringify Paths
flattened_list = [str(item) for sublist in files for item in sublist]
if not flattened_list:
raise Exception(
f"Did not find any files at this path {input_dir}, please "
f"ensure your files are in format {filetypes}."
)
return flattened_list
[docs]def wikitext_detokenizer(string):
"""Detokenizer for wikitext. Used for special handling of data for substrings.
Args:
string (str): String to detoknize before tokenization.
Returns:
Detokenized string
"""
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
[docs]def read_checkpoint(checkpoint_path, resume_from_checkpoint=True):
"""Checkpoint reader for execution.
Args:
checkpoint_path (str): Path to read checkpoint data from
resume_from_checkpoint (bool): Resume from checkpoint for execution.
Defaults to `True`.
Returns:
Tuple containing number of files processed and the count of tfrecords/HDF5 files
written to output directory.
"""
if resume_from_checkpoint and os.path.isfile(checkpoint_path):
try:
resume_files_processed, count = [
int(i) for i in open(checkpoint_path, "r").read().split(", ")
]
logger.info(
f"Resuming from file number: {count}, "
f"with raw file number processed: {resume_files_processed}"
)
return resume_files_processed, count
except Exception as e:
# if checkpoint path is at initialization,
# file may exist, but no data might be written in the file
# in that event, do not do anything, go to the final return
logger.error(e)
return 0, 0
# routine to split the text into smaller sequences
[docs]def split_text_and_tokenize(
text, tokenizer, max_tok_len=2000, remove_bos_in_chunks=True
):
"""Function to split the text into smaller sequences of length max_tok_len
and then tokenize each of the smaller sequences. This is done to avoid
performance issues with tokenizers like LlamaTokenizer which are slow for
long sequences.
Args:
text (str): text to be tokenized
tokenizer (Tokenizer): tokenizer to be used
max_tok_len (int, optional): max length of each sequence. Defaults to 2000.
remove_bos_in_chunks (bool, optional): whether to ignore bos token id in
chunks. Defaults to True.
Returns:
tok_ids (list): list of token ids for the text
"""
curr_start = 0
tok_ids = []
while curr_start < len(text):
curr_end = min(text.find(' ', curr_start + max_tok_len), len(text))
if curr_end < 0:
curr_substr = text[curr_start:]
curr_end = len(text)
else:
curr_substr = text[curr_start:curr_end]
if curr_start == 0:
# keep special tokens for the first chunk
bos_token_id = [tokenizer.encode(curr_substr)[0]]
curr_tok_ids = (
tokenizer.encode(curr_substr)[1:]
if remove_bos_in_chunks
else tokenizer.encode(curr_substr)
)
tok_ids.extend(curr_tok_ids)
curr_start = curr_end
# concatenated tok_ids chunks together by using `extend` to return full sequence of tokens
# NOTE: add bos token id if it is needed here, eos id is added in the next line
# which calls this function
return bos_token_id + tok_ids if remove_bos_in_chunks else tok_ids
[docs]def chunk(
sample, tokenizer, fim_rate, spm_rate,
):
"""
Since we do character-level FIM we need to detokenize, determine boundaries
to split, and re-tokenize after splitting. We chunk but do not shuffle and add
special tokens because we might have to truncate or pad the tokens since they
have been split at the character-level and re-tokenized, leading to potentially
different lengths than the original sequence.
If the sub-context is designated to be an AR (auto-regressive) sequence and not FIM, we store
as [[], [], [sequence]] for convenience in the truncate_helper function.
Args:
sample (np.array):
tokenizer (Tokenizer):
fim_rate (float):
spm_rate (float):
Returns:
List[List[int]], str: List of token lists corresponding to the
prefix/middle/suffix tokens, or 2 empty lists plus the whole
sequence in case of auto-regressive (AR) sequence. Also returns
string representing the format of the sequence (i.e. SPM or
PSM or AR)
"""
if np.random.binomial(1, fim_rate): # sample bernoulli dist
contents = tokenizer.decode(sample, skip_special_tokens=False)
try:
# A boundary can be =0 (prefix will be empty)
# a boundary can be =len(contents) (suffix will be empty)
# The two boundaries can be equal (middle will be empty)
boundaries = list(
np.random.randint(low=0, high=len(contents) + 1, size=2)
)
boundaries.sort()
except ValueError as e:
logging.info(len(contents))
logging.info(contents)
logging.info(e)
raise e
prefix = contents[: boundaries[0]]
middle = contents[boundaries[0] : boundaries[1]]
suffix = contents[boundaries[1] :]
prefix = tokenizer.encode(prefix)
middle = tokenizer.encode(middle)
suffix = tokenizer.encode(suffix)
is_spm = np.random.binomial(1, spm_rate)
fim_format = "SPM" if is_spm else "PSM"
return [prefix, middle, suffix], fim_format
else:
# don't do FIM preproc
fim_format = "AR"
return [[], [], sample.tolist()], fim_format
[docs]def truncate_helper(samples_lst, diff, sample_idx):
"""
The goal of our truncation scheme is to avoid removing tokens from the
middle section. We first remove from the end of suffix, and then from the
beginning of the prefix. We store the chunks in lists in the original order
so that we can easily perform this truncation. Since each sub-context can have
different amounts of tokens in suffix/prefix, we store unique indices for the
section to remove from. If we run out of tokens to remove from, we switch to the next.
This way we can switch to the prefix of one context while still removing from suffix
of another. If the sub-context is AR (auto-regressive) and not FIM, the AR sequence
is stored as [[], [], [sequence]] so that the remove_idx being 2 will simultaneously
work for the AR and FIM sequences.
Args:
samples_lst (List[List[int]]): List of lists that contain token ids
diff (int): Number of tokens to pad
sample_idx (int): Index for the sample from the dataset, for use in
logging if we remove from the middle.
Returns:
(List[List[int]]): List of lists of token ids that have been truncated
"""
num_groups = len(samples_lst)
remove_idxs = [2] * num_groups # remove from suffixes first
i = 0
while diff:
remove_idx_i = remove_idxs[i]
sample_i = samples_lst[i]
if sample_i[remove_idx_i]:
pop_idx = (
-1 if remove_idx_i == 2 else 0
) # remove from end of suffix but beginning of prefix
sample_i[remove_idx_i].pop(pop_idx)
diff -= 1
else:
remove_idxs[i] = (
remove_idxs[i] + 1
) % 3 # order of removal is end of suffix, beginning of prefix, then beginning of middle
if remove_idxs[i] == 1:
logging.info(
f"""Context {i} in the {sample_idx}-th data sample has
begun truncating from the middle section, meaning
the prefix and suffix sections have been exhausted.
"""
)
i = (i + 1) % num_groups
return samples_lst
[docs]def pad_helper(samples_lst, diff, fim_pad_tok_id):
"""
Helper for padding. We put all padding tokens into the last sequence.
Args:
samples_lst (List[List[int]]): List of lists that contain token ids
diff (int): Number of tokens to pad
fim_pad_tok_id (int): Id for padding token
Returns:
(List[List[int]]): List of lists of token ids with padding
"""
padding = np.full(np.abs(diff), fim_pad_tok_id)
samples_lst[-1].append(padding)
return samples_lst
[docs]def truncate_or_pad_helper(
segments_fim_format_pairs, diff, fim_pad_tok_id, sample_idx
):
"""
Since we perform FIM at character-level, we potentially split characters
in the middle of a word. This can lead to non-standard token sequences,
and after re-tokenizing we might need to truncate or pad to get back to
the original context length. This function ensures that our outputs are
back at their original length.
Args:
segments_fim_format_pairs (List[Tuple[List[List[int]], str]]): This list of tuples is used
to store the prefix/middle/suffix token-id lists and the corresponding FIM formats (PSM/SPM) to
be used downstream in the FIM formatting.
diff (int): The number of tokens to add or remove. Positive means truncate, negative means pad
fim_pad_tok_id (int): Id of padding token
Returs:
(List[Tuple[List[List[int]], str]]): The element of the tuples will
now be lists that are truncated or padded such that the concatenation of all these tokens, along
with the special tokens, will be equal to the original sequence length.
"""
segments = [pair[0] for pair in segments_fim_format_pairs]
fim_formats = [pair[1] for pair in segments_fim_format_pairs]
if diff >= 0:
segments = truncate_helper(segments, diff, sample_idx)
else:
segments = pad_helper(segments, diff, fim_pad_tok_id)
return [(segments[i], fim_formats[i]) for i in range(len(segments))]
[docs]def fim(
sample_array,
sample_idx,
tokenizer,
fim_rate,
spm_rate,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
fim_pad_tok_id,
eos_tok_id,
opt_bos_tok_id,
):
"""
Takes in an array of input_ids, mask, and labels, and performs the
FIM operation to re-arrange into PSM and SPM format with some probability
Args:
sample_array (np.array): Stack of input_ids, mask, and labels after tokenization. Labels are off-by-one of input_ids
as in standard auto-regressive training
i (int): Index of sample from dataset, used for logging.
tokenizer (Tokenizer): Tokenizer object
fim_rate (float): Determines what percentage of contexts are FIM'ed
spm_rate (float): Determines what percentage of FIM'ed contexts are in SPM format. 1 - spm_rate determines PSM
suffix_tok_id (int): Id for special token denoting suffix section in a FIM'ed context
prefix_tok_id (int): Id for special token denoting prefix section in a FIM'ed context
middle_tok_id (int): Id for special token denoting middle section in a FIM'ed context
fim_pad_tok_id (int): Id for padding
eos_tok_id (int): Id for the end-of-seqence
opt_bos_tok_id (list): Optionally a list containing the bos token id,
otherwise will be empty list. Empty list will be a no-op in the
concatenation. Bos-token will only exist if model's tokenizer adds
bos-token by default.
Returns:
fim_outputs (np.array): Stack of input_ids, mask, and labels after FIM transformation. Mask and labels have been
adjusted to still filter padding tokens and represent the following token, respectively.
"""
assert (
fim_rate <= 1 and fim_rate >= 0
), "FIM rate must be a probability 0 <= rate <= 1"
sample = sample_array[0, :]
mask = sample_array[1, :]
max_seq_len = sample.shape[0]
segment_breaks = np.argwhere(
sample == eos_tok_id
) # split sample by document
segments_fim_format_pairs = []
if segment_breaks.shape != (0, 1): # FIM each sub-context
curr_start_position = 0
for loc in np.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
segments, fim_format = chunk(
sample=sample[curr_start_position:loc],
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
)
segments_fim_format_pairs.append((segments, fim_format))
curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
segments, fim_format = chunk(
sample=sample[curr_start_position:],
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
)
segments_fim_format_pairs.append((segments, fim_format))
else: # FIM over full context
segments, fim_format = chunk(
sample=sample,
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
)
segments_fim_format_pairs.append((segments, fim_format))
def flatten_2d(arr):
return np.concatenate([np.concatenate(subarr) for subarr in arr])
total_len = flatten_2d(
[pair[0] for pair in segments_fim_format_pairs]
).shape[0]
# we factor in the final EOS, which we add before splitting into
# inputs and labels, i.e. sequence[:-1] and sequence[1:], and the
# optional bos token
add_constant = -1
for (_, fmt) in segments_fim_format_pairs:
if fmt == "AR":
add_constant += 1
else:
add_constant += 4
if opt_bos_tok_id:
add_constant += 1
diff = (total_len + add_constant) - max_seq_len
segments_fim_format_pairs = truncate_or_pad_helper(
segments_fim_format_pairs, diff, fim_pad_tok_id, sample_idx,
)
inputs, mask, labels = format_fim(
segments_fim_format_pairs,
max_seq_len,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
eos_tok_id,
opt_bos_tok_id,
)
try:
assert inputs.shape[0] == max_seq_len
assert mask.shape[0] == max_seq_len
assert labels.shape[0] == max_seq_len
except:
logging.error(
"The inputs/masks/labels were not the correct\
sized after FIM process. Shapes of each are printed\
below, along with the correct max seqeunce length\
that each sequence should be."
)
logging.error(inputs.shape, max_seq_len)
logging.error(mask.shape, max_seq_len)
logging.error(labels.shape, max_seq_len)
raise AssertionError
try:
assert labels[-1] == eos_tok_id
except:
logging.error("The sequence did not end with an EOS token")
raise AssertionError
# end FIM-specific code
fim_outputs = np.stack([inputs, mask, labels], axis=0)
return fim_outputs
[docs]def get_tokenizer_vocab(tokenizer):
if isinstance(tokenizer, BPETokenizer):
tokenizer_vocab = tokenizer.encoder
elif isinstance(tokenizer, HFTokenizer):
tokenizer_vocab = tokenizer.tokenizer.get_vocab()
elif isinstance(tokenizer, PreTrainedTokenizer) or isinstance(
tokenizer, PreTrainedTokenizerFast
):
tokenizer_vocab = tokenizer.vocab
else:
raise NotImplementedError(
"We do not support specified tokenizer\
type."
)
return tokenizer_vocab
[docs]def check_fim_special_tokens(params, tokenizer):
# Check that input config lists the FIM special tokens
assert (
"fim_suffix_tok" in params['processing']
and "fim_prefix_tok" in params['processing']
and "fim_middle_tok" in params['processing']
), """Configs for FIM pre-processing must include the special tokens that
denote prefix, middle, and suffix tokens."""
# Check that the provided tokens are in the tokenizer
pre_tok = params['processing'].get("fim_prefix_tok")
mid_tok = params['processing'].get("fim_middle_tok")
suf_tok = params['processing'].get("fim_suffix_tok")
tokenizer_vocab = get_tokenizer_vocab(tokenizer)
assert (
pre_tok in tokenizer_vocab
and mid_tok in tokenizer_vocab
and suf_tok in tokenizer_vocab
), """Please ensure that the provided FIM special tokens are in the
specified tokenizer."""
[docs]def handle_bos_token_default(tokenizer):
"""
When performing FIM, we tokenize each chunk again after splitting.
Therefore, if the tokenizer adds bos-token by default, we will get
extra bos-tokens in the middle of the sequence. In this function,
we set the tokenizer bos default to False, and return a flag that
indicates whether we will need to add bos-token in the final
fim formatting function.
"""
if hasattr(tokenizer, "add_bos_token") and tokenizer.add_bos_token:
tokenizer.add_bos_token = False
bos_tok_id = tokenizer.encode(tokenizer.bos_token)[-1]
return True, [bos_tok_id]
return False, []