Source code for modelzoo.transformers.data_processing.qa.write_csv_qa

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
File: write_csv_qa.py

Use to create pre-processed CSV files of SQuAD for various models. Called by {model}/fine_tuning/qa/write_csv_qa.sh with the correct command-line arguments to adjust processing for each model. 
"""

import argparse
import os
import random
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../.."))
from modelzoo.common.input.utils import save_params
from modelzoo.transformers.data_processing.qa.qa_utils import (
    convert_examples_to_features_and_write,
    read_squad_examples,
)
from modelzoo.transformers.data_processing.tokenizers.Tokenization import (
    FullTokenizer,
)


[docs]def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", required=True, help="Directory containing train-v1.1.json", ) # Note I'm re-using this arg as the path for the sentencepiece pretrained tokenizer '.model' file, but in the BERT usage it is a .txt. Not sure if there should be a separate arg for this or if its fine to preserve more compatibility parser.add_argument( "--vocab_file", required=True, help="The vocabulary file that the T5 Pretrained model was trained on.", ) parser.add_argument( "--data_split_type", choices=["train", "dev", "all"], default="all", help="Dataset split, choose from 'train', 'dev' or 'all'.", ) parser.add_argument( "--do_lower_case", required=False, action="store_true", help="Whether to convert tokens to lowercase", ) parser.add_argument( "--max_seq_length", required=False, type=int, default=384, help="The maximum total input sequence length after tokenization.", ) parser.add_argument( "--doc_stride", required=False, type=int, default=128, help="When splitting up a long document into chunks, how much stride to " "take between chunks.", ) parser.add_argument( "--max_query_length", required=False, type=int, default=64, help="The maximum number of tokens for the question. Questions longer than " "this will be truncated to this length.", ) parser.add_argument( "--version_2_with_negative", required=False, action="store_true", help="If true, the SQuAD examples contain some that do not have an answer.", ) parser.add_argument( "--output_dir", required=False, default=os.path.join( os.path.dirname(os.path.abspath(__file__)), "preprocessed_csv_dir" ), help="Directory to store pre-processed CSV files.", ) parser.add_argument( "--num_output_files", type=int, default=16, help="number of files on disk to separate csv files into. " "Defaults to 16.", ) parser.add_argument( "--tokenizer_scheme", required=True, type=str, help="Specify which tokenization scheme should be used based on the desired model. Currently supports BERT and T5.", ) args = parser.parse_args() return args
[docs]def main(): args = parse_args() print("***** Configuration *****") for key, val in vars(args).items(): print(' {}: {}'.format(key, val)) print("**************************") print("") write_csv_files(args)
[docs]def get_tokenizer_fns(args): if args.tokenizer_scheme == 'bert': tokenizer = FullTokenizer( vocab_file=args.vocab_file, do_lower_case=args.do_lower_case ) tokenize_fn = tokenizer.tokenize convert_tokens_to_ids_fn = tokenizer.convert_tokens_to_ids elif args.tokenizer_scheme == 't5': import sentencepiece as spm tokenizer = spm.SentencePieceProcessor() tokenizer.load(args.vocab_file) tokenize_fn = tokenizer.encode_as_pieces convert_tokens_to_ids_fn = tokenizer.piece_to_id else: raise ValueError("Tokenization scheme for this model not supported") return tokenize_fn, convert_tokens_to_ids_fn
[docs]def write_csv_files(args): task_name = os.path.basename(args.data_dir.lower()) output_dir = os.path.abspath(args.output_dir) rng = random.Random(12345) tokenize_fn, convert_tokens_to_ids_fn = get_tokenizer_fns(args) to_write = [args.data_split_type] if args.data_split_type == "all": to_write = ["train", "dev"] num_examples_dict = dict() for data_split_type in to_write: data_split_type_dir = os.path.join(output_dir, data_split_type) if not os.path.exists(data_split_type_dir): os.makedirs(data_split_type_dir) if data_split_type == "train": input_fn = "train-v1.1.json" file_prefix = "train-v1.1" elif data_split_type == "dev": input_fn = "dev-v1.1.json" file_prefix = "dev-v1.1" else: assert False, "Unknown data_split_type: %s" % args.data_split_type input_file = os.path.join(args.data_dir, input_fn) examples = read_squad_examples( input_file=input_file, is_training=True, version_2_with_negative=args.version_2_with_negative, ) rng.shuffle(examples) ( num_examples_written, meta_data, ) = convert_examples_to_features_and_write( examples=examples, tokenize_fn=tokenize_fn, convert_tokens_to_ids_fn=convert_tokens_to_ids_fn, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, output_dir=data_split_type_dir, file_prefix=file_prefix, num_output_files=args.num_output_files, tokenizer_scheme=args.tokenizer_scheme, is_training=True, ) num_examples_dict[data_split_type] = num_examples_written meta_file = os.path.join(data_split_type_dir, "meta.dat") with open(meta_file, "w") as fout: for output_file, num_lines in meta_data.items(): fout.write("%s %s\n" % (output_file, num_lines)) # Write args passed and number of examples args_dict = vars(args) args_dict["num_examples"] = num_examples_dict save_params(args_dict, model_dir=args.output_dir)
if __name__ == "__main__": main()