mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
fix to work with wrapped optimizer by accelerate
This commit is contained in:
@@ -271,8 +271,8 @@ def train(args):
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
@@ -435,8 +435,8 @@ def train(args):
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
@@ -644,7 +644,7 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
|
||||
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
@@ -294,8 +294,8 @@ def train(args):
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -255,16 +256,14 @@ def train(args):
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
unet, network, optimizer, train_dataloader = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader
|
||||
)
|
||||
unet, network, optimizer, train_dataloader = accelerator.prepare(unet, network, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
@@ -419,7 +418,9 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
@@ -439,7 +440,9 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -284,8 +284,8 @@ def train(args):
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
@@ -237,9 +237,7 @@ def train(args):
|
||||
|
||||
else:
|
||||
if train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
|
||||
training_models = [unet, text_encoder]
|
||||
else:
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
@@ -249,8 +247,8 @@ def train(args):
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
@@ -446,8 +446,8 @@ class NetworkTrainer:
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
@@ -432,8 +432,8 @@ class TextualInversionTrainer:
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
@@ -336,16 +336,14 @@ def train(args):
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
@@ -473,7 +471,9 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
@@ -485,7 +485,9 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
Reference in New Issue
Block a user