From eb68892ab11b3af4ea6121d0f361b49aaa8d1724 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 9 Mar 2023 16:51:22 +0800 Subject: [PATCH] add lr_scheduler_type etc --- fine_tune.py | 4 +-- library/train_util.py | 53 ++++++++++++++++++++++++++++++++------ train_db.py | 4 +-- train_network.py | 4 +-- train_textual_inversion.py | 4 +-- 5 files changed, 49 insertions(+), 20 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 12557597..89bc1aa6 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -178,9 +178,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/library/train_util.py b/library/train_util.py index e15ce133..d912a45f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1518,6 +1518,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--optimizer_args", type=str, default=None, nargs='*', help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") + parser.add_argument("--lr_scheduler_type", type=str, default="", + help="custom scheduler module") + parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*', + help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") + parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor") parser.add_argument("--lr_warmup_steps", type=int, default=0, @@ -1843,14 +1848,7 @@ def get_optimizer(args, trainable_params): # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -def get_scheduler_fix( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - num_cycles: int = 1, - power: float = 1.0, -): +def get_scheduler_fix(args,optimizer: Optimizer): """ Unified API to get any scheduler from its name. Args: @@ -1871,6 +1869,45 @@ def get_scheduler_fix( last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ + + name = args.lr_scheduler + num_warmup_steps = args.lr_warmup_steps + num_training_steps = args.max_train_steps * args.gradient_accumulation_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.optimizer_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split('=') + + value = value.split(",") + for i in range(len(value)): + if value[i].lower() == "true" or value[i].lower() == "false": + value[i] = (value[i].lower() == "true") + else: + value[i] = float(value[i]) + if len(value) == 1: + value = value[0] + else: + value = tuple(value) + + lr_scheduler_kwargs[key] = value + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler + if name.startswith("adafactor"): assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(':')[1]) diff --git a/train_db.py b/train_db.py index a3021177..f246bd0c 100644 --- a/train_db.py +++ b/train_db.py @@ -150,9 +150,7 @@ def train(args): args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/train_network.py b/train_network.py index ef5a0831..cfc0b15a 100644 --- a/train_network.py +++ b/train_network.py @@ -179,9 +179,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d91a78ff..9f789517 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -235,9 +235,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer) # acceleratorがなんかよろしくやってくれるらしい text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(