diff --git a/fine_tune.py b/fine_tune.py index a3588c37..96aa362b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -149,7 +149,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -163,8 +163,10 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -284,8 +286,11 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] accelerator.log(logs, step=global_step) + # TODO moving averageにする loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} @@ -295,7 +300,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() diff --git a/library/train_util.py b/library/train_util.py index 1a28c39a..329b27fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5,6 +5,7 @@ import json import shutil import time from typing import Dict, List, NamedTuple, Tuple +from typing import Optional, Union from accelerate import Accelerator from torch.autograd.function import Function import glob @@ -17,9 +18,11 @@ from io import BytesIO from tqdm import tqdm import torch +from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer import diffusers +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import DDPMScheduler, StableDiffusionPipeline import albumentations as albu import numpy as np @@ -1368,12 +1371,18 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--optimizer_type", type=str, default="AdamW", - help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit") + help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation") + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--use_lion_optimizer", action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") + # parser.add_argument("--use_dadaptation_optimizer", action="store_true", + # help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--optimizer_momentum", type=float, default=0.9, - help="Momentum value for optimizers") - parser.add_argument("--optimizer_weightdecay", type=float, default=0.01, + help="Momentum value for optimizers for SGD optimizers") + parser.add_argument("--optimizer_weight_decay", type=float, default=0.01, help="Weight decay for optimizers") parser.add_argument("--optimizer_beta1", type=float, default=0.9, help="beta1 parameter for Adam optimizers") @@ -1407,12 +1416,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - # parser.add_argument("--use_8bit_adam", action="store_true", - # help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - # parser.add_argument("--use_lion_optimizer", action="store_true", - # help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") - # parser.add_argument("--use_dadaptation_optimizer", action="store_true", - # help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)") parser.add_argument("--mem_eff_attn", action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") parser.add_argument("--xformers", action="store_true", @@ -1520,14 +1523,19 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): # region utils -# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaption" def get_optimizer(args, trainable_params): - # Prepare optimizer/学習に必要なクラスを準備する - optimizer_type = args.optimizer_type.lower() + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation" + + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + optimizer_type = "AdamW8bit" + elif args.use_lion_optimizer: + optimizer_type = "Lion" + optimizer_type = optimizer_type.lower() betas = (args.optimizer_beta1, args.optimizer_beta2) - weight_decay = args.optimizer_weightdecay + weight_decay = args.optimizer_weight_decay momentum = args.optimizer_momentum lr = args.learning_rate @@ -1563,17 +1571,18 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) - elif optimizer_type == "dadaptation".lower(): + elif optimizer_type == "DAdaptation".lower(): try: import dadaptation except ImportError: raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use dadaptation optimizer") + print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}") optimizer_class = dadaptation.DAdaptAdam - if args.learning_rate <= 0.1: - print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.') - print('recommend option: lr=1.0') - optimizer = optimizer_class(trainable_params, lr=lr) + if lr <= 0.1: + print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}') + print('recommend option: lr=1.0 / 推奨は1.0です') + optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + else: print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") optimizer_class = torch.optim.AdamW @@ -1584,6 +1593,69 @@ def get_optimizer(args, trainable_params): return optimizer_name, optimizer +# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler +# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 +# Which is a newer release of diffusers than currently packaged with sd-scripts +# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts + + +def get_scheduler_fix( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, +): + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + + def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): # backward compatibility if args.caption_extention is not None: diff --git a/train_db.py b/train_db.py index 51e588fc..268d90a1 100644 --- a/train_db.py +++ b/train_db.py @@ -120,7 +120,7 @@ def train(args): else: trainable_params = unet.parameters() - optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) + _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -136,9 +136,11 @@ def train(args): if args.stop_text_encoder_training is None: args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -280,6 +282,8 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] accelerator.log(logs, step=global_step) if epoch == 0: diff --git a/train_network.py b/train_network.py index df987325..3c4fb8d9 100644 --- a/train_network.py +++ b/train_network.py @@ -1,8 +1,5 @@ -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION -from torch.optim import Optimizer from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel as DDP -from typing import Optional, Union import importlib import argparse import gc @@ -26,6 +23,7 @@ def collate_fn(examples): return examples[0] +# TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -37,75 +35,12 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder - if args.use_dadaptation_optimizer: # tracking d*lr value of unet. + if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] return logs -# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler -# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 -# Which is a newer release of diffusers than currently packaged with sd-scripts -# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts - - -def get_scheduler_fix( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - num_cycles: int = 1, - power: float = 1.0, -): - """ - Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles - ) - - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power - ) - - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) - - def train(args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -164,7 +99,7 @@ def train(args): if args.lowram: text_encoder.to("cuda") unet.to("cuda") - + # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -226,8 +161,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - # lr_scheduler = diffusers.optimization.get_scheduler( - lr_scheduler = get_scheduler_fix( + lr_scheduler = train_util.get_scheduler_fix( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 1913da7e..07dcc199 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -199,7 +199,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() - optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) + _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -213,8 +213,10 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # acceleratorがなんかよろしくやってくれるらしい text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -356,6 +358,8 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] accelerator.log(logs, step=global_step) loss_total += current_loss