Source code for modelzoo.fc_mnist.pytorch.model

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

import cerebras_pytorch as cstorch
from cerebras_pytorch.metrics import AccuracyMetric


[docs]class MNIST(nn.Module):
[docs] def __init__(self, model_params): super().__init__() self.fc_layers = [] use_bias = model_params.get("use_bias", True) input_size = 784 # Set the default or None if "hidden_sizes" in model_params: # Depth is len(hidden_sizes) model_params["depth"] = len(model_params["hidden_sizes"]) else: # same hidden size across dense layers model_params["hidden_sizes"] = [ model_params["hidden_size"] ] * model_params["depth"] for hidden_size in model_params["hidden_sizes"]: fc_layer = nn.Linear(input_size, hidden_size, bias=use_bias) self.fc_layers.append(fc_layer) input_size = hidden_size self.fc_layers = nn.ModuleList(self.fc_layers) self.last_layer = nn.Linear(input_size, 10, bias=use_bias) self.nonlin = self._get_nonlinear(model_params) self.dropout = nn.Dropout(model_params["dropout"])
def forward(self, inputs): x = torch.flatten(inputs, 1) for fc_layer in self.fc_layers: x = fc_layer(x) if self.nonlin: x = self.nonlin(x) x = self.dropout(x) pred_logits = self.last_layer(x) return pred_logits def _get_nonlinear(self, model_params): if model_params["activation_fn"] == "relu": return nn.ReLU() elif model_params["activation_fn"] is None: return None else: raise ValueError("supports activation_fn: 'relu' or null")
[docs]class MNISTModel(nn.Module):
[docs] def __init__(self, params): super().__init__() # Disable eval metrics in non-eval modes if params["runconfig"]["mode"] != "eval" and params["model"].get( "compute_eval_metrics", [] ): params["model"]["compute_eval_metrics"] = [] model_params = deepcopy(params["model"]) self.model = self.build_model(model_params) self.loss_fn = nn.NLLLoss() self.disable_softmax = params["model"].get("disable_softmax", False) compute_eval_metrics = model_params.get("compute_eval_metrics", []) if isinstance(compute_eval_metrics, bool) and compute_eval_metrics: compute_eval_metrics = ["accuracy"] # All metrics self.accuracy_metric = None for name in compute_eval_metrics: if "accuracy" in name: self.accuracy_metric = AccuracyMetric(name=name) else: raise ValueError(f"Unknown metric: {name}")
def build_model(self, model_params): dtype = ( cstorch.amp.get_half_dtype() if model_params["to_float16"] else torch.float32 ) model = MNIST(model_params) model.to(dtype) return model def forward(self, data): inputs, labels = data pred_logits = self.model(inputs) if self.accuracy_metric: labels = labels.clone() predictions = pred_logits.argmax(-1).int() self.accuracy_metric(labels=labels, predictions=predictions) if not self.disable_softmax: pred_logits = F.log_softmax(pred_logits, dim=1) loss = self.loss_fn(pred_logits, labels) return loss