mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
call optimizer eval/train fn before/after validation
This commit is contained in:
@@ -1381,6 +1381,8 @@ class NetworkTrainer:
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
optimizer_eval_fn()
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
|
||||
)
|
||||
@@ -1429,6 +1431,8 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
optimizer_train_fn()
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -1438,6 +1442,8 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps),
|
||||
smoothing=0,
|
||||
@@ -1493,6 +1499,8 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
optimizer_train_fn()
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
|
||||
Reference in New Issue
Block a user