Compare commits

...

1 Commits

Author SHA1 Message Date
Kohya S
4bc3d0d6d4 Add Dual Resolution U-Net Hires fix 2023-12-25 23:24:02 +09:00
3 changed files with 254 additions and 0 deletions

View File

@@ -1,3 +1,14 @@
# Dual Resolution U-Net Hires fix
- 複数解像度でU-Netを呼び出し、結果を mix する hires fix です。プロンプトオプションでのみ指定できます。
- `--drr` : 初期サイズの解像度の比率。0.5 なら 1/2 になります。
- `--drst` : 複数解像度を開始する timestep。800なら最初から20%経過した時点から開始。
- `--dret` : 終了する timestep。600なら最初から40%経過した時点で終了。
- `--drstr` : 適用する小解像度の重み。1.0 なら大解像度と同じ重みになります。
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.

View File

@@ -454,6 +454,26 @@ class PipelineLike:
self.control_nets: List[ControlNetInfo] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
# Dual Resolution U-Net
self.dual_reso_unet_size_ratio = None
# 低くすると画像が壊れる可能性が高いがディテールが豊かになる(気がする) / Lowering this value may cause the image to break, but it will be more detailed (I think)
self.dual_reso_unet_strength = 1.0
self.dual_reso_unet_start_timesteps = 750
self.dual_reso_unet_end_timesteps = 500
def set_dual_resolution_unet(self, size_ratio, strength, start_timesteps, end_timesteps):
self.dual_reso_unet_size_ratio = size_ratio
self.dual_reso_unet_strength = strength
self.dual_reso_unet_start_timesteps = start_timesteps
self.dual_reso_unet_end_timesteps = end_timesteps
if self.dual_reso_unet_size_ratio is not None:
print(
f"Enable Dual Resolution U-Net: size {self.dual_reso_unet_size_ratio}, strength {self.dual_reso_unet_strength}, timesteps: {self.dual_reso_unet_start_timesteps} - {self.dual_reso_unet_end_timesteps}"
)
else:
print("Disable Dual Resolution U-Net")
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
@@ -863,6 +883,18 @@ class PipelineLike:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# Dual Resolution U-Net
# scale the initial noise by scale factor
if self.dual_reso_unet_size_ratio is not None:
print(f"scale the initial noise by scale factor: {self.dual_reso_unet_size_ratio}")
org_dtype = latents.dtype
if org_dtype == torch.bfloat16:
latents = latents.float()
resize_h = int((height // 8) * self.dual_reso_unet_size_ratio)
resize_w = int((width // 8) * self.dual_reso_unet_size_ratio)
# we don't need any interpolation for noise
latents = torch.nn.functional.interpolate(latents, (resize_h, resize_w), mode="nearest").to(org_dtype)
timesteps = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
@@ -959,6 +991,24 @@ class PipelineLike:
text_emb_last = text_embeddings
for i, t in enumerate(tqdm(timesteps)):
# Dual Resolution U-Net
# t > start: latents is low-res, otherwise: original size
enable_dual_reso_unet = False
if self.dual_reso_unet_size_ratio is not None:
org_dtype = latents.dtype
unet_h = int((height // 8) * self.dual_reso_unet_size_ratio)
unet_w = int((width // 8) * self.dual_reso_unet_size_ratio)
if self.dual_reso_unet_start_timesteps >= t:
# resize latent to original size if necessary (first timesteps only)
if latents.shape[2] == unet_h and latents.shape[3] == unet_w:
print(f"resize latent to original size: {latents.shape} -> {(height // 8, width // 8)}")
if org_dtype == torch.bfloat16:
latents = latents.float()
latents = torch.nn.functional.interpolate(latents, (height // 8, width // 8), mode="bicubic").to(org_dtype)
if self.dual_reso_unet_start_timesteps >= t > self.dual_reso_unet_end_timesteps:
enable_dual_reso_unet = True
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -980,6 +1030,39 @@ class PipelineLike:
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Dual Resolution U-Net
if enable_dual_reso_unet and latents.shape[2] != unet_h and latents.shape[3] != unet_w:
# call U-Net with low resolution
print(f"call low res U-Net: {(unet_h, unet_w)}")
low_res_latents = latents
low_res_latents = latents.repeat((num_latent_input, 1, 1, 1))
low_res_latents = self.scheduler.scale_model_input(low_res_latents, t)
if org_dtype == torch.bfloat16:
low_res_latents = low_res_latents.float()
low_res_latents = torch.nn.functional.interpolate(low_res_latents, (unet_h, unet_w), mode="bicubic").to(org_dtype)
low_res_noise_pred = self.unet(low_res_latents, t, encoder_hidden_states=text_embeddings).sample
if org_dtype == torch.bfloat16:
low_res_noise_pred = low_res_noise_pred.float()
low_res_noise_pred = torch.nn.functional.interpolate(
low_res_noise_pred, (latents.shape[2], latents.shape[3]), mode="bicubic"
).to(org_dtype)
# # unsharp mask
# import torchvision.transforms.functional as TF
# blurred = TF.gaussian_blur(low_res_noise_pred, kernel_size=3, sigma=1.0)
# low_res_noise_pred = low_res_noise_pred + (low_res_noise_pred - blurred) * 0.5
# なぜか足すとうまくいく。ただし色合いが変わる気がする
# I don't know why, but it works well when added. However, the color seems to change.
noise_pred += low_res_noise_pred * self.dual_reso_unet_strength
# # これはうまく動かない / This does not work well
# noise_pred = noise_pred * (1 - self.dual_reso_unet_strength) + low_res_noise_pred * self.dual_reso_unet_strength
# perform guidance
if do_classifier_free_guidance:
if negative_scale is None:
@@ -3104,6 +3187,12 @@ def main(args):
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
# Dual Resolution U-Net
dru_ratio = None # means no override
dru_start_timesteps = 750 # TODO add to args
dru_end_timesteps = 500
dru_strength = 1.0
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -3210,6 +3299,32 @@ def main(args):
print(f"deep shrink ratio: {ds_ratio}")
continue
# Dual Resolution U-Net
if m: # dual resolution u-net ratio
dru_ratio = float(m.group(1))
print(f"dual resolution u-net size ratio: {dru_ratio}")
continue
m = re.match(r"drst ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net start timesteps
dru_start_timesteps = int(m.group(1))
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
print(f"dual resolution u-net start timesteps: {dru_start_timesteps}")
continue
m = re.match(r"dret ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net end timesteps
dru_end_timesteps = int(m.group(1))
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
print(f"dual resolution u-net end timesteps: {dru_end_timesteps}")
continue
m = re.match(r"drstr ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net strength
dru_strength = float(m.group(1))
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
print(f"dual resolution u-net strength: {dru_strength}")
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
@@ -3220,6 +3335,12 @@ def main(args):
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# override Dual Resolution U-Net
if dru_ratio is not None:
if dru_ratio < 0:
dru_ratio = 0.5 # default TODO add to args
pipe.set_dual_resolution_unet(dru_ratio, dru_strength, dru_start_timesteps, dru_end_timesteps)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う

View File

@@ -345,6 +345,26 @@ class PipelineLike:
self.control_nets: List[ControlNetLLLite] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
# Dual Resolution U-Net
self.dual_reso_unet_size_ratio = None
# 低くすると画像が壊れる可能性が高いがディテールが豊かになる(気がする) / Lowering this value may cause the image to break, but it will be more detailed (I think)
self.dual_reso_unet_strength = 1.0
self.dual_reso_unet_start_timesteps = 750
self.dual_reso_unet_end_timesteps = 500
def set_dual_resolution_unet(self, size_ratio, strength, start_timesteps, end_timesteps):
self.dual_reso_unet_size_ratio = size_ratio
self.dual_reso_unet_strength = strength
self.dual_reso_unet_start_timesteps = start_timesteps
self.dual_reso_unet_end_timesteps = end_timesteps
if self.dual_reso_unet_size_ratio is not None:
print(
f"Enable Dual Resolution U-Net: size {self.dual_reso_unet_size_ratio}, strength {self.dual_reso_unet_strength}, timesteps: {self.dual_reso_unet_start_timesteps} - {self.dual_reso_unet_end_timesteps}"
)
else:
print("Disable Dual Resolution U-Net")
# Textual Inversion
def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids):
self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids
@@ -615,6 +635,18 @@ class PipelineLike:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# Dual Resolution U-Net
# scale the initial noise by scale factor
if self.dual_reso_unet_size_ratio is not None:
print(f"scale the initial noise by scale factor: {self.dual_reso_unet_size_ratio}")
org_dtype = latents.dtype
if org_dtype == torch.bfloat16:
latents = latents.float()
resize_h = int((height // 8) * self.dual_reso_unet_size_ratio)
resize_w = int((width // 8) * self.dual_reso_unet_size_ratio)
# we don't need any interpolation for noise
latents = torch.nn.functional.interpolate(latents, (resize_h, resize_w), mode="nearest").to(org_dtype)
timesteps = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
@@ -710,6 +742,24 @@ class PipelineLike:
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
for i, t in enumerate(tqdm(timesteps)):
# Dual Resolution U-Net
# t > start: latents is low-res, otherwise: original size
enable_dual_res_unet = False
if self.dual_reso_unet_size_ratio is not None:
org_dtype = latents.dtype
unet_h = int((height // 8) * self.dual_reso_unet_size_ratio)
unet_w = int((width // 8) * self.dual_reso_unet_size_ratio)
if self.dual_reso_unet_start_timesteps >= t:
# resize latent to original size if necessary (first timesteps only)
if latents.shape[2] == unet_h and latents.shape[3] == unet_w:
print(f"resize latent to original size: {latents.shape} -> {(height // 8, width // 8)}")
if org_dtype == torch.bfloat16:
latents = latents.float()
latents = torch.nn.functional.interpolate(latents, (height // 8, width // 8), mode="bicubic").to(org_dtype)
if self.dual_reso_unet_start_timesteps >= t > self.dual_reso_unet_end_timesteps:
enable_dual_res_unet = True
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -748,6 +798,39 @@ class PipelineLike:
# else:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
# Dual Resolution U-Net
if enable_dual_res_unet and latents.shape[2] != unet_h and latents.shape[3] != unet_w:
# call U-Net with low resolution
print(f"call low res U-Net: {(unet_h, unet_w)}")
low_res_latents = latents
low_res_latents = latents.repeat((num_latent_input, 1, 1, 1))
low_res_latents = self.scheduler.scale_model_input(low_res_latents, t)
if org_dtype == torch.bfloat16:
low_res_latents = low_res_latents.float()
low_res_latents = torch.nn.functional.interpolate(low_res_latents, (unet_h, unet_w), mode="bicubic").to(org_dtype)
# # unsharp mask
# import torchvision.transforms.functional as TF
# blurred = TF.gaussian_blur(low_res_latents, kernel_size=3, sigma=1.0)
# low_res_latents = low_res_latents + (low_res_latents - blurred) * 0.5
low_res_noise_pred = self.unet(low_res_latents, t, text_embeddings, vector_embeddings)
if org_dtype == torch.bfloat16:
low_res_noise_pred = low_res_noise_pred.float()
low_res_noise_pred = torch.nn.functional.interpolate(
low_res_noise_pred, (latents.shape[2], latents.shape[3]), mode="bicubic"
).to(org_dtype)
# なぜか足すとうまくいく。ただし色合いが変わる気がする
# I don't know why, but it works well when added. However, the color seems to change.
noise_pred += low_res_noise_pred * self.dual_reso_unet_strength
# # これはうまく動かない / This does not work well
# noise_pred = noise_pred * (1 - self.dual_res_unet_strength) + low_res_noise_pred * self.dual_res_unet_strength
# perform guidance
if do_classifier_free_guidance:
if negative_scale is None:
@@ -2305,6 +2388,12 @@ def main(args):
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
# Dual Resolution U-Net
dru_ratio = None # means no override
dru_start_timesteps = 750 # TODO add to args
dru_end_timesteps = 500
dru_strength = 1.0
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -2447,6 +2536,33 @@ def main(args):
print(f"deep shrink ratio: {ds_ratio}")
continue
# Dual Resolution U-Net
m = re.match(r"drr ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net ratio
dru_ratio = float(m.group(1))
print(f"dual resolution u-net size ratio: {dru_ratio}")
continue
m = re.match(r"drst ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net start timesteps
dru_start_timesteps = int(m.group(1))
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
print(f"dual resolution u-net start timesteps: {dru_start_timesteps}")
continue
m = re.match(r"dret ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net end timesteps
dru_end_timesteps = int(m.group(1))
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
print(f"dual resolution u-net end timesteps: {dru_end_timesteps}")
continue
m = re.match(r"drstr ([\d\.]+)", parg, re.IGNORECASE)
if m: # dual resolution u-net strength
dru_strength = float(m.group(1))
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
print(f"dual resolution u-net strength: {dru_strength}")
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
@@ -2457,6 +2573,12 @@ def main(args):
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# override Dual Resolution U-Net
if dru_ratio is not None:
if dru_ratio < 0:
dru_ratio = 0.5 # default TODO add to args
pipe.set_dual_resolution_unet(dru_ratio, dru_strength, dru_start_timesteps, dru_end_timesteps)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う