# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import json
import logging
import os
import re
import sys
from pathlib import Path
import numpy as np
import yaml
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cerebras.modelzoo.data_preparation.nlp.tokenizers.BPETokenizer import (
BPETokenizer,
)
from cerebras.modelzoo.data_preparation.nlp.tokenizers.HFTokenizer import (
HFTokenizer,
)
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
## Added .parquet extension to the list of valid extensions
VALID_EXTENSIONS = [
'.jsonl',
'.jsonl.zst',
'.jsonl.zst.tar',
'.txt',
'.json.gz',
'.parquet',
'.fasta',
]
SYSTEM_PROMPT_REGISTRY = {
"zephyr": "<|system|>\n</s>",
"vicuna_v0": (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
"vicuna_v1": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
"llava_plain": "",
"llava_v0": (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
"llava_v1": (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
"mistral_instruct": "",
}
[docs]def has_valid_extension(file):
return any([file.endswith(ext) for ext in VALID_EXTENSIONS])
def _listdir_or_file(x):
if isinstance(x, list):
return reduce(lambda x, y: x + y, map(listdir_or_file, sorted(x)))
if os.path.isfile(x):
return [x]
elif os.path.isdir(x):
return [str(Path(x) / fn) for fn in sorted(os.listdir(x))]
else:
raise FileNotFoundError(f"{x} not found")
[docs]def listdir_or_file(x):
return list(filter(has_valid_extension, _listdir_or_file(x)))
[docs]def dump_result(
results,
json_params_file,
eos_id=None,
pad_id=None,
vocab_size=None,
):
"""
Write outputs of execution
"""
with open(json_params_file, "r") as _fin:
data = json.load(_fin)
post_process = {}
post_process["discarded_files"] = results.pop("discarded", 0)
post_process["processed_files"] = results.pop("processed", 0)
post_process["successful_files"] = results.pop("successful", 0)
post_process["n_examples"] = results.pop("examples", 0)
post_process["raw_chars_count"] = results.pop("raw_chars_count", 0)
post_process["raw_bytes_count"] = results.pop("raw_bytes_count", 0)
## put remaining key,value pairs in post process
for key, value in results.items():
post_process[key] = value
if eos_id is not None:
post_process["eos_id"] = eos_id
if pad_id is not None:
post_process["pad_id"] = pad_id
if vocab_size is not None:
post_process["vocab_size"] = vocab_size
data["post-process"] = post_process
with open(json_params_file, "w") as _fout:
json.dump(data, _fout, indent=4, sort_keys=True)
def dump_args(args, json_params_file):
"""
Write the input params to file.
"""
logger.info(f"User arguments can be found at {json_params_file}.")
redundant_params = [
"eos_id",
"pad_id",
"display_pbar",
"files_per_record",
"output_name",
"write_remainder",
]
relevant_args = copy.deepcopy(args)
# Iterate through the dictionary and remove the redundant params
for key in redundant_params:
for sub_dict in relevant_args.values():
if key in sub_dict:
del sub_dict[key]
# write initial params to file
with open(json_params_file, "w") as _fout:
json.dump(args, _fout, indent=4, sort_keys=True)
[docs]def get_parser(desc):
"""Argparser definition for command line arguments from user.
Returns:
Argparse namespace object with command line arguments.
"""
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to the YAML config file for setting dataset preprocessing hyper-parameters.",
)
return parser.parse_args()
[docs]def update_params(params, args):
"""
Update config parameters with CLI arguments
"""
setup_params = [
"data",
"metadata_files",
"output_dir",
"processes",
"module",
"dataset_processor",
"mode",
]
processing_params = [
"tokenizer_type",
"huggingface_tokenizer",
"vocab_file",
"encoder_file",
"eos_id",
"pad_id",
"split_text_to_tokenize",
"chunk_len_to_split",
"remove_bos_in_chunks",
"max_seq_length",
"short_seq_prob",
"output_name",
"files_per_record",
"write_in_batch",
"resume_from_checkpoint",
"seed",
"fim_rate",
"spm_rate",
"fim_prefix_tok",
"fim_middle_tok",
"fim_suffix_tok",
"auth_token",
"max_chunk_size",
"shuffle",
"shuffle_seed",
"fraction_of_RAM_alloted",
"drop_input",
"loss_mask_weight",
"chat_template",
"multimodal_mode",
]
dataset_params = [
"use_ftfy",
"ftfy_normalizer",
"wikitext_detokenize",
"jsonl_key",
"pack_sequences",
"min_sequence_len",
"input_ids_dtype",
"input_mask_dtype",
"inverted_mask",
"prompt_key",
"completion_key",
"sep_token",
"fold_long_doc",
"seq_lengths_dtype",
"chosen_key",
"rejected_key",
"user_role",
"assistant_role",
"chat_template",
"respose_delimiter",
"prompt_prefix",
"completion_prefix",
"eos_after_prompt",
"multi_turn_key",
"multi_turn_content_key",
"image_key",
"image_token",
"multi_modal_non_image_ex_key",
"image_dir",
"num_patches",
"system_prompt_style",
]
for key, value in args.items():
if value in ["True", "False"]:
value = value == "True"
if value is not None:
if key in setup_params:
params["setup"][key] = value
elif key in processing_params:
params["processing"][key] = value
elif key in dataset_params:
params["dataset"][key] = value
else:
raise ValueError(f"Unexpected arguments: {key}")
set_defaults(params)
[docs]def set_defaults(params):
params["processing"]["eos_id"] = params["processing"].get("eos_id")
params["processing"]["pad_id"] = params["processing"].get("pad_id")
params["processing"]["split_text_to_tokenize"] = params["processing"].get(
"split_text_to_tokenize", False
)
params["processing"]["chunk_len_to_split"] = params["processing"].get(
"chunk_len_to_split", 2000
)
params["processing"]["remove_bos_in_chunks"] = params["processing"].get(
"remove_bos_in_chunks", False
)
params["processing"]["write_in_batch"] = params["processing"].get(
"write_in_batch", True
)
params["processing"]["resume_from_checkpoint"] = params["processing"].get(
"resume_from_checkpoint", False
)
params["processing"]["auth_token"] = params["processing"].get(
"auth_token", None
)
params["dataset"]["use_ftfy"] = params["dataset"].get("use_ftfy", True)
params["dataset"]["ftfy_normalizer"] = params["dataset"].get(
"ftfy_normalizer", "NFC"
)
params["dataset"]["wikitext_detokenize"] = params["dataset"].get(
"wikitext_detokenize", False
)
[docs]def get_params(desc):
"""Retrieve configuration parameters
Returns:
params (Dict): Dictionary contains the parameters used to configure
the data processing.
"""
args = get_parser(desc)
args = vars(args)
params_file = args.pop("config", None)
if params_file:
with open(params_file, 'r') as stream:
params = yaml.safe_load(stream)
else:
params = {}
for section in ["setup", "processing", "dataset"]:
if not params.get(section, None):
params[section] = {}
update_params(params, args)
return params
[docs]def dump_args(args, json_params_file):
"""
Write the input params to file.
"""
# write initial params to file
with open(json_params_file, "w") as _fout:
json.dump(args, _fout, indent=4, sort_keys=True)
[docs]def validate_tokens(tokens, min_len=2):
is_valid = len(tokens) >= min_len
if not is_valid:
logger.warning(
f"token_ids must have at least {min_len} elements, skipping this example..."
)
return is_valid
[docs]def create_features_auto_lm_vsl(
bin,
max_sequence_length,
num_pad,
pad_id=0,
inverted_mask=False,
input_ids_dtype="int32",
input_mask_dtype="int32",
labels_dtype="int32",
attention_span_dtype="int32",
position_ids_dtype="int32",
):
"""Given a list of VSL sequences, generate input features and labels.
Args:
bin (list(sequence)): list of VSL sequences.
max_sequence_length (int): Maximum sequence length for data writes.
num_pad (int): number of padding tokens in the sequence.
pad_id (int): Id for pad token. Defaults to `0`.
inverted_mask (bool): Invert mask if specified for runtime execution.
Defaults to `False`.
input_ids_dtype (str): Dtype as string for input ids.
Defaults to `int32`.
input_mask_dtype (str): Dtype as string for input mask.
Defaults to `int32`.
labels_dtype (str): Dtype as string for labels. Defaults to `int32`.
attention_span_dtype (str): Dtype as string for keys attention span in VSL.
Defaults to `int32`.
position_ids_dtype (str): Dtype as string for position ids and
attention span in VSL. Defaults to `int32`.
Returns:
Tuple containing features and labels
"""
input_ids, labels, attention_span, position_ids = [], [], [], []
input_mask = []
for i, sample in enumerate(bin):
input_ids.extend(sample)
labels.extend(sample)
sample_len = len(sample)
if i == len(bin) - 1:
attention_span.extend(list(range(sample_len - 2, -1, -1)))
position_ids.extend(list(range(sample_len - 1)))
input_mask.extend([1] * (sample_len - 1))
else:
attention_span.extend(list(range(sample_len - 1, -1, -1)))
position_ids.extend(list(range(sample_len)))
input_mask.extend(
[1] * (sample_len - 1) + [0]
) ## The separator should have 0 as eos token
input_ids = input_ids[:-1]
labels = labels[1:]
# padding
num_pad = max_sequence_length - len(input_ids)
padding = [pad_id] * num_pad
input_ids.extend(padding)
labels.extend(padding)
padding = [0] * num_pad
input_mask.extend(padding)
attention_span.extend(padding)
position_ids.extend(padding)
# assertions to ensure correct output shapes
assert (
len(input_ids) == max_sequence_length
and len(labels) == max_sequence_length
and len(input_mask) == max_sequence_length
and len(attention_span) == max_sequence_length
and len(position_ids) == max_sequence_length
), "Wrong sequence length"
input_ids = getattr(np, input_ids_dtype)(input_ids)
input_mask = getattr(np, input_mask_dtype)(input_mask)
if inverted_mask:
input_mask = np.equal(input_mask, 0).astype(input_mask.dtype)
labels = getattr(np, labels_dtype)(labels)
attention_span = getattr(np, attention_span_dtype)(attention_span)
position_ids = getattr(np, position_ids_dtype)(position_ids)
return np.stack(
[input_ids, input_mask, labels, attention_span, position_ids]
)
[docs]def get_files(input_dir=None, filetypes=None, metadata_files=None):
"""Get all files of given filetypes from input directory.
Args:
input_dir (str): Input directory to read files from.
filetypes (list): File types to fetch from the given input
directory. Defaults to `None`.
metadata_files (str): Comma separated string of metadata files.
Returns:
List of lists containing all file paths as strings
"""
if not filetypes:
filetypes = [
'.jsonl',
'.json.gz',
'.jsonl.zst',
'.jsonl.zst.tar',
'.txt',
'.parquet',
'.fasta',
]
if isinstance(filetypes, str):
filetypes = [filetypes]
filetypes = tuple(filetypes)
assert input_dir or metadata_files, (
"User need to provide `input_dir` or `metadata_files`, "
"but neither was provided."
)
if metadata_files:
if isinstance(metadata_files, str):
metadata_files = [metadata_files]
if input_dir:
logger.warning(
"Both `input_dir` and `metadata_files` were provided, "
"ignoring `input_dir` and using `metadata_files`."
)
input_files = []
for _file in metadata_files:
with open(_file, "r") as _fin:
input_files.extend(_fin.readlines())
input_files_list = [x.strip() for x in input_files if x]
flattened_list = [x for x in input_files_list if x.endswith(filetypes)]
else:
files = [list(Path(input_dir).rglob(f"*{ft}")) for ft in filetypes]
# flatten list of list -> list and stringify Paths
flattened_list = [str(item) for sublist in files for item in sublist]
if not flattened_list:
raise Exception(
f"Did not find any files at this path {input_dir}, please "
f"ensure your files are in format {filetypes}."
)
return flattened_list
[docs]def wikitext_detokenizer(string):
"""Detokenizer for wikitext. Used for special handling of data for substrings.
Args:
string (str): String to detoknize before tokenization.
Returns:
Detokenized string
"""
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
# routine to split the text into smaller sequences
[docs]def split_text_and_tokenize(
text, tokenizer, max_tok_len=2000, remove_bos_in_chunks=True
):
"""Function to split the text into smaller sequences of length max_tok_len
and then tokenize each of the smaller sequences. This is done to avoid
performance issues with tokenizers like LlamaTokenizer which are slow for
long sequences.
Args:
text (str): text to be tokenized
tokenizer (Tokenizer): tokenizer to be used
max_tok_len (int, optional): max length of each sequence. Defaults to 2000.
remove_bos_in_chunks (bool, optional): whether to ignore bos token id in
chunks. Defaults to True.
Returns:
tok_ids (list): list of token ids for the text
"""
if len(text) == 0:
return []
curr_start = 0
tok_ids = []
while curr_start < len(text):
curr_end = min(text.find(' ', curr_start + max_tok_len), len(text))
if curr_end < 0:
curr_substr = text[curr_start:]
curr_end = len(text)
else:
curr_substr = text[curr_start:curr_end]
if curr_start == 0:
# keep special tokens for the first chunk
bos_token_id = [tokenizer.encode(curr_substr)[0]]
curr_tok_ids = (
tokenizer.encode(curr_substr)[1:]
if remove_bos_in_chunks
else tokenizer.encode(curr_substr)
)
tok_ids.extend(curr_tok_ids)
curr_start = curr_end
# concatenated tok_ids chunks together by using `extend` to return full sequence of tokens
# NOTE: add bos token id if it is needed here, eos id is added in the next line
# which calls this function
return bos_token_id + tok_ids if remove_bos_in_chunks else tok_ids
[docs]def chunk(
sample,
tokenizer,
fim_rate,
spm_rate,
):
"""
Since we do character-level FIM we need to detokenize, determine boundaries
to split, and re-tokenize after splitting. We chunk but do not shuffle and add
special tokens because we might have to truncate or pad the tokens since they
have been split at the character-level and re-tokenized, leading to potentially
different lengths than the original sequence.
If the sub-context is designated to be an AR (auto-regressive) sequence and not FIM, we store
as [[], [], [sequence]] for convenience in the truncate_helper function.
Args:
sample (np.array):
tokenizer (Tokenizer):
fim_rate (float):
spm_rate (float):
Returns:
List[List[int]], str: List of token lists corresponding to the
prefix/middle/suffix tokens, or 2 empty lists plus the whole
sequence in case of auto-regressive (AR) sequence. Also returns
string representing the format of the sequence (i.e. SPM or
PSM or AR)
"""
if np.random.binomial(1, fim_rate): # sample bernoulli dist
contents = tokenizer.decode(sample, skip_special_tokens=False)
try:
# A boundary can be =0 (prefix will be empty)
# a boundary can be =len(contents) (suffix will be empty)
# The two boundaries can be equal (middle will be empty)
boundaries = list(
np.random.randint(low=0, high=len(contents) + 1, size=2)
)
boundaries.sort()
except ValueError as e:
logging.info(len(contents))
logging.info(contents)
logging.info(e)
raise e
prefix = contents[: boundaries[0]]
middle = contents[boundaries[0] : boundaries[1]]
suffix = contents[boundaries[1] :]
prefix = tokenizer.encode(prefix)
middle = tokenizer.encode(middle)
suffix = tokenizer.encode(suffix)
is_spm = np.random.binomial(1, spm_rate)
fim_format = "SPM" if is_spm else "PSM"
return [prefix, middle, suffix], fim_format
else:
# don't do FIM preproc
fim_format = "AR"
return [[], [], sample.tolist()], fim_format
[docs]def truncate_helper(samples_lst, diff, sample_idx):
"""
The goal of our truncation scheme is to avoid removing tokens from the
middle section. We first remove from the end of suffix, and then from the
beginning of the prefix. We store the chunks in lists in the original order
so that we can easily perform this truncation. Since each sub-context can have
different amounts of tokens in suffix/prefix, we store unique indices for the
section to remove from. If we run out of tokens to remove from, we switch to the next.
This way we can switch to the prefix of one context while still removing from suffix
of another. If the sub-context is AR (auto-regressive) and not FIM, the AR sequence
is stored as [[], [], [sequence]] so that the remove_idx being 2 will simultaneously
work for the AR and FIM sequences.
Args:
samples_lst (List[List[int]]): List of lists that contain token ids
diff (int): Number of tokens to pad
sample_idx (int): Index for the sample from the dataset, for use in
logging if we remove from the middle.
Returns:
(List[List[int]]): List of lists of token ids that have been truncated
"""
num_groups = len(samples_lst)
remove_idxs = [2] * num_groups # remove from suffixes first
i = 0
while diff:
remove_idx_i = remove_idxs[i]
sample_i = samples_lst[i]
if sample_i[remove_idx_i]:
pop_idx = (
-1 if remove_idx_i == 2 else 0
) # remove from end of suffix but beginning of prefix
sample_i[remove_idx_i].pop(pop_idx)
diff -= 1
else:
remove_idxs[i] = (
remove_idxs[i] + 1
) % 3 # order of removal is end of suffix, beginning of prefix, then beginning of middle
if remove_idxs[i] == 1:
logging.info(
f"""Context {i} in the {sample_idx}-th data sample has
begun truncating from the middle section, meaning
the prefix and suffix sections have been exhausted.
"""
)
i = (i + 1) % num_groups
return samples_lst
[docs]def pad_helper(samples_lst, diff, fim_pad_tok_id):
"""
Helper for padding. We put all padding tokens into the last sequence.
Args:
samples_lst (List[List[int]]): List of lists that contain token ids
diff (int): Number of tokens to pad
fim_pad_tok_id (int): Id for padding token
Returns:
(List[List[int]]): List of lists of token ids with padding
"""
padding = np.full(np.abs(diff), fim_pad_tok_id)
samples_lst[-1].append(padding)
return samples_lst
[docs]def truncate_or_pad_helper(
segments_fim_format_pairs, diff, fim_pad_tok_id, sample_idx
):
"""
Since we perform FIM at character-level, we potentially split characters
in the middle of a word. This can lead to non-standard token sequences,
and after re-tokenizing we might need to truncate or pad to get back to
the original context length. This function ensures that our outputs are
back at their original length.
Args:
segments_fim_format_pairs (List[Tuple[List[List[int]], str]]): This list of tuples is used
to store the prefix/middle/suffix token-id lists and the corresponding FIM formats (PSM/SPM) to
be used downstream in the FIM formatting.
diff (int): The number of tokens to add or remove. Positive means truncate, negative means pad
fim_pad_tok_id (int): Id of padding token
Returs:
(List[Tuple[List[List[int]], str]]): The element of the tuples will
now be lists that are truncated or padded such that the concatenation of all these tokens, along
with the special tokens, will be equal to the original sequence length.
"""
segments = [pair[0] for pair in segments_fim_format_pairs]
fim_formats = [pair[1] for pair in segments_fim_format_pairs]
if diff >= 0:
segments = truncate_helper(segments, diff, sample_idx)
else:
segments = pad_helper(segments, diff, fim_pad_tok_id)
return [(segments[i], fim_formats[i]) for i in range(len(segments))]
[docs]def fim(
sample_array,
sample_idx,
tokenizer,
fim_rate,
spm_rate,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
fim_pad_tok_id,
eos_tok_id,
opt_bos_tok_id,
):
"""
Takes in an array of input_ids, mask, and labels, and performs the
FIM operation to re-arrange into PSM and SPM format with some probability
Args:
sample_array (np.array): Stack of input_ids, mask, and labels after tokenization. Labels are off-by-one of input_ids
as in standard auto-regressive training
i (int): Index of sample from dataset, used for logging.
tokenizer (Tokenizer): Tokenizer object
fim_rate (float): Determines what percentage of contexts are FIM'ed
spm_rate (float): Determines what percentage of FIM'ed contexts are in SPM format. 1 - spm_rate determines PSM
suffix_tok_id (int): Id for special token denoting suffix section in a FIM'ed context
prefix_tok_id (int): Id for special token denoting prefix section in a FIM'ed context
middle_tok_id (int): Id for special token denoting middle section in a FIM'ed context
fim_pad_tok_id (int): Id for padding
eos_tok_id (int): Id for the end-of-seqence
opt_bos_tok_id (list): Optionally a list containing the bos token id,
otherwise will be empty list. Empty list will be a no-op in the
concatenation. Bos-token will only exist if model's tokenizer adds
bos-token by default.
Returns:
fim_outputs (np.array): Stack of input_ids, mask, and labels after FIM transformation. Mask and labels have been
adjusted to still filter padding tokens and represent the following token, respectively.
"""
assert (
fim_rate <= 1 and fim_rate >= 0
), "FIM rate must be a probability 0 <= rate <= 1"
sample = sample_array[0, :]
mask = sample_array[1, :]
max_seq_len = sample.shape[0]
segment_breaks = np.argwhere(
sample == eos_tok_id
) # split sample by document
segments_fim_format_pairs = []
if segment_breaks.shape != (0, 1): # FIM each sub-context
curr_start_position = 0
for loc in np.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
segments, fim_format = chunk(
sample=sample[curr_start_position:loc],
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
)
segments_fim_format_pairs.append((segments, fim_format))
curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
segments, fim_format = chunk(
sample=sample[curr_start_position:],
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
)
segments_fim_format_pairs.append((segments, fim_format))
else: # FIM over full context
segments, fim_format = chunk(
sample=sample,
tokenizer=tokenizer,
fim_rate=fim_rate,
spm_rate=spm_rate,
)
segments_fim_format_pairs.append((segments, fim_format))
def flatten_2d(arr):
return np.concatenate([np.concatenate(subarr) for subarr in arr])
total_len = flatten_2d(
[pair[0] for pair in segments_fim_format_pairs]
).shape[0]
# we factor in the final EOS, which we add before splitting into
# inputs and labels, i.e. sequence[:-1] and sequence[1:], and the
# optional bos token
add_constant = -1
for _, fmt in segments_fim_format_pairs:
if fmt == "AR":
add_constant += 1
else:
add_constant += 4
if opt_bos_tok_id:
add_constant += 1
diff = (total_len + add_constant) - max_seq_len
segments_fim_format_pairs = truncate_or_pad_helper(
segments_fim_format_pairs,
diff,
fim_pad_tok_id,
sample_idx,
)
inputs, mask, labels = format_fim(
segments_fim_format_pairs,
max_seq_len,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
eos_tok_id,
opt_bos_tok_id,
)
try:
assert inputs.shape[0] == max_seq_len
assert mask.shape[0] == max_seq_len
assert labels.shape[0] == max_seq_len
except:
logging.error(
"The inputs/masks/labels were not the correct\
sized after FIM process. Shapes of each are printed\
below, along with the correct max seqeunce length\
that each sequence should be."
)
logging.error(inputs.shape, max_seq_len)
logging.error(mask.shape, max_seq_len)
logging.error(labels.shape, max_seq_len)
raise AssertionError
try:
assert labels[-1] == eos_tok_id
except:
logging.error("The sequence did not end with an EOS token")
raise AssertionError
# end FIM-specific code
fim_outputs = np.stack([inputs, mask, labels], axis=0)
return fim_outputs
[docs]def get_tokenizer_vocab(tokenizer):
if isinstance(tokenizer, BPETokenizer):
tokenizer_vocab = tokenizer.encoder
elif isinstance(tokenizer, HFTokenizer):
tokenizer_vocab = tokenizer.tokenizer.get_vocab()
elif isinstance(tokenizer, PreTrainedTokenizer) or isinstance(
tokenizer, PreTrainedTokenizerFast
):
tokenizer_vocab = tokenizer.vocab
else:
raise NotImplementedError(
"We do not support specified tokenizer\
type."
)
return tokenizer_vocab
[docs]def check_fim_special_tokens(params, tokenizer):
# Check that input config lists the FIM special tokens
assert (
"fim_suffix_tok" in params['processing']
and "fim_prefix_tok" in params['processing']
and "fim_middle_tok" in params['processing']
), """Configs for FIM pre-processing must include the special tokens that
denote prefix, middle, and suffix tokens."""
# Check that the provided tokens are in the tokenizer
pre_tok = params['processing'].get("fim_prefix_tok")
mid_tok = params['processing'].get("fim_middle_tok")
suf_tok = params['processing'].get("fim_suffix_tok")
tokenizer_vocab = get_tokenizer_vocab(tokenizer)
assert (
pre_tok in tokenizer_vocab
and mid_tok in tokenizer_vocab
and suf_tok in tokenizer_vocab
), """Please ensure that the provided FIM special tokens are in the
specified tokenizer."""
[docs]def handle_bos_token_default(tokenizer):
"""
When performing FIM, we tokenize each chunk again after splitting.
Therefore, if the tokenizer adds bos-token by default, we will get
extra bos-tokens in the middle of the sequence. In this function,
we set the tokenizer bos default to False, and return a flag that
indicates whether we will need to add bos-token in the final
fim formatting function.
"""
if hasattr(tokenizer, "add_bos_token") and tokenizer.add_bos_token:
tokenizer.add_bos_token = False
bos_tok_id = tokenizer.encode(tokenizer.bos_token)[-1]
return True, [bos_tok_id]
return False, []
[docs]def get_size(obj, seen=None):
"""Recursively finds size of objects"""
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
# Important mark as seen *before* entering recursion to gracefully handle
# self-referential objects
seen.add(obj_id)
if isinstance(obj, dict):
size += sum([get_size(v, seen) for v in obj.values()])
size += sum([get_size(k, seen) for k in obj.keys()])
elif hasattr(obj, '__dict__'):
size += get_size(obj.__dict__, seen)
elif hasattr(obj, '__iter__') and not isinstance(
obj, (str, bytes, bytearray)
):
size += sum([get_size(i, seen) for i in obj])
return size
[docs]def append_eos_to_multiple_semantic_regions(
formatted_data,
data_ranges,
eos_token,
image_token,
is_chat_data,
):
if data_ranges == [] or not eos_token:
return data_ranges
eos_indices = []
start_search_index = data_ranges[0].get("indices")[0]
while start_search_index < len(formatted_data):
eos_start_idx = formatted_data.find(eos_token, start_search_index)
if eos_start_idx == -1:
## No eos found. Break
break
eos_end_idx = eos_start_idx + len(eos_token)
start_search_index = eos_end_idx
eos_indices.append((eos_start_idx, eos_end_idx))
current_eos_pos = 0
current_data_range_pos = 0
while current_eos_pos < len(eos_indices) and current_data_range_pos < len(
data_ranges
):
## Check if eos occurs between current and next region
eos_start_idx, eos_end_idx = eos_indices[current_eos_pos]
region_start_idx, region_end_idx = data_ranges[
current_data_range_pos
].get("indices")
if current_data_range_pos + 1 < len(data_ranges):
next_region_start_idx, next_region_end_idx = data_ranges[
current_data_range_pos + 1
].get("indices")
if region_end_idx <= eos_start_idx < next_region_start_idx:
image_start_idx = (
-1
if image_token is None
else formatted_data[region_end_idx:eos_start_idx].find(
image_token
)
)
if image_start_idx == -1:
indices_incl_eos = (region_start_idx, eos_end_idx)
data_ranges[current_data_range_pos][
"indices"
] = indices_incl_eos
current_eos_pos += 1
else:
## insert eos in the last region
image_start_idx = (
-1
if image_token is None
else formatted_data[region_end_idx:eos_start_idx].find(
image_token
)
)
if image_start_idx == -1:
indices_incl_eos = (region_start_idx, eos_end_idx)
data_ranges[current_data_range_pos][
"indices"
] = indices_incl_eos
current_eos_pos += 1
current_data_range_pos += 1
if (
not is_chat_data or len(eos_indices) > 1
): ## 1 because the last eot could be eos
return data_ranges
for i in range(1, len(data_ranges)):
start_idx, end_idx = data_ranges[i].get("indices")
previous_start_idx, previous_end_idx = data_ranges[i - 1].get("indices")
if previous_end_idx != start_idx:
handle_turn_token = True
data_ranges[i - 1]["handle_turn_token"] = True
if i == len(data_ranges) - 1:
if end_idx < len(formatted_data):
data_ranges[i]["handle_turn_token"] = True
return data_ranges
[docs]def find_token_range(region, offsets):
string_start, string_end = region.pop('indices')
token_start = next(
(
i
for i, (offset_start, offset_end) in enumerate(offsets)
if (offset_start <= string_start and offset_end > string_start)
or (
offset_start > string_start
) ## this condition is useful for neox tokenizer which treats space as an additional token
),
None,
)
if token_start is None:
raise ValueError(
f"The implementation of offset mapping of this tokenizer may be incorrect. Check the huggingface implementation for more details."
)
token_end = next(
(
i
for i, (offset_start, offset_end) in enumerate(offsets)
if offset_end >= string_end and offset_start < string_end
),
None,
)
if token_end is None:
raise ValueError(
f"The huggingface implementation of offset mapping of this tokenizer may be incorrect. Check the huggingface implementation for more details."
)
data = {
"indices": (token_start, token_end + 1),
"loss_weight": region.get("loss_weight"),
"attention_mask": region.get("attention_mask"),
}
return data
## _text _ ,text
## formatted_data = prompt_completion
## tokens = tokenizer.tokenize(formatted_data)
## tokens -> formatted_data