diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 86e8e1b1..d5f2d8d9 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -113,6 +113,8 @@ def denoise( y_input = b_vec + mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0]) + pred = model( img=b_img, img_ids=b_img_ids, @@ -122,6 +124,7 @@ def denoise( timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=b_t5_attn_mask, + mod_vectors=mod_vectors, ) # classifier free guidance diff --git a/flux_train_network.py b/flux_train_network.py index 13e9ae2a..2d9ab248 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -341,7 +341,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # get modulation vectors for Chroma - input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) + with accelerator.autocast(), torch.no_grad(): + mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) @@ -350,15 +351,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) - if input_vec is not None: - input_vec.requires_grad_(True) + if mod_vectors is not None: + mod_vectors.requires_grad_(True) # Predict the noise residual l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec): + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -371,7 +372,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps / 1000, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, - input_vec=input_vec, + mod_vectors=mod_vectors, ) return model_pred @@ -384,7 +385,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps, guidance_vec=guidance_vec, t5_attn_mask=t5_attn_mask, - input_vec=input_vec, + mod_vectors=mod_vectors, ) # unpack latents @@ -416,7 +417,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps[diff_output_pr_indices], guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, - input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None, + mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step diff --git a/library/chroma_models.py b/library/chroma_models.py index 0c93f526..d5ac1f39 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -641,7 +641,10 @@ class Chroma(Flux): print("Chroma: Gradient checkpointing disabled.") - def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + # We extract this logic from forward to clarify the propagation of the gradients + # original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195 + # print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}") distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) # TODO: need to add toggle to omit this from schnell but that's not a priority @@ -654,7 +657,9 @@ class Chroma(Flux): timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) # then and only then we could concatenate it together input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) - return input_vec + + mod_vectors = self.distilled_guidance_layer(input_vec) + return mod_vectors def forward( self, @@ -669,7 +674,7 @@ class Chroma(Flux): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, attn_padding: int = 1, - input_vec: Tensor | None = None, + mod_vectors: Tensor | None = None, ) -> Tensor: # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" @@ -684,22 +689,9 @@ class Chroma(Flux): img = self.img_in(img) txt = self.txt_in(txt) - if input_vec is None: - # TODO: - # need to fix grad accumulation issue here for now it's in no grad mode - # besides, i don't want to wash out the PFP that's trained on this model weights anyway - # the fan out operation here is deleting the backward graph - # alternatively doing forward pass for every block manually is doable but slow - # custom backward probably be better + if mod_vectors is None: # fallback to the original logic with torch.no_grad(): - input_vec = self.get_input_vec(timesteps, guidance, img.shape[0]) - - # kohya-ss: I'm not sure why requires_grad is set to True here - # original code: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L217 - input_vec.requires_grad = True - mod_vectors = self.distilled_guidance_layer(input_vec) - else: - mod_vectors = self.distilled_guidance_layer(input_vec) + mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0]) mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) # calculate text length for each batch instead of masking diff --git a/library/flux_models.py b/library/flux_models.py index 63d699d4..d2d7e06c 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1009,8 +1009,8 @@ class Flux(nn.Module): self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) - def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: - return None # FLUX.1 does not use input_vec, but Chroma does. + def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + return None # FLUX.1 does not use mod_vectors, but Chroma does. def forward( self, @@ -1024,7 +1024,7 @@ class Flux(nn.Module): block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, - input_vec: Tensor | None = None, + mod_vectors: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.")