fix to work with wrapped optimizer by accelerate

This commit is contained in:
Kohya S
2024-05-06 13:17:14 +09:00
parent 5fe9ded188
commit c1ef6dcabc
9 changed files with 35 additions and 32 deletions

View File

@@ -271,8 +271,8 @@ def train(args):
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

View File

@@ -435,8 +435,8 @@ def train(args):
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
@@ -644,7 +644,7 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()

View File

@@ -294,8 +294,8 @@ def train(args):
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -255,16 +256,14 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
unet, network, optimizer, train_dataloader = accelerator.prepare(unet, network, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
@@ -419,7 +418,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -439,7 +440,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -284,8 +284,8 @@ def train(args):
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

View File

@@ -237,9 +237,7 @@ def train(args):
else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
@@ -249,8 +247,8 @@ def train(args):
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

View File

@@ -446,8 +446,8 @@ class NetworkTrainer:
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

View File

@@ -432,8 +432,8 @@ class TextualInversionTrainer:
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

View File

@@ -336,16 +336,14 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
text_encoder, optimizer, train_dataloader = accelerator.prepare(
text_encoder, optimizer, train_dataloader
)
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
@@ -473,7 +471,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -485,7 +485,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])