From 938bd71844e24f08b6483717b1a13aab9cb83657 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 18:31:27 +0800 Subject: [PATCH 01/18] lower ram usage --- train_network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3e8f4e7d..710055e0 100644 --- a/train_network.py +++ b/train_network.py @@ -150,7 +150,9 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - + # unnecessary, but work on low-ram device + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) From fb312acb7f07605e46662099f62181de197fb490 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 18:54:55 +0800 Subject: [PATCH 02/18] support DistributedDataParallel --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 710055e0..c2f9cbf6 100644 --- a/train_network.py +++ b/train_network.py @@ -267,6 +267,14 @@ def train(args): unet.eval() text_encoder.eval() + # support DistributedDataParallel + try: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + except: + pass + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: From c0be52a7731b58e305ae22ed26e57af6b7d61f5a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:05:39 +0800 Subject: [PATCH 03/18] ignore get_hidden_states expected scalar Error --- library/train_util.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6f809deb..dc0724d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1560,9 +1560,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # uncomment code may raise expected scalar type Half but found Float when using DDP + # if weight_dtype is not None: + # # this is required for additional network training + # encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 From 5e96e1369da410e52725e2c7af8b9f9def956c4b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:23:39 +0800 Subject: [PATCH 04/18] fix get_hidden_states expected scalar Error --- library/train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index dc0724d7..4c410567 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1560,10 +1560,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - # uncomment code may raise expected scalar type Half but found Float when using DDP - # if weight_dtype is not None: - # # this is required for additional network training - # encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 @@ -1589,6 +1585,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) return encoder_hidden_states From b599adc938de5a9d728dd73dc311ab95e1f786e0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:34:03 +0800 Subject: [PATCH 05/18] fix Input type error when using DDP --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index c2f9cbf6..fc387bc3 100644 --- a/train_network.py +++ b/train_network.py @@ -1,5 +1,6 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer +from torch.cuda.amp import autocast from typing import Optional, Union import importlib import argparse From 6473aa1dd7cd927852957b977cfdf4f86b4f17b9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 21:32:21 +0800 Subject: [PATCH 06/18] fix Input type error in noise_pred when using DDP --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index fc387bc3..59f74211 100644 --- a/train_network.py +++ b/train_network.py @@ -417,7 +417,8 @@ def train(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + with autocast(): + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample From 9a9ac79edff44e9c0cdb28d127f62fe5ce0cca07 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 22:30:20 +0800 Subject: [PATCH 07/18] correct wrong inserted code for noise_pred fix --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 59f74211..f247c74e 100644 --- a/train_network.py +++ b/train_network.py @@ -417,11 +417,11 @@ def train(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - with autocast(): - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training From b8ad17902f91d0293daf6a685224d2bb59f9b301 Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Wed, 8 Feb 2023 23:09:59 +0800 Subject: [PATCH 08/18] fix get_hidden_states expected scalar Error again --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f247c74e..928ad321 100644 --- a/train_network.py +++ b/train_network.py @@ -257,7 +257,7 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() From 55521eece0fed681dea93d827fbfc34812c8d711 Mon Sep 17 00:00:00 2001 From: michaelgzhang <49577754+mgz-dev@users.noreply.github.com> Date: Sat, 11 Feb 2023 02:38:13 -0600 Subject: [PATCH 09/18] add verbosity option for resize_lora.py add --verbose flag to print additional statistics during resize_lora function correct some parameter references in resize_lora_model function --- networks/resize_lora.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 7beeb25e..29f87e6d 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -38,9 +38,10 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device): +def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = None network_dim = None + verbose_str = "\n" CLAMP_QUANTILE = 0.99 @@ -96,6 +97,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device): U, S, Vh = torch.linalg.svd(full_weight_matrix) + if verbose: + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + verbose_str+=f"{block_down_name:76} | " + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}%, max(S) to max(S_dropped) ratio: {S[0]/S[new_rank]:0.1f}\n" + U = U[:, :new_rank] S = S[:new_rank] U = U @ torch.diag(S) @@ -113,7 +120,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device): U = U.unsqueeze(2).unsqueeze(3) Vh = Vh.unsqueeze(2).unsqueeze(3) - if args.device: + if device: U = U.to(org_device) Vh = Vh.to(org_device) @@ -127,6 +134,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device): lora_up_weight = None weights_loaded = False + if verbose: + print(verbose_str) print("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -151,7 +160,7 @@ def resize(args): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) # update metadata if metadata is None: @@ -182,6 +191,8 @@ if __name__ == '__main__': parser.add_argument("--model", type=str, default=None, help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument("--verbose", action="store_true", + help="Display verbose resizing information") args = parser.parse_args() resize(args) From b35b053b8d7c75c64dec7dd7d94d9a9b8ef27e66 Mon Sep 17 00:00:00 2001 From: michaelgzhang <49577754+mgz-dev@users.noreply.github.com> Date: Sat, 11 Feb 2023 03:14:43 -0600 Subject: [PATCH 10/18] clean up print formatting --- networks/resize_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 29f87e6d..e21bdabd 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -101,7 +101,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) verbose_str+=f"{block_down_name:76} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}%, max(S) to max(S_dropped) ratio: {S[0]/S[new_rank]:0.1f}\n" + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" U = U[:, :new_rank] S = S[:new_rank] From 2b1a3080e7ddc329e3a3bf59126d8ccac80d0dae Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 12 Feb 2023 15:32:38 +0800 Subject: [PATCH 11/18] Add type checking --- train_network.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index 852aea8e..90771b31 100644 --- a/train_network.py +++ b/train_network.py @@ -267,18 +267,19 @@ def train(args): text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() # support DistributedDataParallel - try: - text_encoder = text_encoder.module - unet = unet.module - network = network.module - except: - pass + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module network.prepare_grad_etc(text_encoder, unet) From 5471b0deb0de4d717d45af94dcba9a596b1f1207 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 13 Feb 2023 02:58:06 -0800 Subject: [PATCH 12/18] Add commit hash to metadata --- library/train_util.py | 8 ++++++++ train_network.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 24e15d1f..98890273 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import math import os import random import hashlib +import subprocess from io import BytesIO from tqdm import tqdm @@ -1100,6 +1101,13 @@ def addnet_hash_safetensors(b): return hash_sha256.hexdigest() +def get_git_revision_hash() -> str: + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + except: + return "(unknown)" + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/train_network.py b/train_network.py index bb3159fd..69aca345 100644 --- a/train_network.py +++ b/train_network.py @@ -344,7 +344,8 @@ def train(args): "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_bucket_info": json.dumps(train_dataset.bucket_info), - "ss_training_comment": args.training_comment # will not be updated after training + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash() } # uncomment if another network is added From 3c29784825f1d0402e22703f7557ee9fe346f135 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 14 Feb 2023 20:55:20 +0900 Subject: [PATCH 13/18] Add ja comment --- networks/resize_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index e21bdabd..271de8ef 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -192,7 +192,7 @@ if __name__ == '__main__': help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", - help="Display verbose resizing information") + help="Display verbose resizing information / rank変更時の詳細情報を出力する") args = parser.parse_args() resize(args) From e0f007f2a912ba4312c6e4d622667fc03a891058 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 14 Feb 2023 20:55:38 +0900 Subject: [PATCH 14/18] Fix import --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index fdc466ec..b783379b 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,7 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer from torch.cuda.amp import autocast +from torch.nn.parallel import DistributedDataParallel as DDP from typing import Optional, Union import importlib import argparse From 43c0a69843b4408e1a9e69ccfd6e8e37f9f69803 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 14 Feb 2023 21:15:48 +0900 Subject: [PATCH 15/18] Add noise_offset --- fine_tune.py | 3 +++ library/train_util.py | 20 +++++++++++--------- train_db.py | 5 ++++- train_network.py | 9 ++++++--- train_textual_inversion.py | 3 +++ 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 52921530..3ba63063 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -255,6 +255,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/library/train_util.py b/library/train_util.py index 9125108b..415f9b70 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -300,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.shuffle_keep_tokens is None: if self.shuffle_caption: random.shuffle(tokens) - + tokens = dropout_tags(tokens) else: if len(tokens) > self.shuffle_keep_tokens: @@ -309,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.shuffle_caption: random.shuffle(tokens) - + tokens = dropout_tags(tokens) tokens = keep_tokens + tokens @@ -1102,10 +1102,10 @@ def addnet_hash_safetensors(b): def get_git_revision_hash() -> str: - try: - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() - except: - return "(unknown)" + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + except: + return "(unknown)" # flash attention forwards and backwards @@ -1421,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--noise_offset", type=float, default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") if support_dreambooth: # DreamBooth training @@ -1653,10 +1655,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) - + if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) return encoder_hidden_states diff --git a/train_db.py b/train_db.py index c210767b..4a50dc94 100644 --- a/train_db.py +++ b/train_db.py @@ -233,10 +233,13 @@ def train(args): else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 + b_size = latents.shape[0] # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): diff --git a/train_network.py b/train_network.py index b783379b..1b8046d2 100644 --- a/train_network.py +++ b/train_network.py @@ -278,9 +278,9 @@ def train(args): # support DistributedDataParallel if type(text_encoder) == DDP: - text_encoder = text_encoder.module - unet = unet.module - network = network.module + text_encoder = text_encoder.module + unet = unet.module + network = network.module network.prepare_grad_etc(text_encoder, unet) @@ -419,6 +419,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4aa91eee..010bd04b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,6 +320,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) From 2aef2872fb7afe4b1ab37a0bea0e8974f7207f9e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 14 Feb 2023 21:28:34 +0900 Subject: [PATCH 16/18] update readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 30921b26..a390f49e 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,20 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 14 Feb. 2023, 2023/2/14: + - Add support with multi-gpu trainining for ``train_newtork.py``. Thanks to Isotr0py! + - Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](./pull/179). Thanks to mgz-dev! + - Git commit hash is added to the metadata for LoRA. Thanks to space-nuko! + - Add ``--noise_offset`` option for each traing scripts. + - Implementation of https://www.crosslabs.org//blog/diffusion-with-offset-noise + - This option may improve ability to generate darker/lighter images. May work with LoRA. + - ``train_newtork.py``でマルチGPU学習をサポートしました。Isotr0py氏に感謝します。 + - ``--verbose``オプションを ``resize_lora.py`` に追加しました。表示される情報の詳細は [こちらのPR](./pull/179) をご参照ください。mgz-dev氏に感謝します。 + - LoRAのメタデータにgitのcommit hashを追加しました。space-nuko氏に感謝します。 + - ``--noise_offset`` オプションを各学習スクリプトに追加しました。 + - こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise + - 全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。 + - 11 Feb. 2023, 2023/2/11: - ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage. - For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate. From 3d400667d243c8f02586ce139ab3be21b03e2da6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 14 Feb 2023 21:29:40 +0900 Subject: [PATCH 17/18] fix typos --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a390f49e..7a9689e9 100644 --- a/README.md +++ b/README.md @@ -125,13 +125,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History - 14 Feb. 2023, 2023/2/14: - - Add support with multi-gpu trainining for ``train_newtork.py``. Thanks to Isotr0py! + - Add support with multi-gpu trainining for ``train_network.py``. Thanks to Isotr0py! - Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](./pull/179). Thanks to mgz-dev! - Git commit hash is added to the metadata for LoRA. Thanks to space-nuko! - - Add ``--noise_offset`` option for each traing scripts. + - Add ``--noise_offset`` option for each training scripts. - Implementation of https://www.crosslabs.org//blog/diffusion-with-offset-noise - This option may improve ability to generate darker/lighter images. May work with LoRA. - - ``train_newtork.py``でマルチGPU学習をサポートしました。Isotr0py氏に感謝します。 + - ``train_network.py``でマルチGPU学習をサポートしました。Isotr0py氏に感謝します。 - ``--verbose``オプションを ``resize_lora.py`` に追加しました。表示される情報の詳細は [こちらのPR](./pull/179) をご参照ください。mgz-dev氏に感謝します。 - LoRAのメタデータにgitのcommit hashを追加しました。space-nuko氏に感謝します。 - ``--noise_offset`` オプションを各学習スクリプトに追加しました。 From 82713e9aa62d9a28b17b3920b2b9bccd74051f9b Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Tue, 14 Feb 2023 21:41:04 +0900 Subject: [PATCH 18/18] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7a9689e9..62551f27 100644 --- a/README.md +++ b/README.md @@ -126,13 +126,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - 14 Feb. 2023, 2023/2/14: - Add support with multi-gpu trainining for ``train_network.py``. Thanks to Isotr0py! - - Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](./pull/179). Thanks to mgz-dev! + - Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](https://github.com/kohya-ss/sd-scripts/pull/179). Thanks to mgz-dev! - Git commit hash is added to the metadata for LoRA. Thanks to space-nuko! - Add ``--noise_offset`` option for each training scripts. - Implementation of https://www.crosslabs.org//blog/diffusion-with-offset-noise - This option may improve ability to generate darker/lighter images. May work with LoRA. - ``train_network.py``でマルチGPU学習をサポートしました。Isotr0py氏に感謝します。 - - ``--verbose``オプションを ``resize_lora.py`` に追加しました。表示される情報の詳細は [こちらのPR](./pull/179) をご参照ください。mgz-dev氏に感謝します。 + - ``--verbose``オプションを ``resize_lora.py`` に追加しました。表示される情報の詳細は [こちらのPR](https://github.com/kohya-ss/sd-scripts/pull/179) をご参照ください。mgz-dev氏に感謝します。 - LoRAのメタデータにgitのcommit hashを追加しました。space-nuko氏に感謝します。 - ``--noise_offset`` オプションを各学習スクリプトに追加しました。 - こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise