From c1ef6dcabc0fe47fd5483a1597e524419569a017 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 13:17:14 +0900 Subject: [PATCH] fix to work with wrapped optimizer by accelerate --- fine_tune.py | 4 ++-- sdxl_train.py | 6 +++--- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 17 ++++++++++------- train_controlnet.py | 4 ++-- train_db.py | 8 +++----- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 16 +++++++++------- 9 files changed, 35 insertions(+), 32 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b82a67ae..2c4d3685 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -271,8 +271,8 @@ def train(args): # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None diff --git a/sdxl_train.py b/sdxl_train.py index 8944f3a0..ed5a6493 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -435,8 +435,8 @@ def train(args): # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None @@ -644,7 +644,7 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op + lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op optimizer.zero_grad(set_to_none=True) optimizer_eval_if_needed() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6e0c2c8a..54b6d0b0 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -294,8 +294,8 @@ def train(args): # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 8585df13..babaa026 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -255,16 +256,14 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") - unet, network, optimizer, train_dataloader = accelerator.prepare( - unet, network, optimizer, train_dataloader - ) + unet, network, optimizer, train_dataloader = accelerator.prepare(unet, network, optimizer, train_dataloader) if not use_schedule_free_optimizer: lr_scheduler = accelerator.prepare(lr_scheduler) # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None @@ -419,7 +418,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -439,7 +440,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index 1785607b..bc9da356 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -284,8 +284,8 @@ def train(args): # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None diff --git a/train_db.py b/train_db.py index 36ed867a..c56630da 100644 --- a/train_db.py +++ b/train_db.py @@ -237,9 +237,7 @@ def train(args): else: if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader - ) + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader) training_models = [unet, text_encoder] else: unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) @@ -249,8 +247,8 @@ def train(args): # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None diff --git a/train_network.py b/train_network.py index 68341378..f287acac 100644 --- a/train_network.py +++ b/train_network.py @@ -446,8 +446,8 @@ class NetworkTrainer: # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None diff --git a/train_textual_inversion.py b/train_textual_inversion.py index de93273b..c71ec53d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -432,8 +432,8 @@ class TextualInversionTrainer: # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index cb38e798..a9d10d6e 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -336,16 +336,14 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") - text_encoder, optimizer, train_dataloader = accelerator.prepare( - text_encoder, optimizer, train_dataloader - ) + text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader) if not use_schedule_free_optimizer: lr_scheduler = accelerator.prepare(lr_scheduler) # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used if use_schedule_free_optimizer: - optimizer_train_if_needed = lambda: optimizer.train() - optimizer_eval_if_needed = lambda: optimizer.eval() + optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train() + optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval() else: optimizer_train_if_needed = lambda: None optimizer_eval_if_needed = lambda: None @@ -473,7 +471,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -485,7 +485,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])