# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: write_csv_qa.py
Use to create pre-processed CSV files of SQuAD for various models. Called by {model}/fine_tuning/qa/write_csv_qa.sh with the correct command-line arguments to adjust processing for each model.
"""
import argparse
import os
import random
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../.."))
from modelzoo.common.input.utils import save_params
from modelzoo.transformers.data_processing.qa.qa_utils import (
convert_examples_to_features_and_write,
read_squad_examples,
)
from modelzoo.transformers.data_processing.tokenizers.Tokenization import (
FullTokenizer,
)
[docs]def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
required=True,
help="Directory containing train-v1.1.json",
)
# Note I'm re-using this arg as the path for the sentencepiece pretrained tokenizer '.model' file, but in the BERT usage it is a .txt. Not sure if there should be a separate arg for this or if its fine to preserve more compatibility
parser.add_argument(
"--vocab_file",
required=True,
help="The vocabulary file that the T5 Pretrained model was trained on.",
)
parser.add_argument(
"--data_split_type",
choices=["train", "dev", "all"],
default="all",
help="Dataset split, choose from 'train', 'dev' or 'all'.",
)
parser.add_argument(
"--do_lower_case",
required=False,
action="store_true",
help="Whether to convert tokens to lowercase",
)
parser.add_argument(
"--max_seq_length",
required=False,
type=int,
default=384,
help="The maximum total input sequence length after tokenization.",
)
parser.add_argument(
"--doc_stride",
required=False,
type=int,
default=128,
help="When splitting up a long document into chunks, how much stride to "
"take between chunks.",
)
parser.add_argument(
"--max_query_length",
required=False,
type=int,
default=64,
help="The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.",
)
parser.add_argument(
"--version_2_with_negative",
required=False,
action="store_true",
help="If true, the SQuAD examples contain some that do not have an answer.",
)
parser.add_argument(
"--output_dir",
required=False,
default=os.path.join(
os.path.dirname(os.path.abspath(__file__)), "preprocessed_csv_dir"
),
help="Directory to store pre-processed CSV files.",
)
parser.add_argument(
"--num_output_files",
type=int,
default=16,
help="number of files on disk to separate csv files into. "
"Defaults to 16.",
)
parser.add_argument(
"--tokenizer_scheme",
required=True,
type=str,
help="Specify which tokenization scheme should be used based on the desired model. Currently supports BERT and T5.",
)
args = parser.parse_args()
return args
[docs]def main():
args = parse_args()
print("***** Configuration *****")
for key, val in vars(args).items():
print(' {}: {}'.format(key, val))
print("**************************")
print("")
write_csv_files(args)
[docs]def get_tokenizer_fns(args):
if args.tokenizer_scheme == 'bert':
tokenizer = FullTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case
)
tokenize_fn = tokenizer.tokenize
convert_tokens_to_ids_fn = tokenizer.convert_tokens_to_ids
elif args.tokenizer_scheme == 't5':
import sentencepiece as spm
tokenizer = spm.SentencePieceProcessor()
tokenizer.load(args.vocab_file)
tokenize_fn = tokenizer.encode_as_pieces
convert_tokens_to_ids_fn = tokenizer.piece_to_id
else:
raise ValueError("Tokenization scheme for this model not supported")
return tokenize_fn, convert_tokens_to_ids_fn
[docs]def write_csv_files(args):
task_name = os.path.basename(args.data_dir.lower())
output_dir = os.path.abspath(args.output_dir)
rng = random.Random(12345)
tokenize_fn, convert_tokens_to_ids_fn = get_tokenizer_fns(args)
to_write = [args.data_split_type]
if args.data_split_type == "all":
to_write = ["train", "dev"]
num_examples_dict = dict()
for data_split_type in to_write:
data_split_type_dir = os.path.join(output_dir, data_split_type)
if not os.path.exists(data_split_type_dir):
os.makedirs(data_split_type_dir)
if data_split_type == "train":
input_fn = "train-v1.1.json"
file_prefix = "train-v1.1"
elif data_split_type == "dev":
input_fn = "dev-v1.1.json"
file_prefix = "dev-v1.1"
else:
assert False, "Unknown data_split_type: %s" % args.data_split_type
input_file = os.path.join(args.data_dir, input_fn)
examples = read_squad_examples(
input_file=input_file,
is_training=True,
version_2_with_negative=args.version_2_with_negative,
)
rng.shuffle(examples)
(
num_examples_written,
meta_data,
) = convert_examples_to_features_and_write(
examples=examples,
tokenize_fn=tokenize_fn,
convert_tokens_to_ids_fn=convert_tokens_to_ids_fn,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
output_dir=data_split_type_dir,
file_prefix=file_prefix,
num_output_files=args.num_output_files,
tokenizer_scheme=args.tokenizer_scheme,
is_training=True,
)
num_examples_dict[data_split_type] = num_examples_written
meta_file = os.path.join(data_split_type_dir, "meta.dat")
with open(meta_file, "w") as fout:
for output_file, num_lines in meta_data.items():
fout.write("%s %s\n" % (output_file, num_lines))
# Write args passed and number of examples
args_dict = vars(args)
args_dict["num_examples"] = num_examples_dict
save_params(args_dict, model_dir=args.output_dir)
if __name__ == "__main__":
main()