Source code for cerebras.modelzoo.data_preparation.nlp.slimpajama.dedup.dedup_train

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from glob import glob

from lm_dataformat import Reader
from tqdm import tqdm

# isort: off
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
# isort: on
from cerebras.modelzoo.data_preparation.nlp.slimpajama.utils import (
    rm_if_exists,
    sha256str,
    write_lmd_dataset,
)


[docs]def deduplicate_train_holdout_sets( train_path, holdout_path, deduped_train_path, chunk_id ): # Calculate hashes on holdout set. seen = set() if os.path.exists("hashes.txt"): with open("hashes.txt") as fh: for line in tqdm(fh): seen.add(line.strip()) else: hashf = open("hashes.txt", "w") for f in tqdm(glob(f"{holdout_path}/*/*.zst")): reader = Reader(f) for doc_id, text in enumerate( reader._stream_data(jsonl_key="text") ): hash = sha256str(text) hashf.write(hash + "\n") seen.add(hash) hashf.close() print("Finished collecting hashes for eval", len(seen)) rm_if_exists(f"{deduped_train_path}/chunk{chunk_id}") os.makedirs(f"{deduped_train_path}/chunk{chunk_id}") total_written = 0 # Remove elements from train set with hashes seen in eval set. for f in tqdm(glob(f"{train_path}/chunk{chunk_id}/*.zst")): def filtered_docs(): reader = Reader(f) for doc_id, doc in enumerate(reader._stream_data(get_meta=True)): text, meta = doc hash = sha256str(text) if hash not in seen: yield text, meta else: print("Found an intersection!!!") with open( f"{deduped_train_path}/chunk{chunk_id}/" + f.split("/")[-1], "wb" ) as fout_dedup_train: total_written += write_lmd_dataset( fout_dedup_train, filtered_docs(), indices=None, return_total_written=True, ) print(f"Total written: {total_written}")
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("chunk_id", type=int) parser.add_argument("--src_dir", type=str) parser.add_argument("--tgt_dir", type=str) parser.add_argument("--out_dir", type=str) args = parser.parse_args() deduplicate_train_holdout_sets( args.src_dir, args.tgt_dir, args.out_dir, args.chunk_id, )