Implement pseudo Huber loss for Flux and SD3

This commit is contained in:
recris
2024-11-27 18:11:51 +00:00
parent 2a61fc0784
commit 420a180d93
15 changed files with 76 additions and 61 deletions

View File

@@ -380,7 +380,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents args, noise_scheduler, latents
) )
@@ -397,7 +397,7 @@ def train(args):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss # do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
@@ -411,7 +411,7 @@ def train(args):
loss = loss.mean() # mean over batch dimension loss = loss.mean() # mean over batch dimension
else: else:
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
accelerator.backward(loss) accelerator.backward(loss)

View File

@@ -667,7 +667,7 @@ def train(args):
# calculate loss # calculate loss
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
if weighting is not None: if weighting is not None:
loss = loss * weighting loss = loss * weighting

View File

@@ -468,7 +468,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
) )
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, None, weighting return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler): def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss return loss

View File

@@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--huber_c", "--huber_c",
type=float, type=float,
default=0.1, default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--huber_scale",
type=float,
default=1.0,
help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
) )
parser.add_argument( parser.add_argument(
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): def get_timesteps(min_timestep, max_timestep, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
timesteps = timesteps.long()
if args.loss_type == "huber" or args.loss_type == "smooth_l1": return timesteps
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = torch.exp(-alpha * timesteps)
elif args.huber_schedule == "snr":
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = torch.full((b_size,), args.huber_c)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
huber_c = huber_c.to(device)
elif args.loss_type == "l2":
huber_c = None # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long().to(device)
return timesteps, huber_c
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
@@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
@@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else: else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps, huber_c return noise, noisy_latents, timesteps
def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, 'alphas_cumprod'):
raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
return result
def conditional_loss( def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
): ):
if loss_type == "l2": if args.loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1": elif args.loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber": elif args.loss_type == "huber":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1) huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean": if reduction == "mean":
loss = torch.mean(loss) loss = torch.mean(loss)
elif reduction == "sum": elif reduction == "sum":
loss = torch.sum(loss) loss = torch.sum(loss)
elif loss_type == "smooth_l1": elif args.loss_type == "smooth_l1":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1) huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean": if reduction == "mean":
@@ -5903,7 +5913,7 @@ def conditional_loss(
elif reduction == "sum": elif reduction == "sum":
loss = torch.sum(loss) loss = torch.sum(loss)
else: else:
raise NotImplementedError(f"Unsupported Loss Type {loss_type}") raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
return loss return loss

View File

@@ -845,7 +845,7 @@ def train(args):
# ) # )
# calculate loss # calculate loss
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)

View File

@@ -378,7 +378,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, None, weighting return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler): def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss return loss

View File

@@ -695,7 +695,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents args, noise_scheduler, latents
) )
@@ -720,7 +720,7 @@ def train(args):
): ):
# do not mean over batch dimension for snr weight or scale v-pred loss # do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)
@@ -738,7 +738,7 @@ def train(args):
loss = loss.mean() # mean over batch dimension loss = loss.mean() # mean over batch dimension
else: else:
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
accelerator.backward(loss) accelerator.backward(loss)

View File

@@ -512,7 +512,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents args, noise_scheduler, latents
) )
@@ -534,7 +534,7 @@ def train(args):
target = noise target = noise
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])

View File

@@ -463,7 +463,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents args, noise_scheduler, latents
) )
@@ -485,7 +485,7 @@ def train(args):
target = noise target = noise
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])

View File

@@ -406,7 +406,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -426,7 +426,9 @@ def train(args):
else: else:
target = noise target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = train_util.conditional_loss(
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -464,8 +464,8 @@ def train(args):
) )
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c( timesteps = train_util.get_timesteps(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device
) )
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
@@ -499,7 +499,7 @@ def train(args):
target = noise target = noise
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])

View File

@@ -370,7 +370,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents args, noise_scheduler, latents
) )
@@ -385,7 +385,7 @@ def train(args):
target = noise target = noise
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)

View File

@@ -192,7 +192,7 @@ class NetworkTrainer:
): ):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# ensure the hidden state will require grad # ensure the hidden state will require grad
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -244,7 +244,7 @@ class NetworkTrainer:
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, huber_c, None return noise_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps, noise_scheduler): def post_process_loss(self, loss, args, timesteps, noise_scheduler):
if args.min_snr_gamma: if args.min_snr_gamma:
@@ -806,6 +806,7 @@ class NetworkTrainer:
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type, "ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule, "ss_huber_schedule": args.huber_schedule,
"ss_huber_scale": args.huber_scale,
"ss_huber_c": args.huber_c, "ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base), "ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet), "ss_fp8_base_unet": bool(args.fp8_base_unet),
@@ -1193,7 +1194,7 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i] text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target # sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args, args,
accelerator, accelerator,
noise_scheduler, noise_scheduler,
@@ -1207,7 +1208,7 @@ class NetworkTrainer:
) )
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
if weighting is not None: if weighting is not None:
loss = loss * weighting loss = loss * weighting

View File

@@ -585,7 +585,7 @@ class TextualInversionTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents args, noise_scheduler, latents
) )
@@ -602,7 +602,7 @@ class TextualInversionTrainer:
target = noise target = noise
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
) )
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)

View File

@@ -461,7 +461,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual # Predict the noise residual
with accelerator.autocast(): with accelerator.autocast():
@@ -473,7 +473,9 @@ def train(args):
else: else:
target = noise target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = train_util.conditional_loss(
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])