diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 69878750..bfe752d5 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # copy from Diffusers + # Dependencies of Diffusers noise sampler has been removed for clearity. parser.add_argument( "--weighting_scheme", type=str, - default="logit_normal", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( @@ -279,8 +279,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) - - + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -965,14 +970,20 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift - # Add noise according to flow matching. - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) + + indices = (u * (t_max-t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to dlowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - -# endregion