Min-SNR Weighting Strategy: Refactored and added to all trainers

This commit is contained in:
AI-Casanova
2023-03-22 01:25:49 +00:00
parent 795a6bd2d8
commit 64c923230e
6 changed files with 43 additions and 14 deletions

View File

@@ -19,7 +19,8 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
@@ -304,6 +305,9 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
@@ -396,6 +400,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")