feat: add LoRA training support for Chroma

This commit is contained in:
Kohya S
2025-07-20 19:00:09 +09:00
parent c4958b5dca
commit b4e862626a
10 changed files with 158 additions and 260 deletions

View File

@@ -468,7 +468,7 @@ if __name__ == "__main__":
# t5xxl = accelerator.prepare(t5xxl)
# DiT
model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype

View File

@@ -270,7 +270,7 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
model_type, _, flux = flux_utils.load_flow_model(
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)

View File

@@ -68,6 +68,11 @@ def train(args):
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
if args.model_type != "flux":
raise ValueError(
f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。"
)
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -258,7 +263,7 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
model_type, is_schnell, flux = flux_utils.load_flow_model(
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)
flux.requires_grad_(False)

View File

@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
def assert_extra_args(
self,
@@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
self.model_type = args.model_type # "flux" or "chroma"
if self.model_type != "chroma":
self.use_clip_l = True
else:
self.use_clip_l = False # Chroma does not use CLIP-L
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
@@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
@@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
self.model_type, self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux"
_, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
)
if args.fp8_base:
# check dtype of model
@@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
else:
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
clip_l.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
@@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
return model_version, [clip_l, t5xxl], ae, model
def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
# Instead, we analyze the checkpoint state to determine if it is schnell.
if args.model_type != "chroma":
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
else:
is_schnell = False
self.is_schnell = is_schnell
if args.t5xxl_max_token_length is None:
if is_schnell:
if self.is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
@@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
@@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return
"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
self.flux_upper = flux_upper
self.flux_lower = flux_lower
self.target_device = device
def prepare_block_swap_before_forward(self):
pass
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
@@ -366,7 +340,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# ensure the hidden state will require grad
# get modulation vectors for Chroma
input_vec = None
if self.model_type == "chroma":
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
@@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
if input_vec is not None:
input_vec.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
input_vec=input_vec,
)
return model_pred
@@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
input_vec=input_vec,
)
# unpack latents
@@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
@@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
if self.model_type != "chroma":
model_description = "schnell" if self.is_schnell else "dev"
else:
model_description = "chroma"
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
def update_metadata(self, metadata, args):
metadata["ss_model_type"] = args.model_type
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean

View File

@@ -601,13 +601,30 @@ class Chroma(Flux):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def get_mod_vectors(
self,
timesteps: Tensor,
guidance: Tensor | None = None,
batch_size: int | None = None,
requires_grad: bool = False,
) -> Tensor:
def get_model_type(self) -> str:
return "chroma"
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
self.distilled_guidance_layer.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing()
print(f"Chroma: Gradient checkpointing enabled.")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.distilled_guidance_layer.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
print("Chroma: Gradient checkpointing disabled.")
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
# TODO: need to add toggle to omit this from schnell but that's not a priority
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
@@ -619,10 +636,7 @@ class Chroma(Flux):
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
if requires_grad:
input_vec = input_vec.requires_grad_(True)
mod_vectors = self.distilled_guidance_layer(input_vec)
return mod_vectors
return input_vec
def forward(
self,
@@ -637,7 +651,7 @@ class Chroma(Flux):
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
attn_padding: int = 1,
mod_vectors: Tensor | None = None,
input_vec: Tensor | None = None,
) -> Tensor:
# print(
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
@@ -651,7 +665,7 @@ class Chroma(Flux):
img = self.img_in(img)
txt = self.txt_in(txt)
if mod_vectors is None:
if input_vec is None:
# TODO:
# need to fix grad accumulation issue here for now it's in no grad mode
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
@@ -659,14 +673,18 @@ class Chroma(Flux):
# alternatively doing forward pass for every block manually is doable but slow
# custom backward probably be better
with torch.no_grad():
# kohya-ss: I'm not sure why requires_grad is set to True here
mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0], requires_grad=True)
input_vec = self.get_input_vec(timesteps, guidance, img.shape[0])
# kohya-ss: I'm not sure why requires_grad is set to True here
input_vec.requires_grad = True
mod_vectors = self.distilled_guidance_layer(input_vec)
else:
mod_vectors = self.distilled_guidance_layer(input_vec)
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
# calculate text length for each batch instead of masking
txt_emb_len = txt.shape[1]
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, )
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, )
txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len)
max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch

View File

@@ -930,6 +930,9 @@ class Flux(nn.Module):
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
def get_model_type(self) -> str:
return "flux"
@property
def device(self):
return next(self.parameters()).device
@@ -1018,6 +1021,7 @@ class Flux(nn.Module):
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
input_vec: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -1169,7 +1173,7 @@ class ControlNetFlux(nn.Module):
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(nn.Conv2d(16, 16, 3, padding=1))
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
)
@property
@@ -1320,174 +1324,3 @@ class ControlNetFlux(nn.Module):
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
return controlnet_block_samples, controlnet_single_block_samples
"""
class FluxUpper(nn.Module):
""
Transformer model for flow matching on sequences.
""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.gradient_checkpointing = False
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.enable_gradient_checkpointing()
for block in self.double_blocks:
block.enable_gradient_checkpointing()
print("FLUX: Gradient checkpointing enabled.")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.disable_gradient_checkpointing()
for block in self.double_blocks:
block.disable_gradient_checkpointing()
print("FLUX: Gradient checkpointing disabled.")
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
return img, txt, vec, pe
class FluxLower(nn.Module):
""
Transformer model for flow matching on sequences.
""
def __init__(self, params: FluxParams):
super().__init__()
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.out_channels = params.in_channels
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
self.gradient_checkpointing = False
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
for block in self.single_blocks:
block.enable_gradient_checkpointing()
print("FLUX: Gradient checkpointing enabled.")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
for block in self.single_blocks:
block.disable_gradient_checkpointing()
print("FLUX: Gradient checkpointing disabled.")
def forward(
self,
img: Tensor,
txt: Tensor,
vec: Tensor | None = None,
pe: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
"""

View File

@@ -154,9 +154,8 @@ def sample_image_inference(
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
# TODO refactor variable names
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
emb_guidance_scale = prompt_dict.get("scale", 3.5)
emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5)
cfg_scale = prompt_dict.get("scale", 1.0)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
@@ -242,7 +241,7 @@ def sample_image_inference(
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
@@ -403,8 +402,8 @@ def denoise(
y=torch.cat([neg_l_pooled, vec], dim=0),
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
timesteps=t_vec.repeat(2),
guidance=guidance_vec.repeat(2),
txt_attention_mask=nc_c_t5_attn_mask,
)
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
@@ -680,3 +679,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
parser.add_argument(
"--model_type",
type=str,
choices=["flux", "chroma"],
default="flux",
help="Model type to use for training / トレーニングに使用するモデルタイプflux or chroma (default: flux)",
)

View File

@@ -23,6 +23,7 @@ from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma"
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
@@ -97,7 +98,7 @@ def load_flow_model(
device: Union[str, torch.device],
disable_mmap: bool = False,
model_type: str = "flux",
) -> Tuple[str, bool, flux_models.Flux]:
) -> Tuple[bool, flux_models.Flux]:
if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
@@ -140,7 +141,7 @@ def load_flow_model(
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return model_type, is_schnell, model
return is_schnell, model
elif model_type == "chroma":
from . import chroma_models
@@ -166,7 +167,7 @@ def load_flow_model(
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell
return model_type, is_schnell, model
return is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
@@ -203,6 +204,42 @@ def load_controlnet(
return controlnet
def dummy_clip_l() -> torch.nn.Module:
"""
Returns a dummy CLIP-L model with the output shape of (N, 77, 768).
"""
return DummyCLIPL()
class DummyTextModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embeddings = torch.nn.Parameter(torch.zeros(1))
class DummyCLIPL(torch.nn.Module):
def __init__(self):
super().__init__()
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
self.text_model = DummyTextModel()
@property
def device(self):
return self.dummy_param.device
@property
def dtype(self):
return self.dummy_param.dtype
def forward(self, *args, **kwargs):
"""
Returns a dummy output with the shape of (N, 77, 768).
"""
batch_size = args[0].shape[0] if args else 1
return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)}
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,

View File

@@ -60,6 +60,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_SCHNELL = "flux-1-schnell"
ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
ARCH_FLUX_1_UNKNOWN = "flux-1"
ADAPTER_LORA = "lora"
@@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -125,7 +128,7 @@ def build_metadata(
flux: Optional[str] = None,
):
"""
sd3: only supports "m", flux: only supports "dev"
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
"""
# if state_dict is None, hash is not calculated
@@ -144,6 +147,10 @@ def build_metadata(
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
elif flux == "schnell":
arch = ARCH_FLUX_1_SCHNELL
elif flux == "chroma":
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
@@ -166,7 +173,10 @@ def build_metadata(
if flux is not None:
# Flux
impl = IMPL_FLUX
if flux == "chroma":
impl = IMPL_CHROMA
else:
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI

View File

@@ -3482,7 +3482,7 @@ def get_sai_model_spec(
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
):
timestamp = time.time()