mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
refactor get_scheduler etc.
This commit is contained in:
13
fine_tune.py
13
fine_tune.py
@@ -149,7 +149,7 @@ def train(args):
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
_, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -163,8 +163,10 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -284,8 +286,11 @@ def train(args):
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# TODO moving averageにする
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -295,7 +300,7 @@ def train(args):
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import shutil
|
||||
import time
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
from typing import Optional, Union
|
||||
from accelerate import Accelerator
|
||||
from torch.autograd.function import Function
|
||||
import glob
|
||||
@@ -17,9 +18,11 @@ from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
@@ -1368,12 +1371,18 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
|
||||
help="Momentum value for optimizers")
|
||||
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
|
||||
help="Momentum value for optimizers for SGD optimizers")
|
||||
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
|
||||
help="Weight decay for optimizers")
|
||||
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
|
||||
help="beta1 parameter for Adam optimizers")
|
||||
@@ -1407,12 +1416,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
||||
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
||||
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
||||
# parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
# help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
# parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
# help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||
parser.add_argument("--xformers", action="store_true",
|
||||
@@ -1520,14 +1523,19 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
# region utils
|
||||
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaption"
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# Prepare optimizer/学習に必要なクラスを準備する
|
||||
optimizer_type = args.optimizer_type.lower()
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
optimizer_type = "AdamW8bit"
|
||||
elif args.use_lion_optimizer:
|
||||
optimizer_type = "Lion"
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
betas = (args.optimizer_beta1, args.optimizer_beta2)
|
||||
weight_decay = args.optimizer_weightdecay
|
||||
weight_decay = args.optimizer_weight_decay
|
||||
momentum = args.optimizer_momentum
|
||||
lr = args.learning_rate
|
||||
|
||||
@@ -1563,17 +1571,18 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
|
||||
elif optimizer_type == "dadaptation".lower():
|
||||
elif optimizer_type == "DAdaptation".lower():
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use dadaptation optimizer")
|
||||
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
if args.learning_rate <= 0.1:
|
||||
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
|
||||
print('recommend option: lr=1.0')
|
||||
optimizer = optimizer_class(trainable_params, lr=lr)
|
||||
if lr <= 0.1:
|
||||
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
|
||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
else:
|
||||
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
@@ -1584,6 +1593,69 @@ def get_optimizer(args, trainable_params):
|
||||
return optimizer_name, optimizer
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
# backward compatibility
|
||||
if args.caption_extention is not None:
|
||||
|
||||
12
train_db.py
12
train_db.py
@@ -120,7 +120,7 @@ def train(args):
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
_, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -136,9 +136,11 @@ def train(args):
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -280,6 +282,8 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if epoch == 0:
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from torch.optim import Optimizer
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
@@ -26,6 +23,7 @@ def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
@@ -37,75 +35,12 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
||||
|
||||
if args.use_dadaptation_optimizer: # tracking d*lr value of unet.
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def train(args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
@@ -164,7 +99,7 @@ def train(args):
|
||||
if args.lowram:
|
||||
text_encoder.to("cuda")
|
||||
unet.to("cuda")
|
||||
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@@ -226,8 +161,7 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
lr_scheduler = get_scheduler_fix(
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
@@ -199,7 +199,7 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
_, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -213,8 +213,10 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -356,6 +358,8 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_total += current_loss
|
||||
|
||||
Reference in New Issue
Block a user