Show "avr_loss" instead of "loss" because it is moving average

This commit is contained in:
Yuta Hayashibe
2023-10-27 17:59:58 +09:00
parent 3d2bb1a8f1
commit efef5c8ead
4 changed files with 4 additions and 4 deletions

View File

@@ -407,7 +407,7 @@ def train(args):
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:

View File

@@ -634,7 +634,7 @@ def train(args):
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:

View File

@@ -393,7 +393,7 @@ def train(args):
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:

View File

@@ -855,7 +855,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms: