Add LossRecorder and use moving average in all places

This commit is contained in:
Yuta Hayashibe
2023-10-27 17:49:49 +09:00
parent 2a23713f71
commit 3d2bb1a8f1
5 changed files with 33 additions and 30 deletions

View File

@@ -295,7 +295,7 @@ def train(args):
for m in training_models:
m.train()
loss_total = 0
loss_recorder = train_util.LossRecorder()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
@@ -405,9 +405,8 @@ def train(args):
)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -415,7 +414,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.get_moving_average()}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -4685,3 +4685,20 @@ class collator_class:
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
class LossRecorder:
def __init__(self):
self.loss_list: List[float] = []
self.loss_total: float = 0.0
def add(self, *, epoch:int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
def get_moving_average(self) -> float:
return self.loss_total / len(self.loss_list)

View File

@@ -459,7 +459,7 @@ def train(args):
for m in training_models:
m.train()
loss_total = 0
loss_recorder = train_util.LossRecorder()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
@@ -632,9 +632,8 @@ def train(args):
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -642,7 +641,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.get_moving_average()}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -264,8 +264,7 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -392,13 +391,8 @@ def train(args):
)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -406,7 +400,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.get_moving_average()}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -703,8 +703,7 @@ class NetworkTrainer:
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# callback for step start
@@ -854,13 +853,8 @@ class NetworkTrainer:
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -875,7 +869,7 @@ class NetworkTrainer:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.get_moving_average()}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()