mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Implement pseudo Huber loss for Flux and SD3
This commit is contained in:
@@ -380,7 +380,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -397,7 +397,7 @@ def train(args):
|
|||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
@@ -411,7 +411,7 @@ def train(args):
|
|||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|||||||
@@ -667,7 +667,7 @@ def train(args):
|
|||||||
|
|
||||||
# calculate loss
|
# calculate loss
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
if weighting is not None:
|
if weighting is not None:
|
||||||
loss = loss * weighting
|
loss = loss * weighting
|
||||||
|
|||||||
@@ -468,7 +468,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
)
|
)
|
||||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||||
|
|
||||||
return model_pred, target, timesteps, None, weighting
|
return model_pred, target, timesteps, weighting
|
||||||
|
|
||||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
"--huber_c",
|
"--huber_c",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.1,
|
default=0.1,
|
||||||
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
|
help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--huber_scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
|
|||||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
|
||||||
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
def get_timesteps(min_timestep, max_timestep, b_size, device):
|
||||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
return timesteps
|
||||||
if args.huber_schedule == "exponential":
|
|
||||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
|
||||||
huber_c = torch.exp(-alpha * timesteps)
|
|
||||||
elif args.huber_schedule == "snr":
|
|
||||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
|
|
||||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
|
||||||
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
|
||||||
elif args.huber_schedule == "constant":
|
|
||||||
huber_c = torch.full((b_size,), args.huber_c)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
|
||||||
huber_c = huber_c.to(device)
|
|
||||||
elif args.loss_type == "l2":
|
|
||||||
huber_c = None # may be anything, as it's not used
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
|
|
||||||
|
|
||||||
timesteps = timesteps.long().to(device)
|
|
||||||
return timesteps, huber_c
|
|
||||||
|
|
||||||
|
|
||||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||||
@@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
|||||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||||
|
|
||||||
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
|
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
@@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
|||||||
else:
|
else:
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
return noise, noisy_latents, timesteps, huber_c
|
return noise, noisy_latents, timesteps
|
||||||
|
|
||||||
|
|
||||||
|
def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
|
||||||
|
b_size = timesteps.shape[0]
|
||||||
|
if args.huber_schedule == "exponential":
|
||||||
|
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||||
|
result = torch.exp(-alpha * timesteps) * args.huber_scale
|
||||||
|
elif args.huber_schedule == "snr":
|
||||||
|
if not hasattr(noise_scheduler, 'alphas_cumprod'):
|
||||||
|
raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.")
|
||||||
|
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
|
||||||
|
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||||
|
result = result.to(timesteps.device)
|
||||||
|
elif args.huber_schedule == "constant":
|
||||||
|
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def conditional_loss(
|
def conditional_loss(
|
||||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
|
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
|
||||||
):
|
):
|
||||||
if loss_type == "l2":
|
if args.loss_type == "l2":
|
||||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||||
elif loss_type == "l1":
|
elif args.loss_type == "l1":
|
||||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||||
elif loss_type == "huber":
|
elif args.loss_type == "huber":
|
||||||
|
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
|
||||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
loss = torch.mean(loss)
|
loss = torch.mean(loss)
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
loss = torch.sum(loss)
|
loss = torch.sum(loss)
|
||||||
elif loss_type == "smooth_l1":
|
elif args.loss_type == "smooth_l1":
|
||||||
|
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
|
||||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
@@ -5903,7 +5913,7 @@ def conditional_loss(
|
|||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
loss = torch.sum(loss)
|
loss = torch.sum(loss)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
|
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -845,7 +845,7 @@ def train(args):
|
|||||||
# )
|
# )
|
||||||
# calculate loss
|
# calculate loss
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
|
args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||||
|
|
||||||
return model_pred, target, timesteps, None, weighting
|
return model_pred, target, timesteps, weighting
|
||||||
|
|
||||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -695,7 +695,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -720,7 +720,7 @@ def train(args):
|
|||||||
):
|
):
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
@@ -738,7 +738,7 @@ def train(args):
|
|||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|||||||
@@ -512,7 +512,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -534,7 +534,7 @@ def train(args):
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
|||||||
@@ -463,7 +463,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -485,7 +485,7 @@ def train(args):
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
|||||||
@@ -406,7 +406,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -426,7 +426,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
|
|||||||
@@ -464,8 +464,8 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
|
timesteps = train_util.get_timesteps(
|
||||||
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
|
0, noise_scheduler.config.num_train_timesteps, b_size, latents.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
@@ -499,7 +499,7 @@ def train(args):
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -385,7 +385,7 @@ def train(args):
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class NetworkTrainer:
|
|||||||
):
|
):
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
# ensure the hidden state will require grad
|
# ensure the hidden state will require grad
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@@ -244,7 +244,7 @@ class NetworkTrainer:
|
|||||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||||
|
|
||||||
return noise_pred, target, timesteps, huber_c, None
|
return noise_pred, target, timesteps, None
|
||||||
|
|
||||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
@@ -806,6 +806,7 @@ class NetworkTrainer:
|
|||||||
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
|
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
|
||||||
"ss_loss_type": args.loss_type,
|
"ss_loss_type": args.loss_type,
|
||||||
"ss_huber_schedule": args.huber_schedule,
|
"ss_huber_schedule": args.huber_schedule,
|
||||||
|
"ss_huber_scale": args.huber_scale,
|
||||||
"ss_huber_c": args.huber_c,
|
"ss_huber_c": args.huber_c,
|
||||||
"ss_fp8_base": bool(args.fp8_base),
|
"ss_fp8_base": bool(args.fp8_base),
|
||||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||||
@@ -1193,7 +1194,7 @@ class NetworkTrainer:
|
|||||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
|
||||||
# sample noise, call unet, get target
|
# sample noise, call unet, get target
|
||||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||||
args,
|
args,
|
||||||
accelerator,
|
accelerator,
|
||||||
noise_scheduler,
|
noise_scheduler,
|
||||||
@@ -1207,7 +1208,7 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
if weighting is not None:
|
if weighting is not None:
|
||||||
loss = loss * weighting
|
loss = loss * weighting
|
||||||
|
|||||||
@@ -585,7 +585,7 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -602,7 +602,7 @@ class TextualInversionTrainer:
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
)
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -473,7 +473,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||||
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|||||||
Reference in New Issue
Block a user