# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort: off
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../../"))
# isort: on
"""
command:
python modelzoo/vision/pytorch/dit/checkpoint_converter/vae_hf_cs.py --dest_ckpt_path=<path to converted checkpoint>
"""
import argparse
import logging
import os
from typing import Tuple
import cerebras_pytorch as cstorch
LOGFORMAT = '%(asctime)s %(levelname)-4s[%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
BaseCheckpointConverter,
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
ConversionRule,
EquivalentSubkey,
)
[docs]def get_parser_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--src_ckpt_path",
type=str,
required=False,
default=None,
help=f"Path to HF Pretrained VAE checkpoint .bin file. "
f"If not provided, file is automatically downloaded from "
f"https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin",
)
parser.add_argument(
"--dest_ckpt_path",
type=str,
required=False,
default=os.path.join(
os.path.dirname(__file__), "mz_stabilityai-sd-vae-ft-mse_ckpt.bin"
),
help="Path to converted modelzoo compatible checkpoint",
)
parser.add_argument(
"--params_path",
type=str,
required=False,
default=os.path.abspath(
os.path.join(
os.path.dirname(__file__),
"../configs/params_dit_small_patchsize_2x2.yaml",
)
),
help="Path to VAE model params yaml",
)
args = parser.parse_args()
return args
[docs]class Converter_VAEModel_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[ # same keys
EquivalentSubkey("encoder.conv", "encoder.conv"),
".*\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey(
"encoder.down_blocks", "encoder.down_blocks"
),
".*",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey(
"encoder.mid_block.resnets", "encoder.mid_block.resnets"
),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # encoder.mid_block.attentions.0.group_norm.weight -> encoder.mid_block.norms.0.weight
EquivalentSubkey(
"encoder.mid_block.attentions",
"encoder.mid_block.norms",
),
"\.\d+\.",
EquivalentSubkey("group_norm.", ""),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # encoder.mid_block.attentions.0.query.weight -> encoder.mid_block.attentions.0.proj_q_dense_layer.weight
EquivalentSubkey(
"encoder.mid_block.attentions",
"encoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("query", "proj_q_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # encoder.mid_block.attentions.0.key.weight -> encoder.mid_block.attentions.0.proj_k_dense_layer.weight
EquivalentSubkey(
"encoder.mid_block.attentions",
"encoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("key", "proj_k_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # encoder.mid_block.attentions.0.value.weight -> encoder.mid_block.attentions.0.proj_v_dense_layer.weight
EquivalentSubkey(
"encoder.mid_block.attentions",
"encoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("value", "proj_v_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # encoder.mid_block.attentions.0.proj_attn.weight -> encoder.mid_block.attentions.0.proj_output_dense_layer.weight
EquivalentSubkey(
"encoder.mid_block.attentions",
"encoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("proj_attn", "proj_output_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey("decoder.conv", "decoder.conv"),
".*\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey("decoder.up_blocks", "decoder.up_blocks"),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey(
"decoder.mid_block.resnets", "decoder.mid_block.resnets"
),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # decoder.mid_block.attentions.0.group_norm.weight -> decoder.mid_block.norms.0.weight
EquivalentSubkey(
"decoder.mid_block.attentions",
"decoder.mid_block.norms",
),
"\.\d+\.",
EquivalentSubkey("group_norm.", ""),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # decoder.mid_block.attentions.0.query.weight -> decoder.mid_block.attentions.0.proj_q_dense_layer.weight
EquivalentSubkey(
"decoder.mid_block.attentions",
"decoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("query", "proj_q_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # decoder.mid_block.attentions.0.key.weight -> decoder.mid_block.attentions.0.proj_k_dense_layer.weight
EquivalentSubkey(
"decoder.mid_block.attentions",
"decoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("key", "proj_k_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # decoder.mid_block.attentions.0.value.weight -> decoder.mid_block.attentions.0.proj_v_dense_layer.weight
EquivalentSubkey(
"decoder.mid_block.attentions",
"decoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("value", "proj_v_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # decoder.mid_block.attentions.0.proj_attn.weight -> decoder.mid_block.attentions.0.proj_output_dense_layer.weight
EquivalentSubkey(
"decoder.mid_block.attentions",
"decoder.mid_block.attentions",
),
"\.\d+\.",
EquivalentSubkey("proj_attn", "proj_output_dense_layer"),
"\.(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey("quant_conv", "quant_conv"),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
ConversionRule(
[ # same keys
EquivalentSubkey("post_quant_conv", "post_quant_conv"),
".*(?:weight|bias)",
],
action=BaseCheckpointConverter.replaceKey,
),
]
@staticmethod
def formats() -> Tuple[str, str]:
return ("vae_HF", "cs-1.9")
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return None
if __name__ == "__main__":
import yaml
from modelzoo.vision.pytorch.dit.layers.vae.VAEModel import (
AutoencoderKL as CSAutoencoderKL,
)
args = get_parser_args()
if args.src_ckpt_path is None:
import requests
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
logging.info(
f"No `src_ckpt_path` provided, downloading the model from {url}"
)
response = requests.get(url)
response.raise_for_status()
args.src_ckpt_path = os.path.join(
os.path.dirname(__file__), "hf_stabilityai-sd-vae-ft-mse_ckpt.bin"
)
with open(args.src_ckpt_path, "wb") as fh:
fh.write(response.content)
logging.info(
f"Downloaded source pretrained ckpt at {args.src_ckpt_path}"
)
old_state_dict = cstorch.load(args.src_ckpt_path)
# VAE Params for CS modelzoo
with open(args.params_path, "r") as fh:
vae_params = yaml.safe_load(fh)["model"]["vae"]
# Initialize CS VAE model
cs_vae = CSAutoencoderKL(**vae_params)
new_state_dict = cs_vae.state_dict()
logging.info(f"Converting checkpoint...")
# Convert
converter = Converter_VAEModel_HF_CS19()
matched_all_keys = converter.convert_all_keys(
old_state_dict=old_state_dict,
new_state_dict=new_state_dict,
from_index=0,
)
logging.info(f"matched_all_keys:{matched_all_keys}")
cstorch.save(
new_state_dict, args.dest_ckpt_path,
)
logging.info(f"DONE: Converting checkpoint, saved at {args.dest_ckpt_path}")