Improve wandb logging (#1576)

* fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified

* fix: checking of whether wandb is enabled

* feat: log images to wandb with their positive prompt as captions

* feat: logging sample images' caption for sd3 and flux

* fix: import wandb before use
This commit is contained in:
Plat
2024-09-11 22:21:16 +09:00
committed by GitHub
parent d83f2e92da
commit a823fd9fb8
14 changed files with 80 additions and 49 deletions

View File

@@ -337,6 +337,9 @@ def train(args):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -456,7 +459,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
@@ -469,7 +472,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -629,6 +629,9 @@ def train(args):
# For --sample_at_first
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
@@ -777,7 +780,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
@@ -791,7 +794,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -254,17 +254,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
def time_shift(mu: float, sigma: float, t: torch.Tensor):

View File

@@ -604,17 +604,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
# region Diffusers

View File

@@ -5832,17 +5832,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
# endregion

View File

@@ -682,6 +682,9 @@ def train(args):
# For --sample_at_first
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# following function will be moved to sd3_train_utils
@@ -901,7 +904,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit)
@@ -915,7 +918,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -617,6 +617,9 @@ def train(args):
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -797,7 +800,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
if block_lrs is None:
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
@@ -814,7 +817,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -541,14 +541,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -480,14 +480,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -409,6 +409,9 @@ def train(args):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
for epoch in range(num_train_epochs):
@@ -542,14 +545,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -315,6 +315,9 @@ def train(args):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -445,7 +448,7 @@ def train(args):
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
@@ -458,7 +461,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -1038,6 +1038,9 @@ class NetworkTrainer:
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
@@ -1224,7 +1227,7 @@ class NetworkTrainer:
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
)
@@ -1233,7 +1236,7 @@ class NetworkTrainer:
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -550,6 +550,9 @@ class TextualInversionTrainer:
unet,
prompt_replacement,
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
for epoch in range(num_train_epochs):
@@ -684,7 +687,7 @@ class TextualInversionTrainer:
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -702,7 +705,7 @@ class TextualInversionTrainer:
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
@@ -739,6 +742,7 @@ class TextualInversionTrainer:
unet,
prompt_replacement,
)
accelerator.log({})
# end of epoch

View File

@@ -538,7 +538,7 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -556,7 +556,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)