Source code for cerebras.modelzoo.data_preparation.nlp.chunk_data_processing.summarization_data_token_generator

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
SummarizationTokenGenerator Module

This module provides the SummarizationTokenGenerator class which is designed to tokenize
prompt/completion data and create features suitable for summarization tasks. The class
utilizes the BPETokenizer from the modelzoo.transformers.data_processing.tokenizers
package for tokenization.

Usage:
    tokenizer = SummarizationTokenizer(dataset_params,max_sequence_length,tokenizer)
    tokenized_features = tokenizer.encode(("prompt_text","completion_text"))
"""

import logging
from typing import Dict, List, Tuple

import ftfy
import numpy as np

from cerebras.modelzoo.data_preparation.nlp.hdf5_preprocessing.utils import (
    wikitext_detokenizer,
)

logger = logging.getLogger("summarization_token_generator")
logger.setLevel(logging.INFO)


[docs]def create_features_summarization( prompt_ids, completion_ids, max_sequence_length, eos_id=0, sep_id=None, pad_id=0, completion_prefix_mask_len=0, min_len=10, inverted_mask=False, input_ids_dtype="int32", input_mask_dtype="int32", labels_dtype="int32", ): """ Given a list of prompt_ids and completion_ids, generate input sequence and labels. Args: prompt_ids (sequence): List containing token ids for the prompt to create features,labels and input mask from. completion_ids (sequence): List containing token ids for the completion create features,labels and input mask from. max_sequence_length (int): Maximum sequence length for data writes. completion_prefix_mask_len (int): This stores the number of tokens which the completion prefix takes. eos_id (int): Id for end of sequence token. Defaults to `0`. sep_id (int): Id for separator token. Defaults to `None`. pad_id (int): Id for pad token. Defaults to `0`. min_len (int): Minimum length of token_ids to be considered a valid sequence. inverted_mask (bool): Invert mask if specified for runtime execution. Defaults to `False`. input_ids_dtype (str): Dtype as string for input ids. Defaults to `int32`. input_mask_dtype (str): Dtype as string for input mask. Defaults to `int32`. labels_dtype (str): Dtype as string for labels. Defaults to `int32`. """ # This function is a modified copy of a function with same name written in # hdf5_preprocessing utils. The difference here is the we are returning # number of pad tokens,masked tokens, loss valid tokens and total number of tokens in addition to the original data. # extra <EOS> total_len = len(prompt_ids) + len(completion_ids) + 1 if sep_id is not None: total_len += 1 if total_len > max_sequence_length: logger.warning( "prompt_ids + completion_ids > max_sequence_length, skipping this example..." ) return [] if total_len < min_len: logger.warning( "prompt_ids + completion_ids < min_sequence_len, skipping this example..." ) return [] token_ids = [] if sep_id is not None: prompt_ids = prompt_ids + [sep_id] completion_ids += [eos_id] token_ids = prompt_ids + completion_ids token_mask = [0] * (len(prompt_ids) - 1) ## mask the completion prefix tokens token_mask += [0] * completion_prefix_mask_len # start prediction on the last prompt token (including if it's sep or eos) or the last completion prefix token token_mask += [1] token_mask += [1] * (len(completion_ids) - completion_prefix_mask_len - 1) token_mask += [0] # EOS # add padding token_ids_pad = max_sequence_length + 1 - len(token_ids) input_mask_pad = max_sequence_length - len(token_mask) token_ids.extend([pad_id] * token_ids_pad) token_mask.extend([0] * input_mask_pad) input_ids = token_ids[:-1] labels = token_ids[1:] assert ( len(input_ids) == max_sequence_length and len(labels) == max_sequence_length and len(token_mask) == max_sequence_length ), "Wrong sequence length" features = dict() features["input_ids"] = getattr(np, input_ids_dtype)(input_ids) features["input_mask"] = getattr(np, input_mask_dtype)(token_mask) if inverted_mask: features["input_mask"] = np.equal(features["input_mask"], 0).astype( features["input_mask"].dtype ) return np.stack([features["input_ids"], features["input_mask"], labels])
[docs]class SummarizationTokenGenerator:
[docs] def __init__(self, params, tokenizer, eos_id, pad_id): """ Initialize the SummarizationTokenizer class. Args: vocab_file (str): Path to the vocabulary file. encoder_file (str): Path to the encoder file. max_seq_length (int, optional): Maximum sequence length. Defaults to 2048. """ dataset_params = params["dataset"] processing_params = params["processing"] self.files_per_record = processing_params.pop( "files_per_record", 50000 ) ## redundant param self.tokenizer = tokenizer self.use_ftfy = dataset_params.pop("use_ftfy", False) self.ftfy_normalizer = dataset_params.pop("ftfy_normalizer", "NFC") self.wikitext_detokenize = dataset_params.pop( "wikitext_detokenize", False ) self.input_ids_dtype = dataset_params.pop("input_ids_dtype", "int32") self.input_mask_dtype = dataset_params.pop("input_mask_dtype", "int32") self.sep_token = dataset_params.pop("sep_token", None) self.sep_id = None if self.sep_token: self.sep_id = self.tokenizer.get_token_id(self.sep_token) self.default_add_bos = False if ( hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token ): self.default_add_bos = True self.inverted_mask = dataset_params.pop("inverted_mask", False) self.min_sequence_len = dataset_params.pop("min_sequence_len", 10) self.max_seq_length = processing_params.pop("max_seq_length", 2048) self.eos_id = eos_id self.pad_id = pad_id self.prompt_prefix = params["dataset"].pop("prompt_prefix", None) self.completion_prefix = params["dataset"].pop( "completion_prefix", None ) self.tokenize_prefix_tags()
def tokenize_prefix_tags(self): if self.prompt_prefix is not None: self.prompt_prefix_toks = self.tokenize_text(self.prompt_prefix) self.prompt_prefix_toks_len = len(self.prompt_prefix_toks) else: self.prompt_prefix_toks_len = 0 if self.completion_prefix is not None: self.comp_prefix_toks = self.tokenize_text(self.completion_prefix) self.comp_prefix_toks_len = len(self.comp_prefix_toks) if self.default_add_bos: self.comp_prefix_toks_len -= 1 ## Remove the default bos token else: self.comp_prefix_toks_len = 0
[docs] def prepend_prefix( self, prompt_ids: List[int], completion_ids: List[int], index: int, ) -> Tuple[List[int], List[int]]: """ Prepends prefixes to prompt ids and completion ids and manages the beginning of sentence (BOS) tokens. :param prompt_ids: A list of integer IDs representing the prompt. :param completion_ids: A list of integer IDs representing the completion. :param index: The index indicating the position of the sequence being processed. return: A tuple of two lists: the updated prompt_ids and completion_ids. """ if self.default_add_bos: prompt_ids_suffix = ( prompt_ids[1:] if index > 0 or self.prompt_prefix else prompt_ids ) completion_ids_suffix = completion_ids[1:] else: prompt_ids_suffix = prompt_ids completion_ids_suffix = completion_ids if self.prompt_prefix is not None: prompt_prefix_tokens = ( self.prompt_prefix_toks[1:] if self.default_add_bos and index > 0 else self.prompt_prefix_toks ) prompt_ids = prompt_prefix_tokens + prompt_ids_suffix else: prompt_ids = prompt_ids_suffix if self.completion_prefix is not None: completion_prefix_tokens = ( self.comp_prefix_toks[1:] if self.default_add_bos else self.comp_prefix_toks ) completion_ids = completion_prefix_tokens + completion_ids_suffix else: completion_ids = completion_ids_suffix return prompt_ids, completion_ids
[docs] def tokenize_text(self, text: str) -> List[int]: """ Tokenize the provided text. Args: text (str): Text to tokenize. Returns: List[int]: List of token IDs. """ return self.tokenizer.encode(text)
def check_valid_doc(self, prompt, completion): doc_field = None if prompt == "" and completion == "": doc_field = "both prompt and completion" elif prompt == "": doc_field = "prompt" elif completion == "": doc_field = "completion" if doc_field: logger.warning(f"{doc_field} is empty. Skipping this doc...") return False else: return True
[docs] def encode(self, doc: List[Tuple]) -> Tuple[List[np.ndarray], Dict]: """ Tokenize and encode the doc for text summarization. Args: List[tuple]: Contains a list of prompt, completion data to encode Returns: -> Tuple[List[np.ndarray], Dict]: Tuple of encoded features for text summarization and dataset stats """ raw_chars_count = 0 raw_bytes_count = 0 files_processed = 0 discarded_files = 0 normalized_chars_count = 0 normalized_bytes_count = 0 num_pad_tokens = 0 non_pad_tokens = 0 num_masked_tokens = 0 loss_valid_tokens = 0 num_tokens = 0 sample_list = [] for i, (prompt, completion) in enumerate(doc): if not self.check_valid_doc(prompt, completion): continue raw_chars_count += len(prompt) + len(completion) raw_bytes_count += len(prompt.encode("utf-8")) + len( completion.encode("utf-8") ) files_processed += 1 if self.use_ftfy: prompt = ftfy.fix_text( prompt, normalization=self.ftfy_normalizer ) completion = ftfy.fix_text( completion, normalization=self.ftfy_normalizer ) if self.wikitext_detokenize: prompt = wikitext_detokenizer(prompt) completion = wikitext_detokenizer(completion) normalized_chars_count += len(prompt) + len(completion) normalized_bytes_count += len(prompt.encode("utf-8")) + len( completion.encode("utf-8") ) prompt_encoded = self.tokenize_text(prompt) completion_encoded = self.tokenize_text(completion) prompt_encoded, completion_encoded = self.prepend_prefix( prompt_encoded, completion_encoded, i ) sample = create_features_summarization( prompt_encoded, completion_encoded, self.max_seq_length, self.eos_id, self.sep_id, self.pad_id, self.comp_prefix_toks_len, min_len=self.min_sequence_len, inverted_mask=self.inverted_mask, input_ids_dtype=self.input_ids_dtype, input_mask_dtype=self.input_mask_dtype, labels_dtype=self.input_ids_dtype, ) if sample == []: discarded_files += 1 else: num_loss_valid_tokens = int(sample[1, :].sum()) num_pad_tokens += int((sample[0, :] == self.pad_id).sum()) non_pad_tokens += int( np.logical_and( sample[0, :] != self.eos_id, sample[0, :] != self.pad_id ).sum() ) num_masked_tokens += self.max_seq_length - num_loss_valid_tokens loss_valid_tokens += num_loss_valid_tokens num_tokens += int(sample[0, :].shape[0]) sample = np.expand_dims(sample, axis=0) if sample_list != []: sample_list = np.stack([sample_list, sample], axis=0) else: sample_list = sample data_stats = { "discarded": discarded_files, "processed": files_processed, "successful": files_processed - discarded_files, "raw_chars_count": raw_chars_count, "raw_bytes_count": raw_bytes_count, "normalized_chars_count": normalized_chars_count, "normalized_bytes_count": normalized_bytes_count, "num_pad_tokens": num_pad_tokens, "non_pad_tokens": non_pad_tokens, "num_masked_tokens": num_masked_tokens, "loss_valid_tokens": loss_valid_tokens, "num_tokens": num_tokens, } return sample, data_stats
[docs] def get_token_id(self, token: str) -> int: """ Get the token ID for the given token. Args: token (str): Token for which the ID is needed. Returns: int: Token ID. """ return self.tokenizer.get_token_id(token)