diff --git a/library/train_util.py b/library/train_util.py index 422dceca..27910dc9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4496,6 +4496,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")