Show the moving average loss

This commit is contained in:
Yuta Hayashibe
2023-02-14 19:46:27 +09:00
parent b32abdd327
commit 21f5b618c3
2 changed files with 14 additions and 8 deletions

View File

@@ -206,6 +206,7 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
loss_list = []
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
@@ -216,7 +217,6 @@ def train(args):
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
@@ -291,8 +291,11 @@ def train(args):
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step+1)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_list[step] = current_loss
avr_loss = sum(loss_list) / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -300,7 +303,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"epoch_loss": sum(loss_list) / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()

View File

@@ -378,6 +378,7 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
loss_list = []
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
@@ -386,7 +387,6 @@ def train(args):
network.on_epoch_start(text_encoder, unet)
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network):
with torch.no_grad():
@@ -446,8 +446,11 @@ def train(args):
global_step += 1
current_loss = loss.detach().item()
loss_total += current_loss
avr_loss = loss_total / (step+1)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_list[step] = current_loss
avr_loss = sum(loss_list) / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -459,7 +462,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": sum(loss_list) / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()