From 0ee75fd75df55ebfbd280ea2a366277657c9a467 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 3 Sep 2023 12:24:15 +0900 Subject: [PATCH] fix typos, add comments etc. --- library/sdxl_model_util.py | 2 +- library/sdxl_train_util.py | 6 ++++-- library/train_util.py | 7 ++++--- networks/control_net_lllite_for_train.py | 2 +- sdxl_train_control_net_lllite_alt.py | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index e54da796..6647b439 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -259,7 +259,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 + info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index d8529ef2..f637d993 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -37,7 +37,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): model_version, weight_dtype, accelerator.device if args.lowram else "cpu", - model_dtype + model_dtype, ) # work on low-ram device @@ -56,7 +56,9 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info -def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None): +def _load_target_model( + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None +): # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers diff --git a/library/train_util.py b/library/train_util.py index bb774792..200b486c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2898,7 +2898,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--ip_noise_gamma", type=float, default=None, - help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) / ", + help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) # parser.add_argument( # "--perlin_noise", @@ -4353,11 +4354,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) if args.ip_noise_gamma: noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) else: - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) return noise, noisy_latents, timesteps diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 2bbefbf8..02688001 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -1,4 +1,4 @@ -# cond_imageをU-Netのforardで渡すバージョンのControlNet-LLLite検証用実装 +# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装 # ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward import os diff --git a/sdxl_train_control_net_lllite_alt.py b/sdxl_train_control_net_lllite_alt.py index 20e7de4b..757194a1 100644 --- a/sdxl_train_control_net_lllite_alt.py +++ b/sdxl_train_control_net_lllite_alt.py @@ -1,4 +1,4 @@ -# cond_imageをU-Netのforardで渡すバージョンのControlNet-LLLite検証用学習コード +# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse