diff --git a/fine_tune.py b/fine_tune.py index 4a3f49c7..893066f7 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -289,6 +289,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -296,7 +297,6 @@ def train(args): for m in training_models: m.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -408,17 +408,16 @@ def train(args): ) accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/library/train_util.py b/library/train_util.py index 0109b42a..0d592036 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4697,3 +4697,21 @@ class collator_class: dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) return examples[0] + + +class LossRecorder: + def __init__(self): + self.loss_list: List[float] = [] + self.loss_total: float = 0.0 + + def add(self, *, epoch:int, step: int, loss: float) -> None: + if epoch == 0: + self.loss_list.append(loss) + else: + self.loss_total -= self.loss_list[step] + self.loss_list[step] = loss + self.loss_total += loss + + @property + def moving_average(self) -> float: + return self.loss_total / len(self.loss_list) diff --git a/sdxl_train.py b/sdxl_train.py index c368f27c..f067acd5 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -451,6 +451,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -458,7 +459,6 @@ def train(args): for m in training_models: m.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -633,17 +633,16 @@ def train(args): accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 7a141bb4..54abf697 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -351,8 +351,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -503,14 +502,9 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -521,7 +515,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e256badc..f00f10ea 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,8 +324,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -473,14 +472,9 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -491,7 +485,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_controlnet.py b/train_controlnet.py index 5bc8d399..bbd915cb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -337,8 +337,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -500,14 +499,9 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +512,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index 7316c27e..59a124a2 100644 --- a/train_db.py +++ b/train_db.py @@ -265,8 +265,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -395,21 +394,16 @@ def train(args): ) accelerator.log(logs, step=global_step) - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 38934c74..d50916b7 100644 --- a/train_network.py +++ b/train_network.py @@ -710,8 +710,7 @@ class NetworkTrainer: "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # callback for step start @@ -863,14 +862,9 @@ class NetworkTrainer: remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.scale_weight_norms: @@ -884,7 +878,7 @@ class NetworkTrainer: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone()