mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: add LoRA training support for Chroma
This commit is contained in:
@@ -468,7 +468,7 @@ if __name__ == "__main__":
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
# DiT
|
||||
model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
|
||||
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
|
||||
@@ -270,7 +270,7 @@ def train(args):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
model_type, _, flux = flux_utils.load_flow_model(
|
||||
_, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +68,11 @@ def train(args):
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
if args.model_type != "flux":
|
||||
raise ValueError(
|
||||
f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。"
|
||||
)
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
@@ -258,7 +263,7 @@ def train(args):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
model_type, is_schnell, flux = flux_utils.load_flow_model(
|
||||
is_schnell, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
|
||||
)
|
||||
flux.requires_grad_(False)
|
||||
|
||||
@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
self.model_type: Optional[str] = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
@@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
self.model_type = args.model_type # "flux" or "chroma"
|
||||
if self.model_type != "chroma":
|
||||
self.use_clip_l = True
|
||||
else:
|
||||
self.use_clip_l = False # Chroma does not use CLIP-L
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||||
|
||||
@@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# prepare CLIP-L/T5XXL training flags
|
||||
self.train_clip_l = not args.network_train_unet_only
|
||||
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
@@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
self.model_type, self.is_schnell, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux"
|
||||
_, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
"cpu",
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
model_type=self.model_type,
|
||||
)
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
@@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
if self.use_clip_l:
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
else:
|
||||
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
|
||||
clip_l.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
@@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
|
||||
return model_version, [clip_l, t5xxl], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
|
||||
# Instead, we analyze the checkpoint state to determine if it is schnell.
|
||||
if args.model_type != "chroma":
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
else:
|
||||
is_schnell = False
|
||||
self.is_schnell = is_schnell
|
||||
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if is_schnell:
|
||||
if self.is_schnell:
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
@@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
@@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
# return
|
||||
|
||||
"""
|
||||
class FluxUpperLowerWrapper(torch.nn.Module):
|
||||
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
||||
super().__init__()
|
||||
self.flux_upper = flux_upper
|
||||
self.flux_lower = flux_lower
|
||||
self.target_device = device
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
pass
|
||||
|
||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
||||
self.flux_lower.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_upper.to(self.target_device)
|
||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_lower.to(self.target_device)
|
||||
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
"""
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
@@ -366,7 +340,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# ensure guidance_scale in args is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
# get modulation vectors for Chroma
|
||||
input_vec = None
|
||||
if self.model_type == "chroma":
|
||||
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
@@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t.requires_grad_(True)
|
||||
img_ids.requires_grad_(True)
|
||||
guidance_vec.requires_grad_(True)
|
||||
if input_vec is not None:
|
||||
input_vec.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
input_vec=input_vec,
|
||||
)
|
||||
return model_pred
|
||||
|
||||
@@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps,
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
input_vec=input_vec,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
@@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
@@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||
if self.model_type != "chroma":
|
||||
model_description = "schnell" if self.is_schnell else "dev"
|
||||
else:
|
||||
model_description = "chroma"
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_model_type"] = args.model_type
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
|
||||
@@ -601,13 +601,30 @@ class Chroma(Flux):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def get_mod_vectors(
|
||||
self,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
batch_size: int | None = None,
|
||||
requires_grad: bool = False,
|
||||
) -> Tensor:
|
||||
def get_model_type(self) -> str:
|
||||
return "chroma"
|
||||
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
self.cpu_offload_checkpointing = cpu_offload
|
||||
|
||||
self.distilled_guidance_layer.enable_gradient_checkpointing()
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print(f"Chroma: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
self.distilled_guidance_layer.disable_gradient_checkpointing()
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("Chroma: Gradient checkpointing disabled.")
|
||||
|
||||
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
|
||||
@@ -619,10 +636,7 @@ class Chroma(Flux):
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
if requires_grad:
|
||||
input_vec = input_vec.requires_grad_(True)
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
return mod_vectors
|
||||
return input_vec
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -637,7 +651,7 @@ class Chroma(Flux):
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
attn_padding: int = 1,
|
||||
mod_vectors: Tensor | None = None,
|
||||
input_vec: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# print(
|
||||
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
|
||||
@@ -651,7 +665,7 @@ class Chroma(Flux):
|
||||
img = self.img_in(img)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if mod_vectors is None:
|
||||
if input_vec is None:
|
||||
# TODO:
|
||||
# need to fix grad accumulation issue here for now it's in no grad mode
|
||||
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
|
||||
@@ -659,14 +673,18 @@ class Chroma(Flux):
|
||||
# alternatively doing forward pass for every block manually is doable but slow
|
||||
# custom backward probably be better
|
||||
with torch.no_grad():
|
||||
# kohya-ss: I'm not sure why requires_grad is set to True here
|
||||
mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0], requires_grad=True)
|
||||
input_vec = self.get_input_vec(timesteps, guidance, img.shape[0])
|
||||
|
||||
# kohya-ss: I'm not sure why requires_grad is set to True here
|
||||
input_vec.requires_grad = True
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
else:
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
# calculate text length for each batch instead of masking
|
||||
txt_emb_len = txt.shape[1]
|
||||
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, )
|
||||
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, )
|
||||
txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len)
|
||||
max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch
|
||||
|
||||
|
||||
@@ -930,6 +930,9 @@ class Flux(nn.Module):
|
||||
self.num_double_blocks = len(self.double_blocks)
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
def get_model_type(self) -> str:
|
||||
return "flux"
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@@ -1018,6 +1021,7 @@ class Flux(nn.Module):
|
||||
block_controlnet_single_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
input_vec: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -1169,7 +1173,7 @@ class ControlNetFlux(nn.Module):
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1320,174 +1324,3 @@ class ControlNetFlux(nn.Module):
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
|
||||
|
||||
return controlnet_block_samples, controlnet_single_block_samples
|
||||
|
||||
|
||||
"""
|
||||
class FluxUpper(nn.Module):
|
||||
""
|
||||
Transformer model for flow matching on sequences.
|
||||
""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
self.time_in.enable_gradient_checkpointing()
|
||||
self.vector_in.enable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.time_in.disable_gradient_checkpointing()
|
||||
self.vector_in.disable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.disable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
return img, txt, vec, pe
|
||||
|
||||
|
||||
class FluxLower(nn.Module):
|
||||
""
|
||||
Transformer model for flow matching on sequences.
|
||||
""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.out_channels = params.in_channels
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
vec: Tensor | None = None,
|
||||
pe: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
"""
|
||||
|
||||
@@ -154,9 +154,8 @@ def sample_image_inference(
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5)
|
||||
cfg_scale = prompt_dict.get("scale", 1.0)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -242,7 +241,7 @@ def sample_image_inference(
|
||||
dtype=weight_dtype,
|
||||
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
||||
)
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||
|
||||
@@ -403,8 +402,8 @@ def denoise(
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timesteps=t_vec.repeat(2),
|
||||
guidance=guidance_vec.repeat(2),
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
@@ -680,3 +679,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
choices=["flux", "chroma"],
|
||||
default="flux",
|
||||
help="Model type to use for training / トレーニングに使用するモデルタイプ:flux or chroma (default: flux)",
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from library.utils import load_safetensors
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
MODEL_VERSION_CHROMA = "chroma"
|
||||
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
@@ -97,7 +98,7 @@ def load_flow_model(
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
model_type: str = "flux",
|
||||
) -> Tuple[str, bool, flux_models.Flux]:
|
||||
) -> Tuple[bool, flux_models.Flux]:
|
||||
if model_type == "flux":
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
@@ -140,7 +141,7 @@ def load_flow_model(
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return model_type, is_schnell, model
|
||||
return is_schnell, model
|
||||
|
||||
elif model_type == "chroma":
|
||||
from . import chroma_models
|
||||
@@ -166,7 +167,7 @@ def load_flow_model(
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Chroma: {info}")
|
||||
is_schnell = False # Chroma is not schnell
|
||||
return model_type, is_schnell, model
|
||||
return is_schnell, model
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
|
||||
@@ -203,6 +204,42 @@ def load_controlnet(
|
||||
return controlnet
|
||||
|
||||
|
||||
def dummy_clip_l() -> torch.nn.Module:
|
||||
"""
|
||||
Returns a dummy CLIP-L model with the output shape of (N, 77, 768).
|
||||
"""
|
||||
return DummyCLIPL()
|
||||
|
||||
|
||||
class DummyTextModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embeddings = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
|
||||
class DummyCLIPL(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
|
||||
self.text_model = DummyTextModel()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.dummy_param.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.dummy_param.dtype
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Returns a dummy output with the shape of (N, 77, 768).
|
||||
"""
|
||||
batch_size = args[0].shape[0] if args else 1
|
||||
return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)}
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
ckpt_path: Optional[str],
|
||||
dtype: torch.dtype,
|
||||
|
||||
@@ -60,6 +60,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
||||
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_SCHNELL = "flux-1-schnell"
|
||||
ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
@@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -125,7 +128,7 @@ def build_metadata(
|
||||
flux: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
sd3: only supports "m", flux: only supports "dev"
|
||||
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
|
||||
"""
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
@@ -144,6 +147,10 @@ def build_metadata(
|
||||
elif flux is not None:
|
||||
if flux == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
elif flux == "schnell":
|
||||
arch = ARCH_FLUX_1_SCHNELL
|
||||
elif flux == "chroma":
|
||||
arch = ARCH_FLUX_1_CHROMA
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif v2:
|
||||
@@ -166,7 +173,10 @@ def build_metadata(
|
||||
|
||||
if flux is not None:
|
||||
# Flux
|
||||
impl = IMPL_FLUX
|
||||
if flux == "chroma":
|
||||
impl = IMPL_CHROMA
|
||||
else:
|
||||
impl = IMPL_FLUX
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
|
||||
@@ -3482,7 +3482,7 @@ def get_sai_model_spec(
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
flux: str = None, # "dev", "schnell" or "chroma"
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user