init unet with empty weights

This commit is contained in:
Isotr0py
2023-07-23 13:17:11 +08:00
parent d1864e2430
commit bb167f94ca
2 changed files with 13 additions and 9 deletions

View File

@@ -1,4 +1,6 @@
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
@@ -156,16 +158,15 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
# U-Net
print("building U-Net")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = unet.load_state_dict(unet_sd)
print("U-Net: ", info)
del unet_sd
set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k))
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
# print("U-Net: ", info)
# Text Encoders
print("building text encoders")

View File

@@ -5,6 +5,8 @@ import os
from types import SimpleNamespace
from typing import Any
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from tqdm import tqdm
from transformers import CLIPTokenizer
import open_clip
@@ -92,10 +94,11 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
del pipe
# Diffusers U-Net to original U-Net
original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
original_unet.load_state_dict(state_dict)
unet = original_unet
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
for k in list(state_dict.keys()):
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))
print("U-Net converted to original U-Net")
logit_scale = None