Model#

This page will cover how to pass a model into the Trainer. The model is the main Module that all training and validation is run on. It is required by all Trainer instances.

Prerequisites#

Make sure to have read through Trainer Overview and Trainer Configuration Overview which provide the basic overview of how to run Model Zoo models. In this document, you will be using the tools and configurations outlined in those pages.

Configure the model#

To set the model to train/validate using the Trainer use the model argument.

All model subkeys are passed as arguments to the model class. The model class is decided by the model_fn in your run script.

trainer:
  init:
    ...
    model:
      vocab_size: 1024
      max_position_embeddings: 1024
      ...
    ...
  ...

Note

If passing the model as a Module directly, it is optimal to first initalize the model inside of the Cerebras device context.

For example:

import cerebras.pytorch as cstorch
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.models.nlp.gpt2.model import Gpt2Model

# Initialize the Cerebras backend for efficient processing.
backend = cstorch.backend("CSX")

# Use the backend's device context manager for initializing the model.
with backend.device:
    model = Gpt2Model(
        vocab_size=1024,
        max_position_embeddings=1024,
        ...,
    )

# Compile the model using the Cerebras backend for optimized execution.
trainer = Trainer(
    ...,
    backend=backend,
    model=model,
    ...,
)
...

This ensures that model parameters are automatically moved to the Cerebras device, optimizing memory usage and enhancing initialization speed. For more information, see Efficient weight initialization.

Conclusion#

That covers specifying the model to train/validate with the Trainer. You should now understand the various ways to configure the model and how the Trainer accepts a model.

Further Reading#

To learn more about how you can use the Trainer in some core workflows, you can check out:

To learn more about how you can extend the capabilities of the Trainer class, you can check out: