mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Add LossRecorder and use moving average in all places
This commit is contained in:
14
train_db.py
14
train_db.py
@@ -264,8 +264,7 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -392,13 +391,8 @@ def train(args):
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.get_moving_average()
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
@@ -406,7 +400,7 @@ def train(args):
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
logs = {"loss/epoch": loss_recorder.get_moving_average()}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user