mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
add lr_scheduler_type etc
This commit is contained in:
@@ -178,9 +178,7 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
|
||||
@@ -1518,6 +1518,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="",
|
||||
help="custom scheduler module")
|
||||
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
@@ -1843,14 +1848,7 @@ def get_optimizer(args, trainable_params):
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
def get_scheduler_fix(args,optimizer: Optimizer):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
@@ -1871,6 +1869,45 @@ def get_scheduler_fix(
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * args.gradient_accumulation_steps
|
||||
num_cycles = args.lr_scheduler_num_cycles
|
||||
power = args.lr_scheduler_power
|
||||
|
||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||
if args.lr_scheduler_args is not None and len(args.optimizer_args) > 0:
|
||||
for arg in args.lr_scheduler_args:
|
||||
key, value = arg.split('=')
|
||||
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = (value[i].lower() == "true")
|
||||
else:
|
||||
value[i] = float(value[i])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
else:
|
||||
value = tuple(value)
|
||||
|
||||
lr_scheduler_kwargs[key] = value
|
||||
|
||||
# using any lr_scheduler from other library
|
||||
if args.lr_scheduler_type:
|
||||
lr_scheduler_type = args.lr_scheduler_type
|
||||
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
|
||||
if "." not in lr_scheduler_type: # default to use torch.optim
|
||||
lr_scheduler_module = torch.optim.lr_scheduler
|
||||
else:
|
||||
values = lr_scheduler_type.split(".")
|
||||
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
|
||||
lr_scheduler_type = values[-1]
|
||||
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||
return lr_scheduler
|
||||
|
||||
if name.startswith("adafactor"):
|
||||
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
||||
initial_lr = float(name.split(':')[1])
|
||||
|
||||
@@ -150,9 +150,7 @@ def train(args):
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
|
||||
@@ -179,9 +179,7 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
|
||||
@@ -235,9 +235,7 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
|
||||
Reference in New Issue
Block a user