mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Add Adafactor optimzier
This commit is contained in:
17
fine_tune.py
17
fine_tune.py
@@ -149,7 +149,7 @@ def train(args):
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
_, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -163,10 +163,9 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -268,11 +267,11 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -285,8 +284,8 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# common functions for training
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
@@ -21,6 +22,7 @@ import torch
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import transformers
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
@@ -1371,28 +1373,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
||||
|
||||
# backward compatibility
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
|
||||
help="Momentum value for optimizers for SGD optimizers")
|
||||
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
|
||||
help="Weight decay for optimizers")
|
||||
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
|
||||
help="beta1 parameter for Adam optimizers")
|
||||
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
|
||||
help="beta2 parameter for Adam optimizers")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
||||
|
||||
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
@@ -1525,18 +1528,37 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます")
|
||||
optimizer_type = "AdamW8bit"
|
||||
elif args.use_lion_optimizer:
|
||||
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます")
|
||||
optimizer_type = "Lion"
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
betas = (args.optimizer_beta1, args.optimizer_beta2)
|
||||
weight_decay = args.optimizer_weight_decay
|
||||
momentum = args.optimizer_momentum
|
||||
# 引数を分解する:boolとfloat、tupleのみ対応
|
||||
optimizer_kwargs = {}
|
||||
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
||||
for arg in args.optimizer_args:
|
||||
key, value = arg.split('=')
|
||||
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = (value[i].lower() == "true")
|
||||
else:
|
||||
value[i] = float(value[i])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
else:
|
||||
value = tuple(value)
|
||||
|
||||
optimizer_kwargs[key] = value
|
||||
print("optkwargs:", optimizer_kwargs)
|
||||
|
||||
lr = args.learning_rate
|
||||
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
@@ -1544,53 +1566,128 @@ def get_optimizer(args, trainable_params):
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = bnb.optim.SGD8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Lion".lower():
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
print(f"use Lion optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov".lower():
|
||||
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "DAdaptation".lower():
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
if lr <= 0.1:
|
||||
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
|
||||
min_lr = lr
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
for group in trainable_params:
|
||||
min_lr = min(min_lr, group.get("lr", lr))
|
||||
|
||||
if min_lr <= 0.1:
|
||||
print(
|
||||
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
if "relative_step" not in optimizer_kwargs:
|
||||
optimizer_kwargs["relative_step"] = True # default
|
||||
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
||||
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
||||
optimizer_kwargs["relative_step"] = True
|
||||
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
||||
|
||||
if optimizer_kwargs["relative_step"]:
|
||||
print(f"relative_step is true / relative_stepがtrueです")
|
||||
if lr != 0.0:
|
||||
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
||||
args.learning_rate = None
|
||||
|
||||
# trainable_paramsがgroupだった時の処理:lrを削除する
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
has_group_lr = False
|
||||
for group in trainable_params:
|
||||
p = group.pop("lr", None)
|
||||
has_group_lr = has_group_lr or (p is not None)
|
||||
|
||||
if has_group_lr:
|
||||
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
||||
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
||||
args.unet_lr = None
|
||||
args.text_encoder_lr = None
|
||||
|
||||
if args.lr_scheduler != "adafactor":
|
||||
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
||||
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
||||
|
||||
lr = None
|
||||
else:
|
||||
if args.max_grad_norm != 0.0:
|
||||
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
||||
if args.lr_scheduler != "constant_with_warmup":
|
||||
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
||||
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
||||
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
||||
|
||||
optimizer_class = transformers.optimization.Adafactor
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "AdamW".lower():
|
||||
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
else:
|
||||
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
if "." not in optimizer_type:
|
||||
optimizer_module = torch.optim
|
||||
else:
|
||||
values = optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
optimizer_type = values[-1]
|
||||
|
||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||
|
||||
return optimizer_name, optimizer
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
@@ -1627,6 +1724,12 @@ def get_scheduler_fix(
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
if name.startswith("adafactor"):
|
||||
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
||||
initial_lr = float(name.split(':')[1])
|
||||
# print("adafactor scheduler init lr", initial_lr)
|
||||
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
@@ -1744,13 +1847,19 @@ def prepare_dtype(args: argparse.Namespace):
|
||||
|
||||
|
||||
def load_target_model(args: argparse.Namespace, weight_dtype):
|
||||
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
if load_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
||||
else:
|
||||
print("load Diffusers pretrained models")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
||||
except EnvironmentError as ex:
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
17
train_db.py
17
train_db.py
@@ -120,7 +120,7 @@ def train(args):
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
_, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -137,10 +137,9 @@ def train(args):
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -263,12 +262,12 @@ def train(args):
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
if train_text_encoder:
|
||||
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
params_to_clip = unet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -281,8 +280,8 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
|
||||
@@ -28,14 +28,14 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
||||
|
||||
return logs
|
||||
@@ -147,7 +147,7 @@ def train(args):
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -161,10 +161,9 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -287,7 +286,7 @@ def train(args):
|
||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
||||
"ss_optimizer": optimizer_name
|
||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "")
|
||||
}
|
||||
|
||||
# uncomment if another network is added
|
||||
@@ -380,9 +379,9 @@ def train(args):
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -478,10 +477,6 @@ if __name__ == '__main__':
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None,
|
||||
help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
|
||||
@@ -199,7 +199,7 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
_, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -213,10 +213,9 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -338,9 +337,9 @@ def train(args):
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -357,8 +356,8 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user