diff --git a/train_db.py b/train_db.py index c210767b..a3154cd1 100644 --- a/train_db.py +++ b/train_db.py @@ -206,6 +206,7 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("dreambooth") + loss_list = [] for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -216,7 +217,6 @@ def train(args): if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -291,8 +291,11 @@ def train(args): logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} accelerator.log(logs, step=global_step) - loss_total += current_loss - avr_loss = loss_total / (step+1) + if epoch == 0: + loss_list.append(current_loss) + else: + loss_list[step] = current_loss + avr_loss = sum(loss_list) / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -300,7 +303,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"epoch_loss": sum(loss_list) / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index bb3159fd..c9c3c468 100644 --- a/train_network.py +++ b/train_network.py @@ -378,6 +378,7 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("network_train") + loss_list = [] for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -386,7 +387,6 @@ def train(args): network.on_epoch_start(text_encoder, unet) - loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): @@ -446,8 +446,11 @@ def train(args): global_step += 1 current_loss = loss.detach().item() - loss_total += current_loss - avr_loss = loss_total / (step+1) + if epoch == 0: + loss_list.append(current_loss) + else: + loss_list[step] = current_loss + avr_loss = sum(loss_list) / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -459,7 +462,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": sum(loss_list) / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone()