add zero_terminal_snr option

This commit is contained in:
Kohya S
2023-07-18 23:17:23 +09:00
parent 0ec7166098
commit 6d2d8dfd2f
9 changed files with 65 additions and 6 deletions

View File

@@ -23,8 +23,6 @@ from library.custom_train_functions import (
apply_snr_weight, apply_snr_weight,
get_weighted_text_embeddings, get_weighted_text_embeddings,
prepare_scheduler_for_custom_training, prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction, scale_v_prediction_loss_like_noise_prediction,
) )
@@ -273,6 +271,8 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)

View File

@@ -18,6 +18,42 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
noise_scheduler.all_snr = all_snr.to(device) noise_scheduler.all_snr = all_snr.to(device)
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
betas = noise_scheduler.betas
betas = enforce_zero_terminal_snr(betas)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# print("original:", noise_scheduler.betas)
# print("fixed:", betas)
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])

View File

@@ -343,9 +343,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def verify_sdxl_training_args(args: argparse.Namespace): def verify_sdxl_training_args(args: argparse.Namespace):
assert ( assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
not args.v2 and not args.v_parameterization if args.v_parameterization:
), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません" print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None: if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")

View File

@@ -2750,6 +2750,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算するNoneの場合は無効、デフォルト", help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算するNoneの場合は無効、デフォルト",
) )
parser.add_argument(
"--zero_terminal_snr",
action="store_true",
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
)
parser.add_argument( parser.add_argument(
"--min_timestep", "--min_timestep",
type=int, type=int,
@@ -2825,7 +2830,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
def verify_training_args(args: argparse.Namespace): def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2: if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None: if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -2856,6 +2861,12 @@ def verify_training_args(args: argparse.Namespace):
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
) )
if args.zero_terminal_snr and not args.v_parameterization:
print(
f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected"
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
)
def add_dataset_arguments( def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool

View File

@@ -350,6 +350,8 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)

View File

@@ -246,6 +246,8 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)

View File

@@ -487,6 +487,7 @@ class NetworkTrainer:
"ss_multires_noise_iterations": args.multires_noise_iterations, "ss_multires_noise_iterations": args.multires_noise_iterations,
"ss_multires_noise_discount": args.multires_noise_discount, "ss_multires_noise_discount": args.multires_noise_discount,
"ss_adaptive_noise_scale": args.adaptive_noise_scale, "ss_adaptive_noise_scale": args.adaptive_noise_scale,
"ss_zero_terminal_snr": args.zero_terminal_snr,
"ss_training_comment": args.training_comment, # will not be updated after training "ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
@@ -670,6 +671,8 @@ class NetworkTrainer:
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)

View File

@@ -487,6 +487,8 @@ class TextualInversionTrainer:
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)

View File

@@ -384,6 +384,8 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)