mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
1 Commits
new_cache
...
dual_reso_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bc3d0d6d4 |
11
README.md
11
README.md
@@ -1,3 +1,14 @@
|
||||
# Dual Resolution U-Net Hires fix
|
||||
|
||||
- 複数解像度でU-Netを呼び出し、結果を mix する hires fix です。プロンプトオプションでのみ指定できます。
|
||||
- `--drr` : 初期サイズの解像度の比率。0.5 なら 1/2 になります。
|
||||
- `--drst` : 複数解像度を開始する timestep。800なら最初から20%経過した時点から開始。
|
||||
- `--dret` : 終了する timestep。600なら最初から40%経過した時点で終了。
|
||||
- `--drstr` : 適用する小解像度の重み。1.0 なら大解像度と同じ重みになります。
|
||||
|
||||
---
|
||||
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
@@ -454,6 +454,26 @@ class PipelineLike:
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
# Dual Resolution U-Net
|
||||
self.dual_reso_unet_size_ratio = None
|
||||
|
||||
# 低くすると画像が壊れる可能性が高いがディテールが豊かになる(気がする) / Lowering this value may cause the image to break, but it will be more detailed (I think)
|
||||
self.dual_reso_unet_strength = 1.0
|
||||
self.dual_reso_unet_start_timesteps = 750
|
||||
self.dual_reso_unet_end_timesteps = 500
|
||||
|
||||
def set_dual_resolution_unet(self, size_ratio, strength, start_timesteps, end_timesteps):
|
||||
self.dual_reso_unet_size_ratio = size_ratio
|
||||
self.dual_reso_unet_strength = strength
|
||||
self.dual_reso_unet_start_timesteps = start_timesteps
|
||||
self.dual_reso_unet_end_timesteps = end_timesteps
|
||||
if self.dual_reso_unet_size_ratio is not None:
|
||||
print(
|
||||
f"Enable Dual Resolution U-Net: size {self.dual_reso_unet_size_ratio}, strength {self.dual_reso_unet_strength}, timesteps: {self.dual_reso_unet_start_timesteps} - {self.dual_reso_unet_end_timesteps}"
|
||||
)
|
||||
else:
|
||||
print("Disable Dual Resolution U-Net")
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
@@ -863,6 +883,18 @@ class PipelineLike:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# Dual Resolution U-Net
|
||||
# scale the initial noise by scale factor
|
||||
if self.dual_reso_unet_size_ratio is not None:
|
||||
print(f"scale the initial noise by scale factor: {self.dual_reso_unet_size_ratio}")
|
||||
org_dtype = latents.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
latents = latents.float()
|
||||
resize_h = int((height // 8) * self.dual_reso_unet_size_ratio)
|
||||
resize_w = int((width // 8) * self.dual_reso_unet_size_ratio)
|
||||
# we don't need any interpolation for noise
|
||||
latents = torch.nn.functional.interpolate(latents, (resize_h, resize_w), mode="nearest").to(org_dtype)
|
||||
|
||||
timesteps = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
@@ -959,6 +991,24 @@ class PipelineLike:
|
||||
text_emb_last = text_embeddings
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# Dual Resolution U-Net
|
||||
# t > start: latents is low-res, otherwise: original size
|
||||
enable_dual_reso_unet = False
|
||||
if self.dual_reso_unet_size_ratio is not None:
|
||||
org_dtype = latents.dtype
|
||||
unet_h = int((height // 8) * self.dual_reso_unet_size_ratio)
|
||||
unet_w = int((width // 8) * self.dual_reso_unet_size_ratio)
|
||||
|
||||
if self.dual_reso_unet_start_timesteps >= t:
|
||||
# resize latent to original size if necessary (first timesteps only)
|
||||
if latents.shape[2] == unet_h and latents.shape[3] == unet_w:
|
||||
print(f"resize latent to original size: {latents.shape} -> {(height // 8, width // 8)}")
|
||||
if org_dtype == torch.bfloat16:
|
||||
latents = latents.float()
|
||||
latents = torch.nn.functional.interpolate(latents, (height // 8, width // 8), mode="bicubic").to(org_dtype)
|
||||
if self.dual_reso_unet_start_timesteps >= t > self.dual_reso_unet_end_timesteps:
|
||||
enable_dual_reso_unet = True
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -980,6 +1030,39 @@ class PipelineLike:
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# Dual Resolution U-Net
|
||||
if enable_dual_reso_unet and latents.shape[2] != unet_h and latents.shape[3] != unet_w:
|
||||
# call U-Net with low resolution
|
||||
print(f"call low res U-Net: {(unet_h, unet_w)}")
|
||||
low_res_latents = latents
|
||||
low_res_latents = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
low_res_latents = self.scheduler.scale_model_input(low_res_latents, t)
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
low_res_latents = low_res_latents.float()
|
||||
low_res_latents = torch.nn.functional.interpolate(low_res_latents, (unet_h, unet_w), mode="bicubic").to(org_dtype)
|
||||
|
||||
low_res_noise_pred = self.unet(low_res_latents, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
low_res_noise_pred = low_res_noise_pred.float()
|
||||
low_res_noise_pred = torch.nn.functional.interpolate(
|
||||
low_res_noise_pred, (latents.shape[2], latents.shape[3]), mode="bicubic"
|
||||
).to(org_dtype)
|
||||
|
||||
# # unsharp mask
|
||||
# import torchvision.transforms.functional as TF
|
||||
|
||||
# blurred = TF.gaussian_blur(low_res_noise_pred, kernel_size=3, sigma=1.0)
|
||||
# low_res_noise_pred = low_res_noise_pred + (low_res_noise_pred - blurred) * 0.5
|
||||
|
||||
# なぜか足すとうまくいく。ただし色合いが変わる気がする
|
||||
# I don't know why, but it works well when added. However, the color seems to change.
|
||||
noise_pred += low_res_noise_pred * self.dual_reso_unet_strength
|
||||
|
||||
# # これはうまく動かない / This does not work well
|
||||
# noise_pred = noise_pred * (1 - self.dual_reso_unet_strength) + low_res_noise_pred * self.dual_reso_unet_strength
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if negative_scale is None:
|
||||
@@ -3104,6 +3187,12 @@ def main(args):
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
# Dual Resolution U-Net
|
||||
dru_ratio = None # means no override
|
||||
dru_start_timesteps = 750 # TODO add to args
|
||||
dru_end_timesteps = 500
|
||||
dru_strength = 1.0
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -3210,6 +3299,32 @@ def main(args):
|
||||
print(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
# Dual Resolution U-Net
|
||||
if m: # dual resolution u-net ratio
|
||||
dru_ratio = float(m.group(1))
|
||||
print(f"dual resolution u-net size ratio: {dru_ratio}")
|
||||
continue
|
||||
|
||||
m = re.match(r"drst ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net start timesteps
|
||||
dru_start_timesteps = int(m.group(1))
|
||||
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
|
||||
print(f"dual resolution u-net start timesteps: {dru_start_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dret ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net end timesteps
|
||||
dru_end_timesteps = int(m.group(1))
|
||||
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
|
||||
print(f"dual resolution u-net end timesteps: {dru_end_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"drstr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net strength
|
||||
dru_strength = float(m.group(1))
|
||||
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
|
||||
print(f"dual resolution u-net strength: {dru_strength}")
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
@@ -3220,6 +3335,12 @@ def main(args):
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# override Dual Resolution U-Net
|
||||
if dru_ratio is not None:
|
||||
if dru_ratio < 0:
|
||||
dru_ratio = 0.5 # default TODO add to args
|
||||
pipe.set_dual_resolution_unet(dru_ratio, dru_strength, dru_start_timesteps, dru_end_timesteps)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
|
||||
122
sdxl_gen_img.py
122
sdxl_gen_img.py
@@ -345,6 +345,26 @@ class PipelineLike:
|
||||
self.control_nets: List[ControlNetLLLite] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
# Dual Resolution U-Net
|
||||
self.dual_reso_unet_size_ratio = None
|
||||
|
||||
# 低くすると画像が壊れる可能性が高いがディテールが豊かになる(気がする) / Lowering this value may cause the image to break, but it will be more detailed (I think)
|
||||
self.dual_reso_unet_strength = 1.0
|
||||
self.dual_reso_unet_start_timesteps = 750
|
||||
self.dual_reso_unet_end_timesteps = 500
|
||||
|
||||
def set_dual_resolution_unet(self, size_ratio, strength, start_timesteps, end_timesteps):
|
||||
self.dual_reso_unet_size_ratio = size_ratio
|
||||
self.dual_reso_unet_strength = strength
|
||||
self.dual_reso_unet_start_timesteps = start_timesteps
|
||||
self.dual_reso_unet_end_timesteps = end_timesteps
|
||||
if self.dual_reso_unet_size_ratio is not None:
|
||||
print(
|
||||
f"Enable Dual Resolution U-Net: size {self.dual_reso_unet_size_ratio}, strength {self.dual_reso_unet_strength}, timesteps: {self.dual_reso_unet_start_timesteps} - {self.dual_reso_unet_end_timesteps}"
|
||||
)
|
||||
else:
|
||||
print("Disable Dual Resolution U-Net")
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids):
|
||||
self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids
|
||||
@@ -615,6 +635,18 @@ class PipelineLike:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# Dual Resolution U-Net
|
||||
# scale the initial noise by scale factor
|
||||
if self.dual_reso_unet_size_ratio is not None:
|
||||
print(f"scale the initial noise by scale factor: {self.dual_reso_unet_size_ratio}")
|
||||
org_dtype = latents.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
latents = latents.float()
|
||||
resize_h = int((height // 8) * self.dual_reso_unet_size_ratio)
|
||||
resize_w = int((width // 8) * self.dual_reso_unet_size_ratio)
|
||||
# we don't need any interpolation for noise
|
||||
latents = torch.nn.functional.interpolate(latents, (resize_h, resize_w), mode="nearest").to(org_dtype)
|
||||
|
||||
timesteps = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
@@ -710,6 +742,24 @@ class PipelineLike:
|
||||
|
||||
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# Dual Resolution U-Net
|
||||
# t > start: latents is low-res, otherwise: original size
|
||||
enable_dual_res_unet = False
|
||||
if self.dual_reso_unet_size_ratio is not None:
|
||||
org_dtype = latents.dtype
|
||||
unet_h = int((height // 8) * self.dual_reso_unet_size_ratio)
|
||||
unet_w = int((width // 8) * self.dual_reso_unet_size_ratio)
|
||||
|
||||
if self.dual_reso_unet_start_timesteps >= t:
|
||||
# resize latent to original size if necessary (first timesteps only)
|
||||
if latents.shape[2] == unet_h and latents.shape[3] == unet_w:
|
||||
print(f"resize latent to original size: {latents.shape} -> {(height // 8, width // 8)}")
|
||||
if org_dtype == torch.bfloat16:
|
||||
latents = latents.float()
|
||||
latents = torch.nn.functional.interpolate(latents, (height // 8, width // 8), mode="bicubic").to(org_dtype)
|
||||
if self.dual_reso_unet_start_timesteps >= t > self.dual_reso_unet_end_timesteps:
|
||||
enable_dual_res_unet = True
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -748,6 +798,39 @@ class PipelineLike:
|
||||
# else:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
|
||||
# Dual Resolution U-Net
|
||||
if enable_dual_res_unet and latents.shape[2] != unet_h and latents.shape[3] != unet_w:
|
||||
# call U-Net with low resolution
|
||||
print(f"call low res U-Net: {(unet_h, unet_w)}")
|
||||
low_res_latents = latents
|
||||
low_res_latents = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
low_res_latents = self.scheduler.scale_model_input(low_res_latents, t)
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
low_res_latents = low_res_latents.float()
|
||||
low_res_latents = torch.nn.functional.interpolate(low_res_latents, (unet_h, unet_w), mode="bicubic").to(org_dtype)
|
||||
|
||||
# # unsharp mask
|
||||
# import torchvision.transforms.functional as TF
|
||||
|
||||
# blurred = TF.gaussian_blur(low_res_latents, kernel_size=3, sigma=1.0)
|
||||
# low_res_latents = low_res_latents + (low_res_latents - blurred) * 0.5
|
||||
|
||||
low_res_noise_pred = self.unet(low_res_latents, t, text_embeddings, vector_embeddings)
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
low_res_noise_pred = low_res_noise_pred.float()
|
||||
low_res_noise_pred = torch.nn.functional.interpolate(
|
||||
low_res_noise_pred, (latents.shape[2], latents.shape[3]), mode="bicubic"
|
||||
).to(org_dtype)
|
||||
|
||||
# なぜか足すとうまくいく。ただし色合いが変わる気がする
|
||||
# I don't know why, but it works well when added. However, the color seems to change.
|
||||
noise_pred += low_res_noise_pred * self.dual_reso_unet_strength
|
||||
|
||||
# # これはうまく動かない / This does not work well
|
||||
# noise_pred = noise_pred * (1 - self.dual_res_unet_strength) + low_res_noise_pred * self.dual_res_unet_strength
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if negative_scale is None:
|
||||
@@ -2305,6 +2388,12 @@ def main(args):
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
# Dual Resolution U-Net
|
||||
dru_ratio = None # means no override
|
||||
dru_start_timesteps = 750 # TODO add to args
|
||||
dru_end_timesteps = 500
|
||||
dru_strength = 1.0
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -2447,6 +2536,33 @@ def main(args):
|
||||
print(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
# Dual Resolution U-Net
|
||||
m = re.match(r"drr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net ratio
|
||||
dru_ratio = float(m.group(1))
|
||||
print(f"dual resolution u-net size ratio: {dru_ratio}")
|
||||
continue
|
||||
|
||||
m = re.match(r"drst ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net start timesteps
|
||||
dru_start_timesteps = int(m.group(1))
|
||||
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
|
||||
print(f"dual resolution u-net start timesteps: {dru_start_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dret ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net end timesteps
|
||||
dru_end_timesteps = int(m.group(1))
|
||||
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
|
||||
print(f"dual resolution u-net end timesteps: {dru_end_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"drstr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # dual resolution u-net strength
|
||||
dru_strength = float(m.group(1))
|
||||
dru_ratio = dru_ratio if dru_ratio is not None else -1 # -1 means override
|
||||
print(f"dual resolution u-net strength: {dru_strength}")
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
@@ -2457,6 +2573,12 @@ def main(args):
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# override Dual Resolution U-Net
|
||||
if dru_ratio is not None:
|
||||
if dru_ratio < 0:
|
||||
dru_ratio = 0.5 # default TODO add to args
|
||||
pipe.set_dual_resolution_unet(dru_ratio, dru_strength, dru_start_timesteps, dru_end_timesteps)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
|
||||
Reference in New Issue
Block a user