Source code for modelzoo.common.pytorch.model_utils.checkpoint_converters.mup

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Tuple

from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseConfigConverter,
    BaseDictionaryConverter,
    ConversionRule,
    FormatVersions,
)


[docs]class ConfigConverter_sP_muP(BaseConfigConverter): """Transforms a CS muP config to a CS sP config."""
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule(["output_logits_scale"]), ConversionRule(["embeddings_scale"]), ConversionRule(["scale_qk_dot_by_d"]), ConversionRule( ["share_embedding_weights"], action=self.set_share_embedding_weights, ), ConversionRule( [r".*"], action=self.replaceKey ), # Catch-all for everything else ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ("sP", "muP") @staticmethod def file_formats() -> Tuple[str, str]: return () @staticmethod def is_mup(config): return _is_mup(config) def set_share_embedding_weights( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 1 and ( "output_logits_scale" in old_state_dict or "embeddings_scale" in old_state_dict ): new_state_dict[new_key] = False else: new_state_dict[new_key] = old_state_dict[old_key]
[docs]class Converter_sP_muP(BaseDictionaryConverter): """Transforms a CS muP checkpoints into a CS sP checkpoint. muP: Maximal Update Parametrization. sP: Standard Parametrization. """
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [r".+\.proj_k_dense_layer.*"], action=self.scale_k_projection, ), ConversionRule( [r"(?:model\.|)lm_head\.weight"], action=self.scale_lm_head, ), ConversionRule( [r"(?:model\.|)embedding_layer\.word_embeddings\.weight"], action=self.scale_embeddings, ), ConversionRule( [ r"(?:model\.|)embedding_layer\.position_embeddings(?:\.embed)?\.weight" ], action=self.scale_embeddings, ), ConversionRule( [r"(?:model\.|)embedding_ln_f\.(?:weight|bias)"], action=self.scale_embedding_layernorm, ), ConversionRule( [r".*"], action=self.replaceKey ), # Catch-all for everything else ]
def scale_k_projection( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): config = action_fn_args["configs"][1] if config["model"].get('scale_qk_dot_by_d', False): d_model = config["model"]["hidden_size"] n_heads = config["model"]["num_heads"] d_sqrt = math.sqrt(d_model // n_heads) new_state_dict[new_key] = old_state_dict[old_key] / d_sqrt else: new_state_dict[new_key] = old_state_dict[old_key] def scale_lm_head( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): config = action_fn_args["configs"][1] if "output_logits_scale" in config["model"]: output_scale = config["model"]["output_logits_scale"] new_state_dict[new_key] = old_state_dict[old_key] * output_scale else: new_state_dict[new_key] = old_state_dict[old_key] def scale_embeddings( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): config = action_fn_args["configs"][1] # Fold embeddings_scale into word/position embeddings if embedding # layer norm *is not* enabled if "embeddings_scale" in config["model"] and not config["model"].get( "embedding_layer_norm", False ): emb_scale = config["model"]["embeddings_scale"] new_state_dict[new_key] = old_state_dict[old_key] * emb_scale else: new_state_dict[new_key] = old_state_dict[old_key] def scale_embedding_layernorm( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): config = action_fn_args["configs"][1] # Fold embeddings_scale into embedding layer norm if embedding # layer norm *is* enabled if "embeddings_scale" in config["model"] and config["model"].get( "embedding_layer_norm", False ): emb_scale = config["model"]["embeddings_scale"] new_state_dict[new_key] = old_state_dict[old_key] * emb_scale else: new_state_dict[new_key] = old_state_dict[old_key] @staticmethod def is_mup(config): return _is_mup(config.get('model', {})) @staticmethod def formats(): return ("sP", "muP")
def _is_mup(model_config): scale_qk_dot_by_d = model_config.get('scale_qk_dot_by_d', False) embeddings_scale = model_config.get('embeddings_scale', None) output_logits_scale = model_config.get('output_logits_scale', None) all_set = scale_qk_dot_by_d and embeddings_scale and output_logits_scale any_set = scale_qk_dot_by_d or embeddings_scale or output_logits_scale if any_set and not all_set: raise ValueError( "This looks like an incomplete muP config. Either all of or none of " "\"scale_qk_dot_by_d\", \"embeddings_scale\", \"output_logits_scale\" can be " "specified, but this config only has some that are specified." ) return all_set