diff --git a/README.md b/README.md index a0b02f10..1089dd00 100644 --- a/README.md +++ b/README.md @@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif ## FLUX.1 LoRA training (WIP) -__Aug 9, 2024__: +This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. + +Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 ``` +LoRAs for Text Encoders are not tested yet. + +We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: + +- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. +- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). + +`--loss_type` may be useful for FLUX.1 training. The default is `l2`. + +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. + +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` -Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. - ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_train_network.py b/flux_train_network.py index e4be97ad..69b6e8ea 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -135,7 +135,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): pass def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler @@ -211,21 +211,32 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) + else: + t = torch.rand((bsz,), device=accelerator.device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -264,11 +275,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - model_pred = model_pred * (-sigmas) + noisy_model_input + if args.model_prediction_type == "raw": + # use model_pred as is + weighting = None + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + weighting = None + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -278,6 +298,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser: default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) return parser diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index af073677..ad72ec00 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -59,6 +59,8 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" ARCH_SD3_M = "stable-diffusion-3-medium" ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -66,6 +68,7 @@ ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -118,10 +121,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - sd3: str = None, + sd3: Optional[str] = None, + flux: Optional[str] = None, ): """ - sd3: only supports "m" + sd3: only supports "m", flux: only supports "dev" """ # if state_dict is None, hash is not calculated @@ -140,6 +144,11 @@ def build_metadata( arch = ARCH_SD3_M else: arch = ARCH_SD3_UNKNOWN + elif flux is not None: + if flux == "dev": + arch = ARCH_FLUX_1_DEV + else: + arch = ARCH_FLUX_1_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -158,7 +167,10 @@ def build_metadata( if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if flux is not None: + # Flux + impl = IMPL_FLUX + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -216,7 +228,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl or sd3 is not None: + if sdxl or sd3 is not None or flux is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 @@ -227,7 +239,9 @@ def build_metadata( metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - if v_parameterization: + if flux is not None: + del metadata["modelspec.prediction_type"] + elif v_parameterization: metadata["modelspec.prediction_type"] = PRED_TYPE_V else: metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f194ccf6..13459d32 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -63,11 +63,11 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): l_pooled = None if t5xxl is not None and t5_tokens is not None: - # t5_out is [1, max length, 4096] + # t5_out is [b, max length, 4096] t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) if apply_t5_attn_mask: t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) - txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None diff --git a/library/train_util.py b/library/train_util.py index fc458a88..6b74bb3f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3186,6 +3186,7 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, + flux: str = None, ): timestamp = time.time() @@ -3220,6 +3221,7 @@ def get_sai_model_spec( timesteps=timesteps, clip_skip=args.clip_skip, # None or int sd3=sd3, + flux=flux, ) return metadata @@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "smooth_l1"], - help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", @@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 ): - if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 332a73d9..a4dab287 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index 48d98862..367203f5 100644 --- a/train_network.py +++ b/train_network.py @@ -226,6 +226,12 @@ class NetworkTrainer: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + # endregion def train(self, args): @@ -521,10 +527,13 @@ class NetworkTrainer: unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn - unet.to(accelerator.device) # this makes faster `to(dtype)` below + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) # this takes long time and large memory + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) @@ -718,8 +727,11 @@ class NetworkTrainer: "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, + "ss_fp8_base": args.fp8_base, } + self.update_metadata(metadata, args) # architecture specific metadata + if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time @@ -964,7 +976,7 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)