mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix: refactor huber-loss calculation in multiple training scripts
This commit is contained in:
13
fine_tune.py
13
fine_tune.py
@@ -380,9 +380,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
@@ -394,11 +392,10 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
@@ -410,9 +407,7 @@ def train(args):
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
|
||||
@@ -666,9 +666,8 @@ def train(args):
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
loss = train_util.conditional_loss(
|
||||
args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
|
||||
@@ -5869,7 +5869,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
b_size = timesteps.shape[0]
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
@@ -5890,22 +5893,20 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch
|
||||
|
||||
|
||||
def conditional_loss(
|
||||
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
|
||||
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
|
||||
):
|
||||
if args.loss_type == "l2":
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif args.loss_type == "l1":
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif args.loss_type == "huber":
|
||||
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
|
||||
elif loss_type == "huber":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif args.loss_type == "smooth_l1":
|
||||
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
|
||||
elif loss_type == "smooth_l1":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
@@ -5913,7 +5914,7 @@ def conditional_loss(
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
|
||||
raise NotImplementedError(f"Unsupported Loss Type: {loss_type}")
|
||||
return loss
|
||||
|
||||
|
||||
@@ -5923,7 +5924,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
||||
names.append("unet")
|
||||
names.append("text_encoder1")
|
||||
names.append("text_encoder2")
|
||||
names.append("text_encoder3") # SD3
|
||||
names.append("text_encoder3") # SD3
|
||||
|
||||
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
@@ -844,7 +844,8 @@ def train(args):
|
||||
# 1,
|
||||
# )
|
||||
# calculate loss
|
||||
loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
@@ -695,9 +695,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
@@ -711,6 +709,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
if (
|
||||
args.min_snr_gamma
|
||||
or args.scale_v_pred_loss_like_noise_pred
|
||||
@@ -719,9 +718,7 @@ def train(args):
|
||||
or args.masked_loss
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -737,9 +734,7 @@ def train(args):
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
||||
@@ -512,9 +512,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -533,9 +531,8 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -463,9 +463,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
@@ -484,9 +482,8 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -324,7 +325,9 @@ def train(args):
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
@@ -426,9 +429,8 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -307,10 +307,12 @@ def train(args):
|
||||
|
||||
if args.fused_backward_pass:
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
@@ -464,9 +466,7 @@ def train(args):
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = train_util.get_timesteps(
|
||||
0, noise_scheduler.config.num_train_timesteps, b_size, latents.device
|
||||
)
|
||||
timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
@@ -498,9 +498,8 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -370,9 +370,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
@@ -384,9 +382,8 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
@@ -1207,9 +1207,8 @@ class NetworkTrainer:
|
||||
train_unet,
|
||||
)
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
|
||||
@@ -601,9 +601,8 @@ class TextualInversionTrainer:
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
@@ -407,7 +407,9 @@ def train(args):
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
# function for saving/removing
|
||||
@@ -473,9 +475,8 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
Reference in New Issue
Block a user