Train an LLM with a large or small context window#

Overview#

One of the advantages of the CS-X is the ability to train a Large Language Model (LLM) with a large context window (also known as sequence lengths). Larger context windows enable LLMs to handle larger inputs. Sometimes it is advantageous to train with a shorter context window (e.g., in instruction tuning) for efficiency.

Follow the steps in this guide to train a language model with a large (> 2048) or a small (< 2048) context window.

Procedure#

For large context window#

1. Data processing: when creating your dataset using the script file create_hdf5_dataset.py, change the value of the argument --max_seq_length to the desired value.

For example, for a sequence length of 4096 tokens:

python create_hdf5_dataset.py LMData \
     --input_dir /path/to/data \
     --tokenizer_type NeoXTokenizer \
     --encoder_file /path/to/encoder \
     --max_seq_length 4096 \
     --ftfy True \
     --pack_sequences False

2. Model training: after processing your data, set the max_sequence_length and the max_position_embeddings to the desired value in the model’s configuration YAML file.

For example:

train_input:
    data_processor: "GptHDF5DataProcessor"
    data_dir:
         - ./path/to/train/dataset/with_msl4096/
    ...
    max_sequence_length: 4096
    ...

eval_input:
    data_processor: "GptHDF5DataProcessor"
    data_dir: "./path/to/eval/datast/with_msl4096"
    max_sequence_length: 4096
    ...

model:
    ...
    max_position_embeddings: 4096
    ...

For small context window#

For example, a model that was pretrained on a sequence length of 2048 tokens may be further instruction fine-tuned on a dataset with a sequence length of 256 tokens. In this case assume the longest sequence in the dataset has a length of 256 tokens. Since the sequences of this dataset will be padded, not packed, training with a shorter sequence length is more efficient than padding every sample in the dataset all the way to 2048 tokens.

1. Data processing: when calling the script file create_hdf5_dataset.py, change the value of the argument --max_seq_length to the desired value.

python create_hdf5_dataset.py Summarization \
     --input_dir /path/to/data \
     --tokenizer_type NeoXTokenizer \
     --encoder_file /path/to/encoder \
     --max_seq_length 256 \
     --ftfy True \
     --sep_token: null \
     --prompt_key: "prompt" \
     --completion_key: "completion" \
     --pack_sequences False

2. Model training: after processing your data, set the max_sequence_length to the desired value in the model’s configuration yaml file.

For example:

train_input:
    data_processor: "GptHDF5DataProcessor"
    data_dir:
         - ./path/to/train/dataset/with_msl4096/
    ...
    max_sequence_length: 256
    ...

eval_input:
    data_processor: "GptHDF5DataProcessor"
    data_dir: "./path/to/eval/datast/with_msl4096"
    max_sequence_length: 256
    ...

model:
    ...
    max_position_embeddings: 2048
    ...

Implementation notes#

Note when training a pretrained model with a smaller context window, the max_position_embeddings parameter of the pretrained model remains the same. Only the max_sequence_length parameter needs to be changed.

On the other hand, when training with a large context window, both the max_position_embeddings and max_sequence_length parameters need to be changed.

Conclusion#

Training a Large Language Model (LLM) with an appropriately sized context window is crucial for optimizing its performance and efficiency. By adjusting the context window size, whether large or small, users can tailor the training process to suit specific requirements, such as handling larger inputs or improving efficiency during instruction fine-tuning. The process involves careful preparation of the dataset and precise configuration of model parameters to ensure that the LLM can effectively learn from and generate text based on the given context size. Ultimately, understanding and leveraging the context window’s impact on LLM training allows for more flexible and powerful model applications, enabling users to achieve their desired outcomes with greater precision and efficiency.