From 066b1bb57e58603ce21acb6c3c7aaddc19338153 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 20:47:11 +0900 Subject: [PATCH] fix do not mean in batch dim when min_snr_gamma --- fine_tune.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 8ae1bb29..0f42741b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,8 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight + def train(args): train_util.verify_training_args(args) @@ -62,9 +63,9 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + collater = train_util.collater_class(current_epoch, current_step) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) @@ -196,7 +197,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -260,7 +263,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 for m in training_models: m.train() @@ -308,10 +311,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + # do not mean over batch dimension for snr weight + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # mean over batch dimension + else: + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -406,7 +413,6 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")