Source code for cerebras.modelzoo.data_preparation.raw_dataset_processor.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gzip
import io
import json
import numbers
import tarfile
import types
from typing import Any, Callable, Dict, Iterator, List, Optional

import jsonlines
import pyarrow.parquet as pq
import zstandard


[docs]class Reader: def __init__( self, file_list: List[str], keys: Optional[Dict], format_hook_fn: Callable, ) -> None: """ Initialize the Reader instance. Args: file_list (List[str]): List of file paths to be read. keys (Optional[Dict]): Dictionary containing the type of key and it's name. """ self.file_list = file_list self.keys = keys self.format_hook_fn = format_hook_fn
[docs] def handle_jsonl( self, jsonl_reader: Any, get_meta: bool, autojoin_paragraphs: bool, para_joiner: str, ) -> Iterator[Dict[str, Any]]: """ Handle JSONL data and yield processed entries. Args: jsonl_reader (Any): The JSONL reader object. get_meta (bool): Flag to determine if meta data should be extracted. autojoin_paragraphs (bool): Flag to auto join paragraphs. para_joiner (str): Paragraph joiner string. Returns: Iterator[Dict[str, Any]]: Yields processed data entries. """ for idx, ob in enumerate( jsonl_reader.iter(type=dict, skip_invalid=True) ): if isinstance(ob, str): assert not get_meta yield {"text": ob, "doc_idx": idx} continue entry = {} for key, value in self.keys.items(): if value in ob: text = ob[value] if not text: entry[value] = None continue # ## Special Case: If the data is an integer typecast it to a string if isinstance(text, numbers.Number): text = str(text) entry[value] = text else: entry[value] = None if get_meta and "meta" in ob: entry["meta"] = ob["meta"] entry["doc_idx"] = idx yield entry
[docs] def read_txt(self, file: str) -> Iterator[Any]: """ Read and process text file. Args: file (str): Path to the .txt file. Returns: Iterator[Any]: Yields processed data lines. """ with open(file, "r") as fh: text = fh.read() entry = {self.keys["text_key"]: text} yield entry
[docs] def read_jsongz( self, file: str, ) -> Iterator[Any]: """ Read and process gzipped JSON file. Args: file (str): Path to the .json.gz file. Returns: Iterator[Any]: Yields processed data entries. """ with gzip.open(file, "rb") as f: text_key = self.keys["text_key"] data_gen = ( { text_key: json.loads(line.decode("utf-8").strip())[ text_key ], "doc_idx": idx, } for idx, line in enumerate(f) ) yield data_gen
[docs] def read_jsonl( self, file: str, get_meta: bool = False, autojoin_paragraphs: bool = True, para_joiner: str = "\n\n", ) -> Iterator[Any]: """ Read and process JSONL file. Args: file (str): Path to the .jsonl file. get_meta (bool): Flag to determine if meta data should be extracted. autojoin_paragraphs (bool): Flag to auto join paragraphs. para_joiner (str): Paragraph joiner string. Returns: Iterator[Any]: Yields processed data entries. """ with open(file, "r") as fh: rdr = jsonlines.Reader(fh) data_gen = self.handle_jsonl( rdr, get_meta, autojoin_paragraphs, para_joiner ) assert isinstance(data_gen, types.GeneratorType) == True for data in data_gen: yield data
[docs] def read_jsonl_zst( self, file: str, get_meta: bool = False, autojoin_paragraphs: bool = True, para_joiner: str = "\n\n", ) -> Iterator[Any]: """ Read and process ZST compressed JSONL file. Args: file (str): Path to the .jsonl.zst file. get_meta (bool): Flag to determine if meta data should be extracted. autojoin_paragraphs (bool): Flag to auto join paragraphs. para_joiner (str): Paragraph joiner string. Returns: Iterator[Any]: Yields processed data entries. """ with open(file, "rb") as fh: cctx = zstandard.ZstdDecompressor() reader = io.BufferedReader(cctx.stream_reader(fh)) rdr = jsonlines.Reader(reader) data_gen = self.handle_jsonl( rdr, get_meta, autojoin_paragraphs, para_joiner ) assert isinstance(data_gen, types.GeneratorType) == True for data in data_gen: yield data
[docs] def read_jsonl_tar( self, file: str, get_meta: bool = False, autojoin_paragraphs: bool = True, para_joiner: str = "\n\n", ) -> Iterator[Any]: """ Read and process TAR archive containing ZST compressed JSONL files. Args: file (str): Path to the .jsonl.zst.tar file. get_meta (bool): Flag to determine if meta data should be extracted. autojoin_paragraphs (bool): Flag to auto join paragraphs. para_joiner (str): Paragraph joiner string. Returns: Iterator[Any]: Yields processed data entries. """ with tarfile.open(file, "r") as archive: for member in archive: with archive.extractfile(member) as f: cctx = zstandard.ZstdDecompressor() reader = io.BufferedReader(cctx.stream_reader(f)) rdr = jsonlines.Reader(reader) data_gen = self.handle_jsonl( rdr, get_meta, autojoin_paragraphs, para_joiner, ) assert isinstance(data_gen, types.GeneratorType) == True for data in data_gen: yield data
[docs] def read_parquet(self, file: str) -> Iterator[Any]: """ Read and process Parquet file. Args: file (str): Path to the .parquet file. Returns: Iterator[Any]: Yields processed data rows. """ parquet_file = pq.ParquetFile(file) def entry_gen() -> Iterator[Dict[str, Any]]: for row_group_index in range(parquet_file.num_row_groups): table = parquet_file.read_row_group(row_group_index) columns = { value: table.column(value) for key, value in self.keys.items() if value != None } for i in range(table.num_rows): entry = { key: ( str(col[i].as_py()) if isinstance(col[i].as_py(), numbers.Number) else col[i].as_py() ) for key, col in columns.items() } yield entry yield from entry_gen()
[docs] def read_fasta(self, file: str) -> Iterator[Dict[str, Any]]: """ Read and process Fasta file without using BioPython. Args: file (str): Path to the .fasta file. Returns: Iterator[Dict[str, Any]]: Yields processed data rows. """ def entry_gen(): with open(file, 'r') as fasta_file: record_id = None sequence_lines = [] for line in fasta_file: line = line.strip() if not line: continue # Skip empty lines if line.startswith(">"): if record_id is not None: # Yield the previous record yield { "text": ''.join(sequence_lines), } record_id = line[ 1: ] # Remove the ">" symbol and store the record ID sequence_lines = ( [] ) # Reset the sequence for a new record else: sequence_lines.append(line) # Don't forget to yield the last record in the file if record_id is not None: yield {"text": ''.join(sequence_lines)} yield from entry_gen()
[docs] def stream_data(self, get_meta: bool = False) -> Iterator[Any]: """ Stream and process data from multiple file formats. Args: get_meta (bool): Flag to determine if meta data should be extracted. Returns: Iterator[Any]: Yields processed data chunks. """ zipped_file_list = list(zip(range(len(self.file_list)), self.file_list)) for idx, f in zipped_file_list: if f.endswith(".jsonl"): yield from self.read_jsonl(f, get_meta) elif f.endswith(".jsonl.zst"): yield from self.read_jsonl_zst(f, get_meta) elif f.endswith(".jsonl.zst.tar"): yield from self.read_jsonl_tar(f, get_meta) elif f.endswith(".txt"): assert not get_meta yield from self.read_txt(f) elif f.endswith(".json.gz"): assert not get_meta yield from self.read_jsongz(f) elif f.endswith(".parquet"): assert not get_meta yield from self.read_parquet(f) elif f.endswith(".fasta"): assert not get_meta yield from self.read_fasta(f) else: logger.warning( f"Skipping {f} as streaming for that filetype is not implemented" )