# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
BaseConfigConverter,
BaseDictionaryConverter,
ConversionRule,
FormatVersions,
)
[docs]class ConfigConverter_sP_muP(BaseConfigConverter):
"""Transforms a CS muP config to a CS sP config."""
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(["output_logits_scale"]),
ConversionRule(["embeddings_scale"]),
ConversionRule(["scale_qk_dot_by_d"]),
ConversionRule(
["share_embedding_weights"],
action=self.set_share_embedding_weights,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return ("sP", "muP")
@staticmethod
def file_formats() -> Tuple[str, str]:
return ()
@staticmethod
def is_mup(config):
return _is_mup(config)
def set_share_embedding_weights(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 1 and (
"output_logits_scale" in old_state_dict
or "embeddings_scale" in old_state_dict
):
new_state_dict[new_key] = False
else:
new_state_dict[new_key] = old_state_dict[old_key]
[docs]class Converter_sP_muP(BaseDictionaryConverter):
"""Transforms a CS muP checkpoints into a CS sP checkpoint.
muP: Maximal Update Parametrization.
sP: Standard Parametrization.
"""
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[r".+\.proj_k_dense_layer.*"], action=self.scale_k_projection,
),
ConversionRule(
[r"(?:model\.|)lm_head\.weight"], action=self.scale_lm_head,
),
ConversionRule(
[r"(?:model\.|)embedding_layer\.word_embeddings\.weight"],
action=self.scale_embeddings,
),
ConversionRule(
[
r"(?:model\.|)embedding_layer\.position_embeddings(?:\.embed)?\.weight"
],
action=self.scale_embeddings,
),
ConversionRule(
[r"(?:model\.|)embedding_ln_f\.(?:weight|bias)"],
action=self.scale_embedding_layernorm,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
def scale_k_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
if config["model"].get('scale_qk_dot_by_d', False):
d_model = config["model"]["hidden_size"]
n_heads = config["model"]["num_heads"]
d_sqrt = math.sqrt(d_model // n_heads)
new_state_dict[new_key] = old_state_dict[old_key] / d_sqrt
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_lm_head(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
if "output_logits_scale" in config["model"]:
output_scale = config["model"]["output_logits_scale"]
new_state_dict[new_key] = old_state_dict[old_key] * output_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_embeddings(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
# Fold embeddings_scale into word/position embeddings if embedding
# layer norm *is not* enabled
if "embeddings_scale" in config["model"] and not config["model"].get(
"embedding_layer_norm", False
):
emb_scale = config["model"]["embeddings_scale"]
new_state_dict[new_key] = old_state_dict[old_key] * emb_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_embedding_layernorm(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
# Fold embeddings_scale into embedding layer norm if embedding
# layer norm *is* enabled
if "embeddings_scale" in config["model"] and config["model"].get(
"embedding_layer_norm", False
):
emb_scale = config["model"]["embeddings_scale"]
new_state_dict[new_key] = old_state_dict[old_key] * emb_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
@staticmethod
def is_mup(config):
return _is_mup(config.get('model', {}))
@staticmethod
def formats():
return ("sP", "muP")
def _is_mup(model_config):
scale_qk_dot_by_d = model_config.get('scale_qk_dot_by_d', False)
embeddings_scale = model_config.get('embeddings_scale', None)
output_logits_scale = model_config.get('output_logits_scale', None)
all_set = scale_qk_dot_by_d and embeddings_scale and output_logits_scale
any_set = scale_qk_dot_by_d or embeddings_scale or output_logits_scale
if any_set and not all_set:
raise ValueError(
"This looks like an incomplete muP config. Either all of or none of "
"\"scale_qk_dot_by_d\", \"embeddings_scale\", \"output_logits_scale\" can be "
"specified, but this config only has some that are specified."
)
return all_set