Source code for cerebras.modelzoo.data_preparation.nlp.slimpajama.preprocessing.filter

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import pickle
import queue
import re
import string
import time
from collections import defaultdict
from glob import glob
from multiprocessing import Process, Queue, cpu_count

from lm_dataformat import Reader


[docs]def clean(s): # lower cased s = s.lower() # remove punctuation s = s.translate(str.maketrans("", "", string.punctuation)) # remove consecutive spaces, newlines, tabs in the middle and in the beginning / end s = re.sub(r"\s+", " ", s.strip()) return s
[docs]def get_short_documents( input_dir, threshold, n_proc, proc_idx, docs_queue, dataset_name ): if dataset_name == "all": files = [] for dataset in [ "arxiv", "stackexchange", "book", "wikipedia", "github", "c4", "common_crawl", ]: if dataset == "common_crawl": files.extend(glob(f"{input_dir}/{dataset}/*/*")) else: files.extend(glob(f"{input_dir}/{dataset}/*")) elif dataset_name == "common_crawl": files = glob(f"{input_dir}/{dataset_name}/*/*") else: files = glob(f"{input_dir}/{dataset_name}/*") files = sorted(files) for i in range(proc_idx, len(files), n_proc): reader = Reader(files[i]) for doc_id, doc in enumerate(reader._stream_data(jsonl_key="text")): if len(clean(doc)) < threshold: docs_queue.put( files[i].replace(input_dir + "/", "") + f"@{doc_id}" )
[docs]def filter_dataset(args): short_documents = defaultdict(set) docs_queue = Queue(64 * 10000) n_proc = cpu_count() processes = [] for process_id in range(n_proc): p = Process( target=get_short_documents, args=( args.input_dir, args.threshold, n_proc, process_id, docs_queue, args.dataset_name, ), ) processes.append(p) p.start() i = 0 start_time = time.time() while True: try: item = docs_queue.get(timeout=30) file_name, doc_idx = item.split("@") short_documents[file_name].add(int(doc_idx)) if i % 10 == 0: print( f"Processed {i / args.n_docs * 100}%. ", time.time() - start_time, ) i += 1 except queue.Empty: break for p in processes: p.join() print("Finished processing, writing to disk!") with open(args.output_file, "wb") as fout: pickle.dump(short_documents, fout)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input_dir", help="Dataset input directory.") parser.add_argument("output_file", help="File to output short docs to.") parser.add_argument( "n_docs", type=int, help="Number of documents in the dataset." ) parser.add_argument("dataset_name") parser.add_argument( "threshold", type=int, help="Minimum length of the document to keep." ) args = parser.parse_args() filter_dataset(args)