mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)
* add huber loss and huber_c compute to train_util * add reduction modes * add huber_c retrieval from timestep getter * move get timesteps and huber to own function * add conditional loss to all training scripts * add cond loss to train network * add (scheduled) huber_loss to args * fixup twice timesteps getting * PHL-schedule should depend on noise scheduler's num timesteps * *2 multiplier to huber loss cause of 1/2 a^2 conv. The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another * add option for smooth l1 (huber / delta) * unify huber scheduling * add snr huber scheduler --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -354,7 +354,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -368,7 +368,7 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
@@ -380,7 +380,7 @@ def train(args):
|
|||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
|||||||
@@ -3236,6 +3236,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
|
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--loss_type",
|
||||||
|
type=str,
|
||||||
|
default="l2",
|
||||||
|
choices=["l2", "huber", "smooth_l1"],
|
||||||
|
help="The type of loss to use and whether it's scheduled based on the timestep"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huber_schedule",
|
||||||
|
type=str,
|
||||||
|
default="exponential",
|
||||||
|
choices=["constant", "exponential", "snr"],
|
||||||
|
help="The type of loss to use and whether it's scheduled based on the timestep"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huber_c",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lowram",
|
"--lowram",
|
||||||
@@ -4842,6 +4862,38 @@ def save_sd_model_on_train_end_common(
|
|||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
||||||
|
|
||||||
|
#TODO: if a huber loss is selected, it will use constant timesteps for each batch
|
||||||
|
# as. In the future there may be a smarter way
|
||||||
|
|
||||||
|
if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
|
||||||
|
timesteps = torch.randint(
|
||||||
|
min_timestep, max_timestep, (1,), device='cpu'
|
||||||
|
)
|
||||||
|
timestep = timesteps.item()
|
||||||
|
|
||||||
|
if args.huber_schedule == "exponential":
|
||||||
|
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||||
|
huber_c = math.exp(-alpha * timestep)
|
||||||
|
elif args.huber_schedule == "snr":
|
||||||
|
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
|
||||||
|
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
|
||||||
|
elif args.huber_schedule == "constant":
|
||||||
|
huber_c = args.huber_c
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')
|
||||||
|
|
||||||
|
timesteps = timesteps.repeat(b_size).to(device)
|
||||||
|
elif args.loss_type == 'l2':
|
||||||
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||||
|
huber_c = 1 # may be anything, as it's not used
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown loss type {args.loss_type}')
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
return timesteps, huber_c
|
||||||
|
|
||||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
@@ -4862,8 +4914,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
|||||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||||
|
|
||||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
|
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
|
||||||
timesteps = timesteps.long()
|
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
@@ -4876,8 +4927,28 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
|||||||
else:
|
else:
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
return noise, noisy_latents, timesteps
|
return noise, noisy_latents, timesteps, huber_c
|
||||||
|
|
||||||
|
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
|
||||||
|
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):
|
||||||
|
|
||||||
|
if loss_type == 'l2':
|
||||||
|
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||||
|
elif loss_type == 'huber':
|
||||||
|
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
|
if reduction == "mean":
|
||||||
|
loss = torch.mean(loss)
|
||||||
|
elif reduction == "sum":
|
||||||
|
loss = torch.sum(loss)
|
||||||
|
elif loss_type == 'smooth_l1':
|
||||||
|
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
|
if reduction == "mean":
|
||||||
|
loss = torch.mean(loss)
|
||||||
|
elif reduction == "sum":
|
||||||
|
loss = torch.sum(loss)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
|
||||||
|
return loss
|
||||||
|
|
||||||
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
||||||
names = []
|
names = []
|
||||||
|
|||||||
@@ -582,7 +582,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -600,7 +600,7 @@ def train(args):
|
|||||||
or args.masked_loss
|
or args.masked_loss
|
||||||
):
|
):
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -616,7 +616,7 @@ def train(args):
|
|||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
|||||||
@@ -439,7 +439,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -458,7 +458,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
|
|||||||
@@ -406,7 +406,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
|
|||||||
@@ -420,13 +420,8 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(
|
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
|
||||||
0,
|
|
||||||
noise_scheduler.config.num_train_timesteps,
|
|
||||||
(b_size,),
|
|
||||||
device=latents.device,
|
|
||||||
)
|
|
||||||
timesteps = timesteps.long()
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
@@ -457,7 +452,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -358,7 +358,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|||||||
@@ -843,7 +843,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -873,7 +873,7 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|||||||
@@ -572,7 +572,7 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -588,7 +588,7 @@ class TextualInversionTrainer:
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -473,7 +473,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|||||||
Reference in New Issue
Block a user