diff --git a/fine_tune.py b/fine_tune.py index 2ecb4ff3..0de4aff1 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -295,7 +295,7 @@ def train(args): for m in training_models: m.train() - loss_total = 0 + loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -405,9 +405,8 @@ def train(args): ) accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -415,7 +414,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/library/train_util.py b/library/train_util.py index 51610e70..7f7190b3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4685,3 +4685,20 @@ class collator_class: dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) return examples[0] + + +class LossRecorder: + def __init__(self): + self.loss_list: List[float] = [] + self.loss_total: float = 0.0 + + def add(self, *, epoch:int, step: int, loss: float) -> None: + if epoch == 0: + self.loss_list.append(loss) + else: + self.loss_total -= self.loss_list[step] + self.loss_list[step] = loss + self.loss_total += loss + + def get_moving_average(self) -> float: + return self.loss_total / len(self.loss_list) diff --git a/sdxl_train.py b/sdxl_train.py index 7bde3cab..5e5d528d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -459,7 +459,7 @@ def train(args): for m in training_models: m.train() - loss_total = 0 + loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -632,9 +632,8 @@ def train(args): accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -642,7 +641,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index a1b9cac8..221a1e45 100644 --- a/train_db.py +++ b/train_db.py @@ -264,8 +264,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -392,13 +391,8 @@ def train(args): ) accelerator.log(logs, step=global_step) - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -406,7 +400,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 2232a384..aeefe2a5 100644 --- a/train_network.py +++ b/train_network.py @@ -703,8 +703,7 @@ class NetworkTrainer: "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # callback for step start @@ -854,13 +853,8 @@ class NetworkTrainer: remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.get_moving_average() logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -875,7 +869,7 @@ class NetworkTrainer: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.get_moving_average()} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone()