diff --git a/fine_tune.py b/fine_tune.py index 4ef47c37..45b27a41 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -381,7 +381,7 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/library/train_util.py b/library/train_util.py index c34894a8..fb1a37cd 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2570,13 +2570,15 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - elif optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower(): + elif optimizer_type.startswith("DAdapt".lower()): + # DAdaptation family + # check dadaptation is installed try: import dadaptation except ImportError: raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + # check lr and lr_count, and print warning actual_lr = lr lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: @@ -2596,96 +2598,24 @@ def get_optimizer(args, trainable_params): f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) - optimizer_class = dadaptation.DAdaptAdam + # set optimizer + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdaGrad".lower(): + optimizer_class = dadaptation.DAdaptAdaGrad + print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdan".lower(): + optimizer_class = dadaptation.DAdaptAdan + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptSGD".lower(): + optimizer_class = dadaptation.DAdaptSGD + print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "DAdaptAdaGrad".lower(): - try: - import dadaptation - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" - ) - print("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - optimizer_class = dadaptation.DAdaptAdaGrad - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "DAdaptAdan".lower(): - try: - import dadaptation - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") - - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" - ) - print("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - optimizer_class = dadaptation.DAdaptAdan - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "DAdaptSGD".lower(): - try: - import dadaptation - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") - - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" - ) - print("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - optimizer_class = dadaptation.DAdaptSGD - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する if "relative_step" not in optimizer_kwargs: diff --git a/train_db.py b/train_db.py index 94ef2bf9..c4b9cb19 100644 --- a/train_db.py +++ b/train_db.py @@ -367,7 +367,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_network.py b/train_network.py index fb58b65c..96a23de0 100644 --- a/train_network.py +++ b/train_network.py @@ -43,7 +43,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value of unet. + if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -53,7 +53,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): + if args.optimizer_type.lower().startswith("DAdapt".lower()): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) @@ -277,7 +277,7 @@ def train(args): else: unet.eval() text_encoder.eval() - + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: @@ -713,7 +713,7 @@ def train(args): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - + print("model saved.") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b3907878..c5d23acb 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -465,7 +465,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 5efe019d..7debee82 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -504,7 +504,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] )