Source code for cerebras.modelzoo.data_preparation.data_preprocessing.data_dedup.dedup

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import logging
import os
from multiprocessing import cpu_count

from cerebras.modelzoo.data_preparation.data_preprocessing.data_dedup import (
    deduplicate_dataset,
    generate_connected_components,
    generate_duplicate_pairs,
    to_hash,
)


[docs]def deduplication_pipeline(args): logging.info("Starting deduplication pipeline...") # Generate MinHash to_hash_args = argparse.Namespace( dataset_name=args.dataset_name, input_dir=args.input_dir, output_dir=args.input_dir, job_id=0, jsonl_key=args.jsonl_key, format=args.format, threshold=0, window_size=6, batch_size=100, docs_per_core=1000, n_jobs=1, ) to_hash.generate_hashes(to_hash_args) logging.info("Finished generating MinHash!") # Generate duplicate pairs duplicates_dir = os.path.join(args.input_dir, "duplicates") os.makedirs(duplicates_dir, exist_ok=True) generate_duplicate_pairs_args = argparse.Namespace( input_dir=to_hash_args.output_dir, out_file=os.path.join( to_hash_args.output_dir, "duplicates", "duplicate_pairs.txt" ), jaccard_threshold=0.8, processes=cpu_count(), ) generate_duplicate_pairs.generate_pairs(generate_duplicate_pairs_args) logging.info("Finished generating duplicate pairs!") generate_connected_components_args = argparse.Namespace( input_dir=duplicates_dir, out_file=os.path.join( args.input_dir, "duplicates", "connected_components.pickle" ), ) generate_connected_components.generate_connected_components_mp( generate_connected_components_args ) logging.info("Finished generating graph of connected components!") # Deduplicate dataset os.makedirs(args.output_dir, exist_ok=True) deduplicate_dataset_args = argparse.Namespace( input_file=generate_connected_components_args.out_file, input_dir=args.input_dir, output_dir=args.output_dir, format=args.format, jsonl_key=args.jsonl_key, ) deduplicate_dataset.generate_duplicates(deduplicate_dataset_args) logging.info("Finished deduplicating dataset!") if args.clean_up: logging.info("Cleaning up...") os.system( f"cd {args.input_dir} && rm -rf minhash_nfc/ duplicates/ && rm duplicate_pairs.pickle" )
if __name__ == "__main__": logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() parser.add_argument( "--dataset_name", type=str, help="Name of the dataset", required=True ) parser.add_argument( "--input_dir", type=str, help="Input directory which contains documents", required=True, ) parser.add_argument( "--jsonl_key", type=str, help="JSONL key for the dataset", required=True ) parser.add_argument( "--format", type=str, help="Format of the dataset", required=True ) parser.add_argument( "--output_dir", type=str, help="Output directory for deduplicated dataset", required=True, ) parser.add_argument( "--clean_up", type=bool, help="Delete intermediate, created files", default=False, required=False, ) args = parser.parse_args() deduplication_pipeline(args)