add lr_scheduler_type etc

This commit is contained in:
Isotr0py
2023-03-09 16:51:22 +08:00
parent 8d5ba29363
commit eb68892ab1
5 changed files with 49 additions and 20 deletions

View File

@@ -178,9 +178,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:

View File

@@ -1518,6 +1518,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"")
parser.add_argument("--lr_scheduler_type", type=str, default="",
help="custom scheduler module")
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
@@ -1843,14 +1848,7 @@ def get_optimizer(args, trainable_params):
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
def get_scheduler_fix(args,optimizer: Optimizer):
"""
Unified API to get any scheduler from its name.
Args:
@@ -1871,6 +1869,45 @@ def get_scheduler_fix(
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps
num_training_steps = args.max_train_steps * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.optimizer_args) > 0:
for arg in args.lr_scheduler_args:
key, value = arg.split('=')
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = (value[i].lower() == "true")
else:
value[i] = float(value[i])
if len(value) == 1:
value = value[0]
else:
value = tuple(value)
lr_scheduler_kwargs[key] = value
# using any lr_scheduler from other library
if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
if "." not in lr_scheduler_type: # default to use torch.optim
lr_scheduler_module = torch.optim.lr_scheduler
else:
values = lr_scheduler_type.split(".")
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler
if name.startswith("adafactor"):
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(':')[1])

View File

@@ -150,9 +150,7 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:

View File

@@ -179,9 +179,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:

View File

@@ -235,9 +235,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(