refactor get_scheduler etc.

This commit is contained in:
Kohya S
2023-02-20 22:47:43 +09:00
parent 12d30afb39
commit 663aad2b0d
5 changed files with 119 additions and 100 deletions

View File

@@ -149,7 +149,7 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
_, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -163,8 +163,10 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -284,8 +286,11 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -295,7 +300,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()

View File

@@ -5,6 +5,7 @@ import json
import shutil
import time
from typing import Dict, List, NamedTuple, Tuple
from typing import Optional, Union
from accelerate import Accelerator
from torch.autograd.function import Function
import glob
@@ -17,9 +18,11 @@ from io import BytesIO
from tqdm import tqdm
import torch
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
@@ -1368,12 +1371,18 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW",
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う dadaptation のインストールが必要)")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
help="Momentum value for optimizers")
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
help="Momentum value for optimizers for SGD optimizers")
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
help="Weight decay for optimizers")
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
help="beta1 parameter for Adam optimizers")
@@ -1407,12 +1416,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
# parser.add_argument("--use_8bit_adam", action="store_true",
# help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
# parser.add_argument("--use_lion_optimizer", action="store_true",
# help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う dadaptation のインストールが必要)")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
@@ -1520,14 +1523,19 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
# region utils
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaption"
def get_optimizer(args, trainable_params):
# Prepare optimizer/学習に必要なクラスを準備する
optimizer_type = args.optimizer_type.lower()
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
optimizer_type = "AdamW8bit"
elif args.use_lion_optimizer:
optimizer_type = "Lion"
optimizer_type = optimizer_type.lower()
betas = (args.optimizer_beta1, args.optimizer_beta2)
weight_decay = args.optimizer_weightdecay
weight_decay = args.optimizer_weight_decay
momentum = args.optimizer_momentum
lr = args.learning_rate
@@ -1563,17 +1571,18 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
elif optimizer_type == "dadaptation".lower():
elif optimizer_type == "DAdaptation".lower():
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use dadaptation optimizer")
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = dadaptation.DAdaptAdam
if args.learning_rate <= 0.1:
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
print('recommend option: lr=1.0')
optimizer = optimizer_class(trainable_params, lr=lr)
if lr <= 0.1:
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
print('recommend option: lr=1.0 / 推奨は1.0です')
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
else:
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = torch.optim.AdamW
@@ -1584,6 +1593,69 @@ def get_optimizer(args, trainable_params):
return optimizer_name, optimizer
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
# backward compatibility
if args.caption_extention is not None:

View File

@@ -120,7 +120,7 @@ def train(args):
else:
trainable_params = unet.parameters()
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
_, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -136,9 +136,11 @@ def train(args):
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -280,6 +282,8 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
if epoch == 0:

View File

@@ -1,8 +1,5 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Optional, Union
import importlib
import argparse
import gc
@@ -26,6 +23,7 @@ def collate_fn(examples):
return examples[0]
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -37,75 +35,12 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
if args.use_dadaptation_optimizer: # tracking d*lr value of unet.
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
return logs
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -164,7 +99,7 @@ def train(args):
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -226,8 +161,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix(
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)

View File

@@ -199,7 +199,7 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
_, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -213,8 +213,10 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -356,6 +358,8 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
loss_total += current_loss