# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import cerebras_pytorch as cstorch
from cerebras_pytorch.metrics import AccuracyMetric
[docs]class MNIST(nn.Module):
[docs] def __init__(self, model_params):
super().__init__()
self.fc_layers = []
use_bias = model_params.get("use_bias", True)
input_size = 784
# Set the default or None
if "hidden_sizes" in model_params:
# Depth is len(hidden_sizes)
model_params["depth"] = len(model_params["hidden_sizes"])
else:
# same hidden size across dense layers
model_params["hidden_sizes"] = [
model_params["hidden_size"]
] * model_params["depth"]
for hidden_size in model_params["hidden_sizes"]:
fc_layer = nn.Linear(input_size, hidden_size, bias=use_bias)
self.fc_layers.append(fc_layer)
input_size = hidden_size
self.fc_layers = nn.ModuleList(self.fc_layers)
self.last_layer = nn.Linear(input_size, 10, bias=use_bias)
self.nonlin = self._get_nonlinear(model_params)
self.dropout = nn.Dropout(model_params["dropout"])
def forward(self, inputs):
x = torch.flatten(inputs, 1)
for fc_layer in self.fc_layers:
x = fc_layer(x)
if self.nonlin:
x = self.nonlin(x)
x = self.dropout(x)
pred_logits = self.last_layer(x)
return pred_logits
def _get_nonlinear(self, model_params):
if model_params["activation_fn"] == "relu":
return nn.ReLU()
elif model_params["activation_fn"] is None:
return None
else:
raise ValueError("supports activation_fn: 'relu' or null")
[docs]class MNISTModel(nn.Module):
[docs] def __init__(self, params):
super().__init__()
# Disable eval metrics in non-eval modes
if params["runconfig"]["mode"] != "eval" and params["model"].get(
"compute_eval_metrics", []
):
params["model"]["compute_eval_metrics"] = []
model_params = deepcopy(params["model"])
self.model = self.build_model(model_params)
self.loss_fn = nn.NLLLoss()
self.disable_softmax = params["model"].get("disable_softmax", False)
compute_eval_metrics = model_params.get("compute_eval_metrics", [])
if isinstance(compute_eval_metrics, bool) and compute_eval_metrics:
compute_eval_metrics = ["accuracy"] # All metrics
self.accuracy_metric = None
for name in compute_eval_metrics:
if "accuracy" in name:
self.accuracy_metric = AccuracyMetric(name=name)
else:
raise ValueError(f"Unknown metric: {name}")
def build_model(self, model_params):
dtype = (
cstorch.amp.get_half_dtype()
if model_params["to_float16"]
else torch.float32
)
model = MNIST(model_params)
model.to(dtype)
return model
def forward(self, data):
inputs, labels = data
pred_logits = self.model(inputs)
if self.accuracy_metric:
labels = labels.clone()
predictions = pred_logits.argmax(-1).int()
self.accuracy_metric(labels=labels, predictions=predictions)
if not self.disable_softmax:
pred_logits = F.log_softmax(pred_logits, dim=1)
loss = self.loss_fn(pred_logits, labels)
return loss