Use @property

This commit is contained in:
Yuta Hayashibe
2023-10-27 18:14:27 +09:00
parent efef5c8ead
commit 0d21925bdf
5 changed files with 10 additions and 9 deletions

View File

@@ -406,7 +406,7 @@ def train(args):
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -414,7 +414,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -4700,5 +4700,6 @@ class LossRecorder:
self.loss_list[step] = loss
self.loss_total += loss
def get_moving_average(self) -> float:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)

View File

@@ -633,7 +633,7 @@ def train(args):
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -641,7 +641,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -392,7 +392,7 @@ def train(args):
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -400,7 +400,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -854,7 +854,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -869,7 +869,7 @@ class NetworkTrainer:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()