fix to work linear/cosine scheduler closes #1651 ref #1393

This commit is contained in:
Kohya S
2024-09-29 23:18:16 +09:00
parent 1567549220
commit 012e7e63a5

View File

@@ -4496,6 +4496,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
**lr_scheduler_kwargs,
)
# these schedulers do not require `num_decay_steps`
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**lr_scheduler_kwargs,
)
# All other schedulers require `num_decay_steps`
if num_decay_steps is None:
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")