From 208ad06be063d5d5283685f0e1be09cea395e761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:25:54 +0800 Subject: [PATCH] Update fine_tune.py for DAdapt --- fine_tune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index b6a8d1d7..4227dd04 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -376,7 +376,7 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] )