From bb167f94ca417e97ea1a6018b17119df6abade91 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 23 Jul 2023 13:17:11 +0800 Subject: [PATCH] init unet with empty weights --- library/sdxl_model_util.py | 13 +++++++------ library/sdxl_train_util.py | 9 ++++++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 41a05e95..69357517 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,6 @@ import torch +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel @@ -156,16 +158,15 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # U-Net print("building U-Net") - unet = sdxl_original_unet.SdxlUNet2DConditionModel() + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() print("loading U-Net from checkpoint") - unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): - unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = unet.load_state_dict(unet_sd) - print("U-Net: ", info) - del unet_sd + set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k)) + # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys + # print("U-Net: ", info) # Text Encoders print("building text encoders") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 34312afc..f37cadab 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,6 +5,8 @@ import os from types import SimpleNamespace from typing import Any import torch +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device from tqdm import tqdm from transformers import CLIPTokenizer import open_clip @@ -92,10 +94,11 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp del pipe # Diffusers U-Net to original U-Net - original_unet = sdxl_original_unet.SdxlUNet2DConditionModel() state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) - original_unet.load_state_dict(state_dict) - unet = original_unet + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + for k in list(state_dict.keys()): + set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k)) print("U-Net converted to original U-Net") logit_scale = None