mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
update FLUX LoRA training
This commit is contained in:
29
README.md
29
README.md
@@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
## FLUX.1 LoRA training (WIP)
|
||||
|
||||
__Aug 9, 2024__:
|
||||
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||
|
||||
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
|
||||
|
||||
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
||||
|
||||
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options.
|
||||
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs.
|
||||
|
||||
```
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2
|
||||
```
|
||||
|
||||
LoRAs for Text Encoders are not tested yet.
|
||||
|
||||
We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows:
|
||||
|
||||
- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux).
|
||||
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
|
||||
- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3).
|
||||
- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3).
|
||||
|
||||
`--loss_type` may be useful for FLUX.1 training. The default is `l2`.
|
||||
|
||||
In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings.
|
||||
|
||||
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
|
||||
|
||||
The trained LoRA model can be used with ComfyUI.
|
||||
|
||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||
|
||||
```
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||
```
|
||||
|
||||
Unfortnately the training result is not good. Please let us know if you have any idea to improve the training.
|
||||
|
||||
## SD3 training
|
||||
|
||||
SD3 training is done with `sd3_train.py`.
|
||||
|
||||
@@ -135,7 +135,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
pass
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
@@ -211,21 +211,32 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=accelerator.device)
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
@@ -264,11 +275,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
if args.model_prediction_type == "raw":
|
||||
# use model_pred as is
|
||||
weighting = None
|
||||
elif args.model_prediction_type == "additive":
|
||||
# add the model_pred to the noisy_model_input
|
||||
model_pred = model_pred + noisy_model_input
|
||||
weighting = None
|
||||
elif args.model_prediction_type == "sigma_scaled":
|
||||
# apply sigma scaling
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
@@ -278,6 +298,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
@@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=3.5,
|
||||
help="the FLUX.1 dev variant is a guidance distilled model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid"],
|
||||
default="sigma",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_prediction_type",
|
||||
choices=["raw", "additive", "sigma_scaled"],
|
||||
default="sigma_scaled",
|
||||
help="How to interpret and process the model prediction: "
|
||||
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
||||
" / モデル予測の解釈と処理方法:"
|
||||
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_SD3_M = "stable-diffusion-3-medium"
|
||||
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -66,6 +68,7 @@ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -118,10 +121,11 @@ def build_metadata(
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
sd3: str = None,
|
||||
sd3: Optional[str] = None,
|
||||
flux: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
sd3: only supports "m"
|
||||
sd3: only supports "m", flux: only supports "dev"
|
||||
"""
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
@@ -140,6 +144,11 @@ def build_metadata(
|
||||
arch = ARCH_SD3_M
|
||||
else:
|
||||
arch = ARCH_SD3_UNKNOWN
|
||||
elif flux is not None:
|
||||
if flux == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
@@ -158,7 +167,10 @@ def build_metadata(
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
if flux is not None:
|
||||
# Flux
|
||||
impl = IMPL_FLUX
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -216,7 +228,7 @@ def build_metadata(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl or sd3 is not None:
|
||||
if sdxl or sd3 is not None or flux is not None:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
@@ -227,7 +239,9 @@ def build_metadata(
|
||||
|
||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||
|
||||
if v_parameterization:
|
||||
if flux is not None:
|
||||
del metadata["modelspec.prediction_type"]
|
||||
elif v_parameterization:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||
else:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||
|
||||
@@ -63,11 +63,11 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
l_pooled = None
|
||||
|
||||
if t5xxl is not None and t5_tokens is not None:
|
||||
# t5_out is [1, max length, 4096]
|
||||
# t5_out is [b, max length, 4096]
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
|
||||
if apply_t5_attn_mask:
|
||||
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||
txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device)
|
||||
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
||||
else:
|
||||
t5_out = None
|
||||
txt_ids = None
|
||||
|
||||
@@ -3186,6 +3186,7 @@ def get_sai_model_spec(
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -3220,6 +3221,7 @@ def get_sai_model_spec(
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
sd3=sd3,
|
||||
flux=flux,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--loss_type",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "huber", "smooth_l1"],
|
||||
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
|
||||
choices=["l1", "l2", "huber", "smooth_l1"],
|
||||
help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huber_schedule",
|
||||
@@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
||||
):
|
||||
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
|
||||
@@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_FLUX = "lora_flux"
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||
|
||||
|
||||
@@ -226,6 +226,12 @@ class NetworkTrainer:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -521,10 +527,13 @@ class NetworkTrainer:
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
unet.to(accelerator.device) # this makes faster `to(dtype)` below
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
@@ -718,8 +727,11 @@ class NetworkTrainer:
|
||||
"ss_loss_type": args.loss_type,
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": args.fp8_base,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
|
||||
if use_user_config:
|
||||
# save metadata of multiple datasets
|
||||
# NOTE: pack "ss_datasets" value as json one time
|
||||
@@ -964,7 +976,7 @@ class NetworkTrainer:
|
||||
metadata["ss_epoch"] = str(epoch_no)
|
||||
|
||||
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
||||
sai_metadata = self.get_sai_model_spec(args)
|
||||
metadata_to_save.update(sai_metadata)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
|
||||
Reference in New Issue
Block a user