ZeroCool22

convert_to_ckpt.py (2023)

Apr 19th, 2023
193
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.91 KB | None | 0 0
  1. # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
  2. # *Only* converts the UNet, VAE, and Text Encoder.
  3. # Does not convert optimizer state or any other thing.
  4. # Written by jachiam
  5.  
  6. import argparse
  7. import os.path as osp
  8.  
  9. import torch
  10.  
  11.  
  12. # =================#
  13. # UNet Conversion #
  14. # =================#
  15.  
  16. unet_conversion_map = [
  17.     # (stable-diffusion, HF Diffusers)
  18.     ("time_embed.0.weight", "time_embedding.linear_1.weight"),
  19.     ("time_embed.0.bias", "time_embedding.linear_1.bias"),
  20.     ("time_embed.2.weight", "time_embedding.linear_2.weight"),
  21.     ("time_embed.2.bias", "time_embedding.linear_2.bias"),
  22.     ("input_blocks.0.0.weight", "conv_in.weight"),
  23.     ("input_blocks.0.0.bias", "conv_in.bias"),
  24.     ("out.0.weight", "conv_norm_out.weight"),
  25.     ("out.0.bias", "conv_norm_out.bias"),
  26.     ("out.2.weight", "conv_out.weight"),
  27.     ("out.2.bias", "conv_out.bias"),
  28. ]
  29.  
  30. unet_conversion_map_resnet = [
  31.     # (stable-diffusion, HF Diffusers)
  32.     ("in_layers.0", "norm1"),
  33.     ("in_layers.2", "conv1"),
  34.     ("out_layers.0", "norm2"),
  35.     ("out_layers.3", "conv2"),
  36.     ("emb_layers.1", "time_emb_proj"),
  37.     ("skip_connection", "conv_shortcut"),
  38. ]
  39.  
  40. unet_conversion_map_layer = []
  41. # hardcoded number of downblocks and resnets/attentions...
  42. # would need smarter logic for other networks.
  43. for i in range(4):
  44.     # loop over downblocks/upblocks
  45.  
  46.     for j in range(2):
  47.         # loop over resnets/attentions for downblocks
  48.         hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
  49.         sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
  50.         unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
  51.  
  52.         if i < 3:
  53.             # no attention layers in down_blocks.3
  54.             hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
  55.             sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
  56.             unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
  57.  
  58.     for j in range(3):
  59.         # loop over resnets/attentions for upblocks
  60.         hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
  61.         sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
  62.         unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
  63.  
  64.         if i > 0:
  65.             # no attention layers in up_blocks.0
  66.             hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
  67.             sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
  68.             unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
  69.  
  70.     if i < 3:
  71.         # no downsample in down_blocks.3
  72.         hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
  73.         sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
  74.         unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
  75.  
  76.         # no upsample in up_blocks.3
  77.         hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
  78.         sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
  79.         unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
  80.  
  81. hf_mid_atn_prefix = "mid_block.attentions.0."
  82. sd_mid_atn_prefix = "middle_block.1."
  83. unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
  84.  
  85. for j in range(2):
  86.     hf_mid_res_prefix = f"mid_block.resnets.{j}."
  87.     sd_mid_res_prefix = f"middle_block.{2*j}."
  88.     unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
  89.  
  90.  
  91. def convert_unet_state_dict(unet_state_dict):
  92.     # buyer beware: this is a *brittle* function,
  93.     # and correct output requires that all of these pieces interact in
  94.     # the exact order in which I have arranged them.
  95.     mapping = {k: k for k in unet_state_dict.keys()}
  96.     for sd_name, hf_name in unet_conversion_map:
  97.         mapping[hf_name] = sd_name
  98.     for k, v in mapping.items():
  99.         if "resnets" in k:
  100.             for sd_part, hf_part in unet_conversion_map_resnet:
  101.                 v = v.replace(hf_part, sd_part)
  102.             mapping[k] = v
  103.     for k, v in mapping.items():
  104.         for sd_part, hf_part in unet_conversion_map_layer:
  105.             v = v.replace(hf_part, sd_part)
  106.         mapping[k] = v
  107.     new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
  108.     return new_state_dict
  109.  
  110.  
  111. # ================#
  112. # VAE Conversion #
  113. # ================#
  114.  
  115. vae_conversion_map = [
  116.     # (stable-diffusion, HF Diffusers)
  117.     ("nin_shortcut", "conv_shortcut"),
  118.     ("norm_out", "conv_norm_out"),
  119.     ("mid.attn_1.", "mid_block.attentions.0."),
  120. ]
  121.  
  122. for i in range(4):
  123.     # down_blocks have two resnets
  124.     for j in range(2):
  125.         hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
  126.         sd_down_prefix = f"encoder.down.{i}.block.{j}."
  127.         vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
  128.  
  129.     if i < 3:
  130.         hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
  131.         sd_downsample_prefix = f"down.{i}.downsample."
  132.         vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
  133.  
  134.         hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
  135.         sd_upsample_prefix = f"up.{3-i}.upsample."
  136.         vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
  137.  
  138.     # up_blocks have three resnets
  139.     # also, up blocks in hf are numbered in reverse from sd
  140.     for j in range(3):
  141.         hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
  142.         sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
  143.         vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
  144.  
  145. # this part accounts for mid blocks in both the encoder and the decoder
  146. for i in range(2):
  147.     hf_mid_res_prefix = f"mid_block.resnets.{i}."
  148.     sd_mid_res_prefix = f"mid.block_{i+1}."
  149.     vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
  150.  
  151.  
  152. vae_conversion_map_attn = [
  153.     # (stable-diffusion, HF Diffusers)
  154.     ("norm.", "group_norm."),
  155.     ("q.", "query."),
  156.     ("k.", "key."),
  157.     ("v.", "value."),
  158.     ("proj_out.", "proj_attn."),
  159. ]
  160.  
  161.  
  162. def reshape_weight_for_sd(w):
  163.     # convert HF linear weights to SD conv2d weights
  164.     return w.reshape(*w.shape, 1, 1)
  165.  
  166.  
  167. def convert_vae_state_dict(vae_state_dict):
  168.     mapping = {k: k for k in vae_state_dict.keys()}
  169.     for k, v in mapping.items():
  170.         for sd_part, hf_part in vae_conversion_map:
  171.             v = v.replace(hf_part, sd_part)
  172.         mapping[k] = v
  173.     for k, v in mapping.items():
  174.         if "attentions" in k:
  175.             for sd_part, hf_part in vae_conversion_map_attn:
  176.                 v = v.replace(hf_part, sd_part)
  177.             mapping[k] = v
  178.     new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
  179.     weights_to_convert = ["q", "k", "v", "proj_out"]
  180.     for k, v in new_state_dict.items():
  181.         for weight_name in weights_to_convert:
  182.             if f"mid.attn_1.{weight_name}.weight" in k:
  183.                 print(f"Reshaping {k} for SD format")
  184.                 new_state_dict[k] = reshape_weight_for_sd(v)
  185.     return new_state_dict
  186.  
  187.  
  188. # =========================#
  189. # Text Encoder Conversion #
  190. # =========================#
  191. # pretty much a no-op
  192.  
  193.  
  194. def convert_text_enc_state_dict(text_enc_dict):
  195.     return text_enc_dict
  196.  
  197.  
  198. if __name__ == "__main__":
  199.     parser = argparse.ArgumentParser()
  200.  
  201.     parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
  202.     parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
  203.     parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
  204.  
  205.     args = parser.parse_args()
  206.  
  207.     assert args.model_path is not None, "Must provide a model path!"
  208.  
  209.     assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
  210.  
  211.     unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
  212.     vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
  213.     text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
  214.  
  215.     # Convert the UNet model
  216.     unet_state_dict = torch.load(unet_path, map_location='cpu')
  217.     unet_state_dict = convert_unet_state_dict(unet_state_dict)
  218.     unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
  219.  
  220.     # Convert the VAE model
  221.     vae_state_dict = torch.load(vae_path, map_location='cpu')
  222.     vae_state_dict = convert_vae_state_dict(vae_state_dict)
  223.     vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
  224.  
  225.     # Convert the text encoder model
  226.     text_enc_dict = torch.load(text_enc_path, map_location='cpu')
  227.     text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
  228.     text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
  229.  
  230.     # Put together new checkpoint
  231.     state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
  232.     if args.half:
  233.         state_dict = {k:v.half() for k,v in state_dict.items()}
  234.     state_dict = {"state_dict": state_dict}
  235.     torch.save(state_dict, args.checkpoint_path)
Add Comment
Please, Sign In to add comment