Compare commits

..

4 Commits

Author SHA1 Message Date
Kohya S
dbe78a8638 scale crafter 2024-02-25 08:53:43 +09:00
Kohya S
bae116a031 Merge branch 'dev' into flexible-zero-slicing 2024-02-24 22:42:41 +09:00
Kohya S
6aa2d99219 make mask for flexible zero slicing from attncouple mask 2024-02-24 21:38:57 +09:00
Kohya S
725bab124b impl flexible zero slicing 2024-02-24 21:00:38 +09:00
7 changed files with 374 additions and 149 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.17.2
uses: crate-ci/typos@v1.16.26

View File

@@ -249,16 +249,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Mar 15, 2024 / 2024/3/15: v0.8.5
- Fixed a bug that the value of timestep embedding during SDXL training was incorrect.
- The inference with the generation script is also fixed.
- The impact is unknown, but please update for SDXL training.
- SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。
- 生成スクリプトでの推論時についてもあわせて修正しました。
- 影響の度合いは不明ですが、SDXL の学習時にはアップデートをお願いいたします。
### Feb 24, 2024 / 2024/2/24: v0.8.4
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!

View File

@@ -1,5 +1,6 @@
import itertools
import json
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
@@ -61,12 +62,6 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
@@ -88,12 +83,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")
print("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
logger.info("Enable xformers for U-Net")
print("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
@@ -101,7 +96,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
logger.info("Enable SDPA for U-Net")
print("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
@@ -118,7 +113,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
def replace_vae_attn_to_memory_efficient():
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs):
@@ -174,7 +169,7 @@ def replace_vae_attn_to_memory_efficient():
def replace_vae_attn_to_xformers():
logger.info("VAE: Attention.forward has been replaced to xformers")
print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
@@ -230,7 +225,7 @@ def replace_vae_attn_to_xformers():
def replace_vae_attn_to_sdpa():
logger.info("VAE: Attention.forward has been replaced to sdpa")
print("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
@@ -392,10 +387,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
logger.info("gradual_latent is disabled")
print("gradual_latent is disabled")
self.gradual_latent = None
else:
logger.info(f"gradual_latent is enabled: {gradual_latent}")
print(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
@@ -473,7 +468,7 @@ class PipelineLike:
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
print(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
@@ -582,7 +577,7 @@ class PipelineLike:
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None:
logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
@@ -748,8 +743,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
logger.warning("gradual_latent is not supported for this scheduler. Ignoring.")
logger.warning(f"{self.scheduler.__class__.__name__}")
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -798,7 +793,7 @@ class PipelineLike:
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
@@ -1019,7 +1014,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
logger.info(f"BREAK pad_len: {pad_len}")
print(f"BREAK pad_len: {pad_len}")
for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0:
@@ -1049,7 +1044,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -1350,7 +1345,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
logger.warning(f"invalid count range: {count_range}")
print(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
@@ -1494,9 +1489,9 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデルを読み込む
if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
@@ -1516,7 +1511,7 @@ def main(args):
else:
# if `text_encoder_2` subdirectory exists, sdxl
is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2"))
logger.info(f"SDXL: {is_sdxl}")
print(f"SDXL: {is_sdxl}")
if is_sdxl:
if args.clip_skip is None:
@@ -1532,10 +1527,10 @@ def main(args):
args.clip_skip = 2 if args.v2 else 1
if use_stable_diffusion_format:
logger.info("load StableDiffusion checkpoint")
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else:
logger.info("load Diffusers pretrained models")
print("load Diffusers pretrained models")
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae
@@ -1559,7 +1554,7 @@ def main(args):
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
logger.info("additional VAE loaded")
print("additional VAE loaded")
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
@@ -1568,7 +1563,7 @@ def main(args):
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
logger.info("loading tokenizer")
print("loading tokenizer")
if is_sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
@@ -1660,7 +1655,7 @@ def main(args):
noise = None
if noise == None:
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1
@@ -1721,7 +1716,7 @@ def main(args):
vae_dtype = dtype
if args.no_half_vae:
logger.info("set vae_dtype to float32")
print("set vae_dtype to float32")
vae_dtype = torch.float32
vae.to(vae_dtype).to(device)
vae.eval()
@@ -1745,10 +1740,10 @@ def main(args):
network_merge = args.network_merge_n_models
else:
network_merge = 0
logger.info(f"network_merge: {network_merge}")
print(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module):
logger.info("import network module: {network_module}")
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -1766,7 +1761,7 @@ def main(args):
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
logger.info(f"load network weights from: {network_weight}")
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
@@ -1774,7 +1769,7 @@ def main(args):
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
logger.info(f"metadata for: {network_weight}: {metadata}")
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs
@@ -1784,20 +1779,20 @@ def main(args):
mergeable = network.is_mergeable()
if network_merge and not mergeable:
logger.warning("network is not mergiable. ignore merge option.")
print("network is not mergiable. ignore merge option.")
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoders, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
logger.info(f"weights are loaded: {info}")
print(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if network_pre_calc:
logger.info("backup original weights")
print("backup original weights")
network.backup_weights()
networks.append(network)
@@ -1811,7 +1806,7 @@ def main(args):
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
logger.info("import upscaler module: {args.highres_fix_upscaler}")
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
@@ -1820,7 +1815,7 @@ def main(args):
key, value = net_arg.split("=")
us_kwargs[key] = value
logger.info("create upscaler")
print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
@@ -1839,7 +1834,7 @@ def main(args):
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:
for i, model_file in enumerate(args.control_net_lllite_models):
logger.info(f"loading ControlNet-LLLite: {model_file}")
print(f"loading ControlNet-LLLite: {model_file}")
from safetensors.torch import load_file
@@ -1873,7 +1868,7 @@ def main(args):
), "ControlNet and ControlNet-LLLite cannot be used at the same time"
if args.opt_channels_last:
logger.info(f"set optimizing: channels last")
print(f"set optimizing: channels last")
for text_encoder in text_encoders:
text_encoder.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last)
@@ -1900,7 +1895,7 @@ def main(args):
)
pipe.set_control_nets(control_nets)
pipe.set_control_net_lllites(control_net_lllites)
logger.info("pipeline is ready.")
print("pipeline is ready.")
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
@@ -1971,7 +1966,7 @@ def main(args):
token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings)
token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
assert (
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
), f"token ids1 is not ordered"
@@ -2008,7 +2003,7 @@ def main(args):
# promptを取得する
prompt_list = None
if args.from_file is not None:
logger.info(f"reading prompts from {args.from_file}")
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
@@ -2025,7 +2020,7 @@ def main(args):
spec.loader.exec_module(module)
return module
logger.info(f"reading prompts from module: {args.from_module}")
print(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)
prompter = prompt_module.get_prompter(args, pipe, networks)
@@ -2056,7 +2051,7 @@ def main(args):
for p in paths:
image = Image.open(p)
if image.mode != "RGB":
logger.info(f"convert image to RGB from {image.mode}: {p}")
print(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB")
images.append(image)
@@ -2072,14 +2067,14 @@ def main(args):
return resized
if args.image_path is not None:
logger.info(f"load image for img2img: {args.image_path}")
print(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
logger.info(f"loaded {len(init_images)} images for img2img")
print(f"loaded {len(init_images)} images for img2img")
# CLIP Vision
if args.clip_vision_strength is not None:
logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
vision_model.to(device, dtype)
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
@@ -2087,22 +2082,22 @@ def main(args):
pipe.clip_vision_model = vision_model
pipe.clip_vision_processor = processor
pipe.clip_vision_strength = args.clip_vision_strength
logger.info(f"CLIP Vision model loaded.")
print(f"CLIP Vision model loaded.")
else:
init_images = None
if args.mask_path is not None:
logger.info(f"load mask for inpainting: {args.mask_path}")
print(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
print(f"loaded {len(mask_images)} mask images for inpainting")
else:
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and prompter is None and not args.interactive:
logger.info("get prompts from images' metadata")
print("get prompts from images' metadata")
prompt_list = []
for img in init_images:
if "prompt" in img.text:
@@ -2124,6 +2119,37 @@ def main(args):
l.extend([im] * args.images_per_prompt)
mask_images = l
# Flexible Zero Slicing
if args.flexible_zero_slicing_depth is not None:
# CV2 が必要
import cv2
# mask 画像は背景 255、zero にする部分 0 とする
np_mask = np.array(mask_images[0].convert("RGB"))
fz_mask = np.full(np_mask.shape, 255, dtype=np.uint8)
# 各チャンネルに対して処理
for i in range(3):
# チャンネルを抽出
channel = np_mask[:, :, i]
# 輪郭を検出
contours, _ = cv2.findContours(channel, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 輪郭を新しい配列に描画
cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1)
fz_mask = fz_mask.astype(np.float32) / 255.0
fz_mask = fz_mask[:, :, 0]
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
# only for sdxl
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
# Dilated Conv Hires fix
if args.dilated_conv_hires_fix_depth is not None:
unet.set_dilated_conv(args.dilated_conv_hires_fix_depth, args.dilated_conv_hires_fix_timesteps)
# 画像サイズにオプション指定があるときはリサイズする
if args.W is not None and args.H is not None:
# highres fix を考慮に入れる
@@ -2133,17 +2159,17 @@ def main(args):
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
logger.info(f"resize img2img source images to {w}*{h}")
print(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
logger.info(f"resize img2img mask images to {w}*{h}")
print(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
logger.info("use mask as region")
print("use mask as region")
size = None
for i, network in enumerate(networks):
@@ -2152,10 +2178,17 @@ def main(args):
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
# 0-7: RGB 3bitで8色, 0/255
# 8-15: RGB 3bitで8色, 0/127
code = (i % 7) + 1
r = code & 1
g = (code & 2) >> 1
b = (code & 4) >> 2
if i < 7:
color = (r * 255, g * 255, b * 255)
else:
color = (r * 127, g * 127, b * 127)
np_mask = np.all(np_mask == color, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
@@ -2168,14 +2201,14 @@ def main(args):
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
logger.info(f"load image for ControlNet guidance: {args.guide_image_path}")
print(f"load image for ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
logger.info(f"loaded {len(guide_images)} guide images for guidance")
print(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
logger.warning(
print(
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
)
guide_images = None
@@ -2206,7 +2239,7 @@ def main(args):
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter):
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
print(f"iteration {gen_iter+1}/{args.n_iter}")
if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
@@ -2225,7 +2258,7 @@ def main(args):
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
logger.info("process 1st stage")
print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
@@ -2270,7 +2303,7 @@ def main(args):
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
logger.info("process 2nd stage")
print("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
@@ -2443,7 +2476,7 @@ def main(args):
n.restore_weights()
for n in networks:
n.pre_calculation()
logger.info("pre-calculation... done")
print("pre-calculation... done")
images = pipe(
prompts,
@@ -2526,7 +2559,7 @@ def main(args):
cv2.waitKey()
cv2.destroyAllWindows()
except ImportError:
logger.warning(
print(
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
)
@@ -2541,7 +2574,7 @@ def main(args):
# interactive
valid = False
while not valid:
logger.info("\nType prompt:")
print("\nType prompt:")
try:
raw_prompt = input()
except EOFError:
@@ -2601,74 +2634,74 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
length = len(prompter) if hasattr(prompter, "__len__") else 0
logger.info(f"prompt {prompt_index+1}/{length}: {prompt}")
print(f"prompt {prompt_index+1}/{length}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
logger.info(f"width: {width}")
print(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
logger.info(f"height: {height}")
print(f"height: {height}")
continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m:
original_width = int(m.group(1))
logger.info(f"original width: {original_width}")
print(f"original width: {original_width}")
continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m:
original_height = int(m.group(1))
logger.info(f"original height: {original_height}")
print(f"original height: {original_height}")
continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
logger.info(f"original width negative: {original_width_negative}")
print(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
logger.info(f"original height negative: {original_height_negative}")
print(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
logger.info(f"crop top: {crop_top}")
print(f"crop top: {crop_top}")
continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m:
crop_left = int(m.group(1))
logger.info(f"crop left: {crop_left}")
print(f"crop left: {crop_left}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
logger.info(f"steps: {steps}")
print(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
logger.info(f"seeds: {seeds}")
print(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
logger.info(f"scale: {scale}")
print(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
@@ -2677,25 +2710,25 @@ def main(args):
negative_scale = None
else:
negative_scale = float(m.group(1))
logger.info(f"negative scale: {negative_scale}")
print(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
logger.info(f"strength: {strength}")
print(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
logger.info(f"negative prompt: {negative_prompt}")
print(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
logger.info(f"clip prompt: {clip_prompt}")
print(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
@@ -2703,89 +2736,89 @@ def main(args):
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
logger.info(f"network mul: {network_muls}")
print(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
logger.info(f"deep shrink depth 1: {ds_depth_1}")
print(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink depth 2: {ds_depth_2}")
print(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink ratio: {ds_ratio}")
print(f"deep shrink ratio: {ds_ratio}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
print(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
print(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
print(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
print(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
print(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
print(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
# override Deep Shrink
if ds_depth_1 is not None:
@@ -2831,7 +2864,7 @@ def main(args):
if seed is None:
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
logger.info(f"seed: {seed}")
print(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
@@ -2847,7 +2880,7 @@ def main(args):
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
logger.warning(
print(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
@@ -2909,14 +2942,12 @@ def main(args):
process_batch(batch_data, highres_fix)
batch_data.clear()
logger.info("done!")
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
)
@@ -3320,6 +3351,38 @@ def setup_parser() -> argparse.ArgumentParser:
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
)
# parser.add_argument(
# "--flexible_zero_slicing_mask",
# type=str,
# default=None,
# help="mask for flexible zero slicing / flexible zero slicingのマスク",
# )
parser.add_argument(
"--flexible_zero_slicing_depth",
type=int,
default=None,
help="depth for flexible zero slicing / flexible zero slicingのdepth",
)
parser.add_argument(
"--flexible_zero_slicing_timesteps",
type=int,
default=None,
help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps",
)
parser.add_argument(
"--dilated_conv_hires_fix_depth",
type=int,
default=None,
help="depth for dilated conv hires fix / dilated conv hires fixのdepth",
)
parser.add_argument(
"--dilated_conv_hires_fix_timesteps",
type=int,
default=None,
help="timesteps for dilated conv hires fix / dilated conv hires fixのtimesteps",
)
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

View File

@@ -489,10 +489,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
logger.info("gradual_latent is disabled")
print("gradual_latent is disabled")
self.gradual_latent = None
else:
logger.info(f"gradual_latent is enabled: {gradual_latent}")
print(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
# region xformersとか使う部分独自に書き換えるので関係なし
@@ -971,8 +971,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -3314,42 +3314,42 @@ def main(args):
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
print(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
print(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
print(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
print(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
print(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
print(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
@@ -3369,7 +3369,7 @@ def main(args):
if gl_unsharp_params is not None:
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
logger.info(f'{unsharp_params}')
print(unsharp_params)
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:

View File

@@ -24,7 +24,7 @@
import math
from types import SimpleNamespace
from typing import Any, Optional
from typing import Any, List, Optional
import torch
import torch.utils.checkpoint
from torch import nn
@@ -1076,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
@@ -1116,6 +1116,46 @@ class SdxlUNet2DConditionModel(nn.Module):
return h
def get_mask_from_mask_dic(mask_dic, shape):
if mask_dic is None or len(mask_dic) == 0:
return None
mask = mask_dic.get(shape, None)
if mask is None:
# resize from the original mask
mask = mask_dic.get((0, 0), None)
org_dtype = mask.dtype
if org_dtype == torch.bfloat16:
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=shape, mode="area") # area is needed for keeping the mask value less than 1
mask = (mask == 1).to(dtype=org_dtype, device=mask.device)
mask_dic[shape] = mask
# for m in mask[0,0]:
# print("".join([f"{int(v)}" for v in m]))
return mask
# class Conv2dZeroSlicing(nn.Conv2d):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.mask_dic = None
# self.enable_flag = None
# def set_reference_for_enable_and_mask_dic(self, enable_flag, mask_dic):
# self.enable_flag = enable_flag
# self.mask_dic = mask_dic
# def forward(self, input: torch.Tensor) -> torch.Tensor:
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
# if self.enable_flag is None or not self.enable_flag[0] or self.mask_dic is None or len(self.mask_dic) == 0:
# return super().forward(input)
# mask = get_mask_from_mask_dic(self.mask_dic, input.shape[-2:])
# if mask is not None:
# input = input * mask
# return super().forward(input)
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
@@ -1131,6 +1171,92 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = None
self.ds_ratio = None
# Dilated Conv
self.dc_depth = None
self.dc_timesteps = None
self.dc_enable_flag = [False]
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3) and module.dilation == (1, 1):
module.dc_enable_flag = self.dc_enable_flag
# replace forward method
module.dc_original_forward = module.forward
def make_forward_dilated_conv(module):
def forward_conv2d_dilated_conv(input: torch.Tensor) -> torch.Tensor:
if module.dc_enable_flag[0]:
module.dilation = (1, 2)
module.padding = (1, 2)
else:
module.dilation = (1, 1)
module.padding = (1, 1)
return module.dc_original_forward(input)
return forward_conv2d_dilated_conv
module.forward = make_forward_dilated_conv(module)
# flexible zero slicing
self.fz_depth = None
self.fz_enable_flag = [False]
self.fz_mask_dic = {}
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3):
module.fz_enable_flag = self.fz_enable_flag
module.fz_mask_dic = self.fz_mask_dic
# replace forward method
module.fz_original_forward = module.forward
def make_forward(module):
def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor:
if not module.fz_enable_flag[0] or len(module.fz_mask_dic) == 0:
return module.fz_original_forward(input)
mask = get_mask_from_mask_dic(module.fz_mask_dic, input.shape[-2:])
input = input * mask
return module.fz_original_forward(input)
return forward_conv2d_zero_slicing
module.forward = make_forward(module)
# def forward_conv2d_zero_slicing(self, input: torch.Tensor) -> torch.Tensor:
# print(self.__class__.__name__, "forward_conv2d_zero_slicing")
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
# if self.fz_depth is None or not self.fz_enable_flag[0] or self.fz_mask_dic is None or len(self.fz_mask_dic) == 0:
# return self.original_forward(input)
# mask = get_mask_from_mask_dic(self.fz_mask_dic, input.shape[-2:])
# if mask is not None:
# input = input * mask
# return self.original_forward(input)
# for name, module in list(self.delegate.named_modules()):
# if isinstance(module, nn.Conv2d):
# if module.kernel_size == (3, 3):
# # replace Conv2d with Conv2dZeroSlicing
# new_conv2d = Conv2dZeroSlicing(
# module.in_channels,
# module.out_channels,
# module.kernel_size,
# module.stride,
# module.padding,
# module.dilation,
# module.groups,
# module.bias is not None,
# module.padding_mode,
# )
# new_conv2d.set_reference_for_enable_and_mask_dic(self.fz_enable_flag, self.fz_mask_dic)
# print(f"replace {name} with Conv2dZeroSlicing")
# setattr(self.delegate, name, new_conv2d)
# # copy parameters
# new_conv2d.weight = module.weight
# new_conv2d.bias = module.bias
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
@@ -1156,6 +1282,32 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def set_flexible_zero_slicing(self, mask: torch.Tensor, depth: int, timesteps: int = None):
# mask is arbitrary shape, 0 for zero slicing.
if depth is None or depth < 0:
logger.info("Flexible zero slicing is disabled.")
self.fz_depth = None
self.fz_mask = None
self.fz_timesteps = None
self.fz_mask_dic.clear()
else:
logger.info(f"Flexible zero slicing is enabled: [depth={depth}]")
self.fz_depth = depth
self.fz_mask = mask
self.fz_timesteps = timesteps
self.fz_mask_dic.clear()
self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0)
def set_dilated_conv(self, depth: int, timesteps: int = None):
if depth is None or depth < 0:
logger.info("Dilated Conv is disabled.")
self.dc_depth = None
self.dc_timesteps = None
else:
logger.info(f"Dilated Conv is enabled: [depth={depth}]")
self.dc_depth = depth
self.dc_timesteps = timesteps
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
@@ -1166,7 +1318,7 @@ class InferSdxlUNet2DConditionModel:
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
@@ -1190,7 +1342,18 @@ class InferSdxlUNet2DConditionModel:
# h = x.type(self.dtype)
h = x
self.fz_enable_flag[0] = False
for depth, module in enumerate(_self.input_blocks):
# Dilated Conv
if self.dc_depth is not None:
self.dc_enable_flag[0] = depth >= self.dc_depth and timesteps[0] > self.dc_timesteps
# Flexible Zero Slicing
if self.fz_depth is not None:
self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
@@ -1210,7 +1373,16 @@ class InferSdxlUNet2DConditionModel:
h = call_module(_self.middle_block, h, emb, context)
for module in _self.output_blocks:
for depth, module in enumerate(_self.output_blocks):
# Dilated Conv
if self.dc_depth is not None and len(_self.output_blocks) - depth <= self.dc_depth:
self.dc_enable_flag[0] = False
# Flexible Zero Slicing
if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth:
self.fz_enable_flag[0] = False
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]:

View File

@@ -327,10 +327,10 @@ class DyLoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
logger.info(f"create LoRA for Text Encoder {index}")
print(f"create LoRA for Text Encoder {index}")
else:
index = None
logger.info("create LoRA for Text Encoder")
print(f"create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)

View File

@@ -380,10 +380,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
logger.info("gradual_latent is disabled")
print("gradual_latent is disabled")
self.gradual_latent = None
else:
logger.info(f"gradual_latent is enabled: {gradual_latent}")
print(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
@@ -789,8 +789,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -2614,84 +2614,84 @@ def main(args):
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
print(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
print(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
print(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
print(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
print(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
print(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
print(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
print(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
print(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
print(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
print(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
print(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex: