From b4e862626aaba996ffe8b7f942ce5ce21d762919 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 19:00:09 +0900 Subject: [PATCH] feat: add LoRA training support for Chroma --- flux_minimal_inference.py | 2 +- flux_train.py | 2 +- flux_train_control_net.py | 7 +- flux_train_network.py | 102 +++++++++------------ library/chroma_models.py | 50 ++++++---- library/flux_models.py | 177 +----------------------------------- library/flux_train_utils.py | 19 ++-- library/flux_utils.py | 43 ++++++++- library/sai_model_spec.py | 14 ++- library/train_util.py | 2 +- 10 files changed, 158 insertions(+), 260 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 550904d2..86e8e1b1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -468,7 +468,7 @@ if __name__ == "__main__": # t5xxl = accelerator.prepare(t5xxl) # DiT - model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train.py b/flux_train.py index 1d2cc68b..84db34cf 100644 --- a/flux_train.py +++ b/flux_train.py @@ -270,7 +270,7 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - model_type, _, flux = flux_utils.load_flow_model( + _, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 3c038c32..93c20dab 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -68,6 +68,11 @@ def train(args): if not args.skip_cache_check: args.skip_cache_check = args.skip_latents_validity_check + if args.model_type != "flux": + raise ValueError( + f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。" + ) + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -258,7 +263,7 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - model_type, is_schnell, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) flux.requires_grad_(False) diff --git a/flux_train_network.py b/flux_train_network.py index b2bf8e7c..1b61ac72 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False + self.model_type: Optional[str] = None def assert_extra_args( self, @@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + self.model_type = args.model_type # "flux" or "chroma" + if self.model_type != "chroma": + self.use_clip_l = True + else: + self.use_clip_l = False # Chroma does not use CLIP-L + if args.fp8_base_unet: args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 @@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # prepare CLIP-L/T5XXL training flags - self.train_clip_l = not args.network_train_unet_only + self.train_clip_l = not args.network_train_unet_only and self.use_clip_l self.train_t5xxl = False # default is False even if args.network_train_unet_only is False if args.max_token_length is not None: @@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.model_type, self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux" + _, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, + loading_dtype, + "cpu", + disable_mmap=args.disable_mmap_load_safetensors, + model_type=self.model_type, ) if args.fp8_base: # check dtype of model @@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") model.enable_block_swap(args.blocks_to_swap, accelerator.device) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + if self.use_clip_l: + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + else: + clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L clip_l.eval() # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) @@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA + return model_version, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + # This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here. + # Instead, we analyze the checkpoint state to determine if it is schnell. + if args.model_type != "chroma": + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + else: + is_schnell = False + self.is_schnell = is_schnell if args.t5xxl_max_token_length is None: - if is_schnell: + if self.is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 @@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device) - # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - # # get size embeddings - # orig_size = batch["original_sizes_hw"] - # crop_size = batch["crop_top_lefts"] - # target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # # concat embeddings - # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - # return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): text_encoders = text_encoder # for compatibility text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) @@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): flux_train_utils.sample_images( accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) - # return - - """ - class FluxUpperLowerWrapper(torch.nn.Module): - def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): - super().__init__() - self.flux_upper = flux_upper - self.flux_lower = flux_lower - self.target_device = device - - def prepare_block_swap_before_forward(self): - pass - - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): - self.flux_lower.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) - self.flux_upper.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe, txt_attention_mask) - - wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) - clean_memory_on_device(accelerator.device) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs - ) - clean_memory_on_device(accelerator.device) - """ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -366,7 +340,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # ensure guidance_scale in args is float guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - # ensure the hidden state will require grad + # get modulation vectors for Chroma + input_vec = None + if self.model_type == "chroma": + input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz) + if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: @@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) + if input_vec is not None: + input_vec.requires_grad_(True) # Predict the noise residual l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps / 1000, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, + input_vec=input_vec, ) return model_pred @@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps, guidance_vec=guidance_vec, t5_attn_mask=t5_attn_mask, + input_vec=input_vec, ) # unpack latents @@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps[diff_output_pr_indices], guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step @@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + if self.model_type != "chroma": + model_description = "schnell" if self.is_schnell else "dev" + else: + model_description = "chroma" + return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description) def update_metadata(self, metadata, args): + metadata["ss_model_type"] = args.model_type metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_logit_mean"] = args.logit_mean diff --git a/library/chroma_models.py b/library/chroma_models.py index 1b62f20f..e5d3b547 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -601,13 +601,30 @@ class Chroma(Flux): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def get_mod_vectors( - self, - timesteps: Tensor, - guidance: Tensor | None = None, - batch_size: int | None = None, - requires_grad: bool = False, - ) -> Tensor: + def get_model_type(self) -> str: + return "chroma" + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.distilled_guidance_layer.enable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print(f"Chroma: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.distilled_guidance_layer.disable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("Chroma: Gradient checkpointing disabled.") + + def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) # TODO: need to add toggle to omit this from schnell but that's not a priority distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) @@ -619,10 +636,7 @@ class Chroma(Flux): timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) # then and only then we could concatenate it together input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) - if requires_grad: - input_vec = input_vec.requires_grad_(True) - mod_vectors = self.distilled_guidance_layer(input_vec) - return mod_vectors + return input_vec def forward( self, @@ -637,7 +651,7 @@ class Chroma(Flux): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, attn_padding: int = 1, - mod_vectors: Tensor | None = None, + input_vec: Tensor | None = None, ) -> Tensor: # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" @@ -651,7 +665,7 @@ class Chroma(Flux): img = self.img_in(img) txt = self.txt_in(txt) - if mod_vectors is None: + if input_vec is None: # TODO: # need to fix grad accumulation issue here for now it's in no grad mode # besides, i don't want to wash out the PFP that's trained on this model weights anyway @@ -659,14 +673,18 @@ class Chroma(Flux): # alternatively doing forward pass for every block manually is doable but slow # custom backward probably be better with torch.no_grad(): - # kohya-ss: I'm not sure why requires_grad is set to True here - mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0], requires_grad=True) + input_vec = self.get_input_vec(timesteps, guidance, img.shape[0]) + # kohya-ss: I'm not sure why requires_grad is set to True here + input_vec.requires_grad = True + mod_vectors = self.distilled_guidance_layer(input_vec) + else: + mod_vectors = self.distilled_guidance_layer(input_vec) mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) # calculate text length for each batch instead of masking txt_emb_len = txt.shape[1] - txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, ) + txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, ) txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..6f889755 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -930,6 +930,9 @@ class Flux(nn.Module): self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) + def get_model_type(self) -> str: + return "flux" + @property def device(self): return next(self.parameters()).device @@ -1018,6 +1021,7 @@ class Flux(nn.Module): block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, + input_vec: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1169,7 +1173,7 @@ class ControlNetFlux(nn.Module): nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1, stride=2), nn.SiLU(), - zero_module(nn.Conv2d(16, 16, 3, padding=1)) + zero_module(nn.Conv2d(16, 16, 3, padding=1)), ) @property @@ -1320,174 +1324,3 @@ class ControlNetFlux(nn.Module): controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) return controlnet_block_samples, controlnet_single_block_samples - - -""" -class FluxUpper(nn.Module): - "" - Transformer model for flow matching on sequences. - "" - - def __init__(self, params: FluxParams): - super().__init__() - - self.params = params - self.in_channels = params.in_channels - self.out_channels = self.in_channels - if params.hidden_size % params.num_heads != 0: - raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") - pe_dim = params.hidden_size // params.num_heads - if sum(params.axes_dim) != pe_dim: - raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) - self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) - - self.double_blocks = nn.ModuleList( - [ - DoubleStreamBlock( - self.hidden_size, - self.num_heads, - mlp_ratio=params.mlp_ratio, - qkv_bias=params.qkv_bias, - ) - for _ in range(params.depth) - ] - ) - - self.gradient_checkpointing = False - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing = True - - self.time_in.enable_gradient_checkpointing() - self.vector_in.enable_gradient_checkpointing() - if self.guidance_in.__class__ != nn.Identity: - self.guidance_in.enable_gradient_checkpointing() - - for block in self.double_blocks: - block.enable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing enabled.") - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing = False - - self.time_in.disable_gradient_checkpointing() - self.vector_in.disable_gradient_checkpointing() - if self.guidance_in.__class__ != nn.Identity: - self.guidance_in.disable_gradient_checkpointing() - - for block in self.double_blocks: - block.disable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing disabled.") - - def forward( - self, - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor | None = None, - txt_attention_mask: Tensor | None = None, - ) -> Tensor: - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - - # running on sequences img - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) - if self.params.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) - vec = vec + self.vector_in(y) - txt = self.txt_in(txt) - - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) - - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - - return img, txt, vec, pe - - -class FluxLower(nn.Module): - "" - Transformer model for flow matching on sequences. - "" - - def __init__(self, params: FluxParams): - super().__init__() - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.out_channels = params.in_channels - - self.single_blocks = nn.ModuleList( - [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(params.depth_single_blocks) - ] - ) - - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) - - self.gradient_checkpointing = False - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing = True - - for block in self.single_blocks: - block.enable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing enabled.") - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing = False - - for block in self.single_blocks: - block.disable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing disabled.") - - def forward( - self, - img: Tensor, - txt: Tensor, - vec: Tensor | None = None, - pe: Tensor | None = None, - txt_attention_mask: Tensor | None = None, - ) -> Tensor: - img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - img = img[:, txt.shape[1] :, ...] - - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - return img -""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8392e559..f3eb8199 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -154,9 +154,8 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - # TODO refactor variable names - cfg_scale = prompt_dict.get("guidance_scale", 1.0) - emb_guidance_scale = prompt_dict.get("scale", 3.5) + emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5) + cfg_scale = prompt_dict.get("scale", 1.0) seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -242,7 +241,7 @@ def sample_image_inference( dtype=weight_dtype, generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, ) - timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None @@ -403,8 +402,8 @@ def denoise( y=torch.cat([neg_l_pooled, vec], dim=0), block_controlnet_hidden_states=block_samples, block_controlnet_single_hidden_states=block_single_samples, - timesteps=t_vec, - guidance=guidance_vec, + timesteps=t_vec.repeat(2), + guidance=guidance_vec.repeat(2), txt_attention_mask=nc_c_t5_attn_mask, ) neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) @@ -680,3 +679,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + parser.add_argument( + "--model_type", + type=str, + choices=["flux", "chroma"], + default="flux", + help="Model type to use for training / トレーニングに使用するモデルタイプ:flux or chroma (default: flux)", + ) diff --git a/library/flux_utils.py b/library/flux_utils.py index dda7c789..3f0a0d63 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -23,6 +23,7 @@ from library.utils import load_safetensors MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" +MODEL_VERSION_CHROMA = "chroma" def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: @@ -97,7 +98,7 @@ def load_flow_model( device: Union[str, torch.device], disable_mmap: bool = False, model_type: str = "flux", -) -> Tuple[str, bool, flux_models.Flux]: +) -> Tuple[bool, flux_models.Flux]: if model_type == "flux": is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL @@ -140,7 +141,7 @@ def load_flow_model( info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") - return model_type, is_schnell, model + return is_schnell, model elif model_type == "chroma": from . import chroma_models @@ -166,7 +167,7 @@ def load_flow_model( info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Chroma: {info}") is_schnell = False # Chroma is not schnell - return model_type, is_schnell, model + return is_schnell, model else: raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.") @@ -203,6 +204,42 @@ def load_controlnet( return controlnet +def dummy_clip_l() -> torch.nn.Module: + """ + Returns a dummy CLIP-L model with the output shape of (N, 77, 768). + """ + return DummyCLIPL() + + +class DummyTextModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.embeddings = torch.nn.Parameter(torch.zeros(1)) + + +class DummyCLIPL(torch.nn.Module): + def __init__(self): + super().__init__() + self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter + self.text_model = DummyTextModel() + + @property + def device(self): + return self.dummy_param.device + + @property + def dtype(self): + return self.dummy_param.dtype + + def forward(self, *args, **kwargs): + """ + Returns a dummy output with the shape of (N, 77, 768). + """ + batch_size = args[0].shape[0] if args else 1 + return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)} + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 8896c047..662a6b2e 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -60,6 +60,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. # ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_SCHNELL = "flux-1-schnell" +ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" @@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" IMPL_FLUX = "https://github.com/black-forest-labs/flux" +IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -125,7 +128,7 @@ def build_metadata( flux: Optional[str] = None, ): """ - sd3: only supports "m", flux: only supports "dev" + sd3: only supports "m", flux: supports "dev", "schnell" or "chroma" """ # if state_dict is None, hash is not calculated @@ -144,6 +147,10 @@ def build_metadata( elif flux is not None: if flux == "dev": arch = ARCH_FLUX_1_DEV + elif flux == "schnell": + arch = ARCH_FLUX_1_SCHNELL + elif flux == "chroma": + arch = ARCH_FLUX_1_CHROMA else: arch = ARCH_FLUX_1_UNKNOWN elif v2: @@ -166,7 +173,10 @@ def build_metadata( if flux is not None: # Flux - impl = IMPL_FLUX + if flux == "chroma": + impl = IMPL_CHROMA + else: + impl = IMPL_FLUX elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI diff --git a/library/train_util.py b/library/train_util.py index 36d419fd..b09963fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3482,7 +3482,7 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, - flux: str = None, + flux: str = None, # "dev", "schnell" or "chroma" ): timestamp = time.time()