change method name, add comments

This commit is contained in:
Kohya S
2023-07-30 13:34:07 +09:00
parent e6034b7eb6
commit b62185b821
2 changed files with 11 additions and 15 deletions

View File

@@ -136,7 +136,8 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
return new_sd, logit_scale
def _load_state_dict(model, state_dict, device, dtype=None):
# load state_dict without allocating new tensors
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
# dtype will use fp32 as default
missing_keys = list(model.state_dict().keys() - state_dict.keys())
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
@@ -145,26 +146,21 @@ def _load_state_dict(model, state_dict, device, dtype=None):
if not missing_keys and not unexpected_keys:
for k in list(state_dict.keys()):
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
return '<All keys matched successfully>'
return "<All keys matched successfully>"
# error_msgs
error_msgs: List[str] = []
if missing_keys:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
if unexpected_keys:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
# model_version is reserved for future use
# dtype is reserved for full_fp16/bf16 integration
# dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# Load the state dict
if model_util.is_safetensors(ckpt_path):
@@ -172,7 +168,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else:
@@ -197,7 +193,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict(unet, unet_sd, device=map_location)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location)
print("U-Net: ", info)
# Text Encoders

View File

@@ -98,8 +98,8 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
# Diffusers U-Net to original U-Net
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
sdxl_model_util._load_state_dict(unet, state_dict, device=device)
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
print("U-Net converted to original U-Net")
logit_scale = None