update FLUX LoRA training

This commit is contained in:
Kohya S
2024-08-10 23:42:05 +09:00
parent 358f13f2c9
commit 8a0f12dde8
7 changed files with 148 additions and 39 deletions

View File

@@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif
## FLUX.1 LoRA training (WIP)
__Aug 9, 2024__:
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options.
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2
```
LoRAs for Text Encoders are not tested yet.
We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows:
- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux).
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3).
- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3).
`--loss_type` may be useful for FLUX.1 training. The default is `l2`.
In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings.
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
The trained LoRA model can be used with ComfyUI.
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```
Unfortnately the training result is not good. Please let us know if you have any idea to improve the training.
## SD3 training
SD3 training is done with `sd3_train.py`.

View File

@@ -135,7 +135,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
pass
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
@@ -211,21 +211,32 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device))
else:
t = torch.rand((bsz,), device=accelerator.device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
# Add noise according to flow matching.
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# Add noise according to flow matching.
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
@@ -264,11 +275,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
model_pred = model_pred * (-sigmas) + noisy_model_input
if args.model_prediction_type == "raw":
# use model_pred as is
weighting = None
elif args.model_prediction_type == "additive":
# add the model_pred to the noisy_model_input
model_pred = model_pred + noisy_model_input
weighting = None
elif args.model_prediction_type == "sigma_scaled":
# apply sigma scaling
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
@@ -278,6 +298,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
def update_metadata(self, metadata, args):
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_guidance_scale"] = args.guidance_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
@@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser:
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法sigma、random uniform、またはrandom normalのsigmoid。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
return parser

View File

@@ -59,6 +59,8 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3-medium"
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -66,6 +68,7 @@ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -118,10 +121,11 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: str = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
):
"""
sd3: only supports "m"
sd3: only supports "m", flux: only supports "dev"
"""
# if state_dict is None, hash is not calculated
@@ -140,6 +144,11 @@ def build_metadata(
arch = ARCH_SD3_M
else:
arch = ARCH_SD3_UNKNOWN
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -158,7 +167,10 @@ def build_metadata(
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if flux is not None:
# Flux
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -216,7 +228,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None:
if sdxl or sd3 is not None or flux is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
@@ -227,7 +239,9 @@ def build_metadata(
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON

View File

@@ -63,11 +63,11 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
l_pooled = None
if t5xxl is not None and t5_tokens is not None:
# t5_out is [1, max length, 4096]
# t5_out is [b, max length, 4096]
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None

View File

@@ -3186,6 +3186,7 @@ def get_sai_model_spec(
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None,
):
timestamp = time.time()
@@ -3220,6 +3221,7 @@ def get_sai_model_spec(
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
)
return metadata
@@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類L2、Huber、またはsmooth L1、デフォルトはL2",
choices=["l1", "l2", "huber", "smooth_l1"],
help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
@@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
):
if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":

View File

@@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
class LoRANetwork(torch.nn.Module):
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_FLUX = "lora_flux"
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"

View File

@@ -226,6 +226,12 @@ class NetworkTrainer:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
def update_metadata(self, metadata, args):
pass
# endregion
def train(self, args):
@@ -521,10 +527,13 @@ class NetworkTrainer:
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.to(accelerator.device) # this makes faster `to(dtype)` below
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
@@ -718,8 +727,11 @@ class NetworkTrainer:
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
"ss_fp8_base": args.fp8_base,
}
self.update_metadata(metadata, args) # architecture specific metadata
if use_user_config:
# save metadata of multiple datasets
# NOTE: pack "ss_datasets" value as json one time
@@ -964,7 +976,7 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch_no)
metadata_to_save = minimum_metadata if args.no_metadata else metadata
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
sai_metadata = self.get_sai_model_spec(args)
metadata_to_save.update(sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)