diff --git a/fine_tune.py b/fine_tune.py index 96aa362b..80290e72 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -149,7 +149,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -163,10 +163,9 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -268,11 +267,11 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) - if accelerator.sync_gradients: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -285,8 +284,8 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] accelerator.log(logs, step=global_step) diff --git a/library/train_util.py b/library/train_util.py index 329b27fc..37642dd5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,6 +1,7 @@ # common functions for training import argparse +import importlib import json import shutil import time @@ -21,6 +22,7 @@ import torch from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer +import transformers import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import DDPMScheduler, StableDiffusionPipeline @@ -1371,28 +1373,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--optimizer_type", type=str, default="AdamW", - help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation") + help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") + + # backward compatibility parser.add_argument("--use_8bit_adam", action="store_true", help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") parser.add_argument("--use_lion_optimizer", action="store_true", help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") - # parser.add_argument("--use_dadaptation_optimizer", action="store_true", - # help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--optimizer_momentum", type=float, default=0.9, - help="Momentum value for optimizers for SGD optimizers") - parser.add_argument("--optimizer_weight_decay", type=float, default=0.01, - help="Weight decay for optimizers") - parser.add_argument("--optimizer_beta1", type=float, default=0.9, - help="beta1 parameter for Adam optimizers") - parser.add_argument("--optimizer_beta2", type=float, default=0.999, - help="beta2 parameter for Adam optimizers") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない") + + parser.add_argument("--optimizer_args", type=str, default=None, nargs='*', + help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor") parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") + parser.add_argument("--lr_scheduler_power", type=float, default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -1525,18 +1528,37 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: + print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます") optimizer_type = "AdamW8bit" elif args.use_lion_optimizer: + print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます") optimizer_type = "Lion" optimizer_type = optimizer_type.lower() - betas = (args.optimizer_beta1, args.optimizer_beta2) - weight_decay = args.optimizer_weight_decay - momentum = args.optimizer_momentum + # 引数を分解する:boolとfloat、tupleのみ対応 + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split('=') + + value = value.split(",") + for i in range(len(value)): + if value[i].lower() == "true" or value[i].lower() == "false": + value[i] = (value[i].lower() == "true") + else: + value[i] = float(value[i]) + if len(value) == 1: + value = value[0] + else: + value = tuple(value) + + optimizer_kwargs[key] = value + print("optkwargs:", optimizer_kwargs) + lr = args.learning_rate if optimizer_type == "AdamW8bit".lower(): @@ -1544,53 +1566,128 @@ def get_optimizer(args, trainable_params): import bitsandbytes as bnb except ImportError: raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov8bit".lower(): try: import bitsandbytes as bnb except ImportError: raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}") + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + optimizer_kwargs["momentum"] = 0.9 + optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "Lion".lower(): try: import lion_pytorch except ImportError: raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}") + print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion - optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}") + print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + optimizer_kwargs["momentum"] = 0.9 + optimizer_class = torch.optim.SGD - optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "DAdaptation".lower(): try: import dadaptation except ImportError: raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}") - optimizer_class = dadaptation.DAdaptAdam - if lr <= 0.1: - print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}') + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + + min_lr = lr + if type(trainable_params) == list and type(trainable_params[0]) == dict: + for group in trainable_params: + min_lr = min(min_lr, group.get("lr", lr)) + + if min_lr <= 0.1: + print( + f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}') print('recommend option: lr=1.0 / 推奨は1.0です') - optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + + optimizer_class = dadaptation.DAdaptAdam + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # 引数を確認して適宜補正する + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + optimizer_kwargs["relative_step"] = True + print(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + print(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + # trainable_paramsがgroupだった時の処理:lrを削除する + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない + print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None + else: + if args.max_grad_norm != 0.0: + print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません") + if args.lr_scheduler != "constant_with_warmup": + print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "AdamW".lower(): + print(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) else: - print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") - optimizer_class = torch.optim.AdamW - optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + # 任意のoptimizerを使う + optimizer_type = args.optimizer_type # lowerでないやつ(微妙) + print(f"use {optimizer_type} | {optimizer_kwargs}") + if "." not in optimizer_type: + optimizer_module = torch.optim + else: + values = optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - return optimizer_name, optimizer + return optimizer_name, optimizer_args, optimizer # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler @@ -1627,6 +1724,12 @@ def get_scheduler_fix( last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ + if name.startswith("adafactor"): + assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(':')[1]) + # print("adafactor scheduler init lr", initial_lr) + return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) + name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: @@ -1744,13 +1847,19 @@ def prepare_dtype(args: argparse.Namespace): def load_target_model(args: argparse.Namespace, weight_dtype): - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) else: print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}") text_encoder = pipe.text_encoder vae = pipe.vae unet = pipe.unet diff --git a/train_db.py b/train_db.py index 268d90a1..03fba1a6 100644 --- a/train_db.py +++ b/train_db.py @@ -120,7 +120,7 @@ def train(args): else: trainable_params = unet.parameters() - _, optimizer = train_util.get_optimizer(args, trainable_params) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -137,10 +137,9 @@ def train(args): args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -263,12 +262,12 @@ def train(args): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: if train_text_encoder: params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) else: params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -281,8 +280,8 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] accelerator.log(logs, step=global_step) diff --git a/train_network.py b/train_network.py index 3c4fb8d9..b01ec117 100644 --- a/train_network.py +++ b/train_network.py @@ -28,14 +28,14 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs = {"loss/current": current_loss, "loss/average": avr_loss} if args.network_train_unet_only: - logs["lr/unet"] = lr_scheduler.get_last_lr()[0] + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) else: - logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] - logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder - if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet. + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] return logs @@ -147,7 +147,7 @@ def train(args): print("prepare optimizer, data loader etc.") trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -161,10 +161,9 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -287,7 +286,7 @@ def train(args): "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "") } # uncomment if another network is added @@ -380,9 +379,9 @@ def train(args): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -478,10 +477,6 @@ if __name__ == '__main__': parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, - help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") - parser.add_argument("--lr_scheduler_power", type=float, default=1, - help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 07dcc199..b4ddd763 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -199,7 +199,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() - _, optimizer = train_util.get_optimizer(args, trainable_params) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -213,10 +213,9 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # acceleratorがなんかよろしくやってくれるらしい text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -338,9 +337,9 @@ def train(args): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -357,8 +356,8 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] accelerator.log(logs, step=global_step)