diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0df61e84..8e975252 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -350,8 +350,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -500,14 +499,9 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +512,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 79920a97..066aca59 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -323,8 +323,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -470,14 +469,9 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -488,7 +482,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_controlnet.py b/train_controlnet.py index 5bc8d399..bbd915cb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -337,8 +337,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -500,14 +499,9 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +512,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone()