Source code for modelzoo.vision.pytorch.dit.checkpoint_converter.vae_hf_cs

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# isort: off
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../../"))
# isort: on

"""
command:
python modelzoo/vision/pytorch/dit/checkpoint_converter/vae_hf_cs.py --dest_ckpt_path=<path to converted checkpoint>
"""
import argparse
import logging
import os
from typing import Tuple

import cerebras_pytorch as cstorch

LOGFORMAT = '%(asctime)s %(levelname)-4s[%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(level=logging.INFO, format=LOGFORMAT)

from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter,
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    ConversionRule,
    EquivalentSubkey,
)


[docs]def get_parser_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--src_ckpt_path", type=str, required=False, default=None, help=f"Path to HF Pretrained VAE checkpoint .bin file. " f"If not provided, file is automatically downloaded from " f"https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin", ) parser.add_argument( "--dest_ckpt_path", type=str, required=False, default=os.path.join( os.path.dirname(__file__), "mz_stabilityai-sd-vae-ft-mse_ckpt.bin" ), help="Path to converted modelzoo compatible checkpoint", ) parser.add_argument( "--params_path", type=str, required=False, default=os.path.abspath( os.path.join( os.path.dirname(__file__), "../configs/params_dit_small_patchsize_2x2.yaml", ) ), help="Path to VAE model params yaml", ) args = parser.parse_args() return args
[docs]class Converter_VAEModel_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [ # same keys EquivalentSubkey("encoder.conv", "encoder.conv"), ".*\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey( "encoder.down_blocks", "encoder.down_blocks" ), ".*", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey( "encoder.mid_block.resnets", "encoder.mid_block.resnets" ), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # encoder.mid_block.attentions.0.group_norm.weight -> encoder.mid_block.norms.0.weight EquivalentSubkey( "encoder.mid_block.attentions", "encoder.mid_block.norms", ), "\.\d+\.", EquivalentSubkey("group_norm.", ""), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # encoder.mid_block.attentions.0.query.weight -> encoder.mid_block.attentions.0.proj_q_dense_layer.weight EquivalentSubkey( "encoder.mid_block.attentions", "encoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("query", "proj_q_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # encoder.mid_block.attentions.0.key.weight -> encoder.mid_block.attentions.0.proj_k_dense_layer.weight EquivalentSubkey( "encoder.mid_block.attentions", "encoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("key", "proj_k_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # encoder.mid_block.attentions.0.value.weight -> encoder.mid_block.attentions.0.proj_v_dense_layer.weight EquivalentSubkey( "encoder.mid_block.attentions", "encoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("value", "proj_v_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # encoder.mid_block.attentions.0.proj_attn.weight -> encoder.mid_block.attentions.0.proj_output_dense_layer.weight EquivalentSubkey( "encoder.mid_block.attentions", "encoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("proj_attn", "proj_output_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey("decoder.conv", "decoder.conv"), ".*\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey("decoder.up_blocks", "decoder.up_blocks"), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey( "decoder.mid_block.resnets", "decoder.mid_block.resnets" ), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # decoder.mid_block.attentions.0.group_norm.weight -> decoder.mid_block.norms.0.weight EquivalentSubkey( "decoder.mid_block.attentions", "decoder.mid_block.norms", ), "\.\d+\.", EquivalentSubkey("group_norm.", ""), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # decoder.mid_block.attentions.0.query.weight -> decoder.mid_block.attentions.0.proj_q_dense_layer.weight EquivalentSubkey( "decoder.mid_block.attentions", "decoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("query", "proj_q_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # decoder.mid_block.attentions.0.key.weight -> decoder.mid_block.attentions.0.proj_k_dense_layer.weight EquivalentSubkey( "decoder.mid_block.attentions", "decoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("key", "proj_k_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # decoder.mid_block.attentions.0.value.weight -> decoder.mid_block.attentions.0.proj_v_dense_layer.weight EquivalentSubkey( "decoder.mid_block.attentions", "decoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("value", "proj_v_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # decoder.mid_block.attentions.0.proj_attn.weight -> decoder.mid_block.attentions.0.proj_output_dense_layer.weight EquivalentSubkey( "decoder.mid_block.attentions", "decoder.mid_block.attentions", ), "\.\d+\.", EquivalentSubkey("proj_attn", "proj_output_dense_layer"), "\.(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey("quant_conv", "quant_conv"), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ConversionRule( [ # same keys EquivalentSubkey("post_quant_conv", "post_quant_conv"), ".*(?:weight|bias)", ], action=BaseCheckpointConverter.replaceKey, ), ]
@staticmethod def formats() -> Tuple[str, str]: return ("vae_HF", "cs-1.9") @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None
if __name__ == "__main__": import yaml from modelzoo.vision.pytorch.dit.layers.vae.VAEModel import ( AutoencoderKL as CSAutoencoderKL, ) args = get_parser_args() if args.src_ckpt_path is None: import requests url = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin" logging.info( f"No `src_ckpt_path` provided, downloading the model from {url}" ) response = requests.get(url) response.raise_for_status() args.src_ckpt_path = os.path.join( os.path.dirname(__file__), "hf_stabilityai-sd-vae-ft-mse_ckpt.bin" ) with open(args.src_ckpt_path, "wb") as fh: fh.write(response.content) logging.info( f"Downloaded source pretrained ckpt at {args.src_ckpt_path}" ) old_state_dict = cstorch.load(args.src_ckpt_path) # VAE Params for CS modelzoo with open(args.params_path, "r") as fh: vae_params = yaml.safe_load(fh)["model"]["vae"] # Initialize CS VAE model cs_vae = CSAutoencoderKL(**vae_params) new_state_dict = cs_vae.state_dict() logging.info(f"Converting checkpoint...") # Convert converter = Converter_VAEModel_HF_CS19() matched_all_keys = converter.convert_all_keys( old_state_dict=old_state_dict, new_state_dict=new_state_dict, from_index=0, ) logging.info(f"matched_all_keys:{matched_all_keys}") cstorch.save( new_state_dict, args.dest_ckpt_path, ) logging.info(f"DONE: Converting checkpoint, saved at {args.dest_ckpt_path}")