Source code for cerebras.modelzoo.data_preparation.nlp.chunk_data_processing.fim_data_token_generator

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
FIMTokenGenerator Module

This module offers the FIMTokenGenerator class, an extension of the
LMDataTokenGenerator class, tailored for fill in the middle (FIM) tasks.

Usage:
    from your_module_name import FIMTokenGenerator

    # Initialize the token generator with the required parameters
    tokenizer = FIMTokenGenerator(params, tokenizer_impl, eos_id, pad_id)

    # Tokenize and encode text data
    tokenized_data, stats = tokenizer.encode("Your sample text to process.")
"""

import logging
from typing import Dict, List, Tuple

import ftfy
import numpy as np

from cerebras.modelzoo.data_preparation.nlp.chunk_data_processing.lm_data_token_generator import (
    LMDataTokenGenerator,
)
from cerebras.modelzoo.data_preparation.nlp.hdf5_preprocessing.utils import (
    check_fim_special_tokens,
    fim,
    handle_bos_token_default,
    wikitext_detokenizer,
)

logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)


[docs]class FIMTokenGenerator(LMDataTokenGenerator):
[docs] def __init__(self, params, tokenizer, eos_id, pad_id): """ Initialize the FIMPreprocessor class. Args: params (args): Params from config file """ super(FIMTokenGenerator, self).__init__( params, tokenizer, eos_id, pad_id ) processing_params = params["processing"] self.fim_rate = processing_params.pop("fim_rate", None) self.spm_rate = processing_params.pop("spm_rate", None) # Ensures that FIM tokens are specified in config, and that # the specified tokens are actually in the tokenizer check_fim_special_tokens(params, self.tokenizer) # Some tokenizers use BOS ID at the beginning and others do not. # Here we get a flag for whether to use BOS by default # and the BOS id if needed. self.default_bos_token, self.opt_bos_tok_id = handle_bos_token_default( self.tokenizer ) self.suffix_tok_id = self.tokenizer.encode( params['processing'].get("fim_suffix_tok") )[-1] self.prefix_tok_id = self.tokenizer.encode( params['processing'].get("fim_prefix_tok") )[-1] self.middle_tok_id = self.tokenizer.encode( params['processing'].get("fim_middle_tok") )[-1]
[docs] def encode(self, data: str) -> Tuple[List[np.ndarray], Dict]: """ Tokenize and encode the data for auto-regressive language modeling. Args: data (str): Text to tokenize Returns: Tuple[List[np.ndarray], Dict]: Tuple of encoded features for auto-regressive language modeling and dataset stats. """ raw_chars_count = len(data) raw_bytes_count = len(data.encode("utf-8")) files_processed = 1 discarded_files = 0 normalized_chars_count = raw_chars_count normalized_bytes_count = raw_bytes_count if self.use_ftfy: data = ftfy.fix_text(data, normalization=self.ftfy_normalizer) normalized_chars_count = len(data) normalized_bytes_count = len(data.encode("utf-8")) if self.wikitext_detokenize: data = wikitext_detokenizer(data) normalized_chars_count = len(data) normalized_bytes_count = len(data.encode("utf-8")) tokenized_text, stats = self.tokenize_text_auto_lm(data) result = [] # Reset the stats for pad tokens and masked tokens and recompute for FIM num_masked_tokens = 0 num_pad_tokens = 0 loss_valid_tokens = 0 num_tokens = 0 discarded_files = stats["discarded_files"] + stats["empty_chunks"] if tokenized_text == []: discarded_files += 1 for i, sample in enumerate(tokenized_text): if sample != []: sample = fim( sample, i, self.tokenizer, self.fim_rate, self.spm_rate, self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_id, self.eos_id, self.opt_bos_tok_id, ) num_pad_tokens += (sample[0, :] == self.pad_id).sum() num_masked_tokens += (sample[1, :] == 0).sum() loss_valid_tokens += sample[1, :].sum() num_tokens += sample.shape[1] else: discarded_files += 1 result.append(sample) data_stats = { "discarded": discarded_files, "processed": files_processed, "successful": files_processed - discarded_files, "raw_chars_count": raw_chars_count, "raw_bytes_count": raw_bytes_count, "num_pad_tokens": int(num_pad_tokens), "num_masked_tokens": int(num_masked_tokens), "loss_valid_tokens": int(loss_valid_tokens), "num_tokens": int(num_tokens), "normalized_chars_count": normalized_chars_count, "normalized_bytes_count": normalized_bytes_count, } return result, data_stats