feat: implement modulation vector extraction for Chroma and update related methods

This commit is contained in:
Kohya S
2025-07-30 21:34:49 +09:00
parent 450630c6bd
commit 96feb61c0a
4 changed files with 24 additions and 28 deletions

View File

@@ -113,6 +113,8 @@ def denoise(
y_input = b_vec
mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0])
pred = model(
img=b_img,
img_ids=b_img_ids,
@@ -122,6 +124,7 @@ def denoise(
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=b_t5_attn_mask,
mod_vectors=mod_vectors,
)
# classifier free guidance

View File

@@ -341,7 +341,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# get modulation vectors for Chroma
input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
with accelerator.autocast(), torch.no_grad():
mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
@@ -350,15 +351,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
if input_vec is not None:
input_vec.requires_grad_(True)
if mod_vectors is not None:
mod_vectors.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -371,7 +372,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
input_vec=input_vec,
mod_vectors=mod_vectors,
)
return model_pred
@@ -384,7 +385,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
input_vec=input_vec,
mod_vectors=mod_vectors,
)
# unpack latents
@@ -416,7 +417,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None,
mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step

View File

@@ -641,7 +641,10 @@ class Chroma(Flux):
print("Chroma: Gradient checkpointing disabled.")
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
# We extract this logic from forward to clarify the propagation of the gradients
# original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195
# print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}")
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
# TODO: need to add toggle to omit this from schnell but that's not a priority
@@ -654,7 +657,9 @@ class Chroma(Flux):
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
return input_vec
mod_vectors = self.distilled_guidance_layer(input_vec)
return mod_vectors
def forward(
self,
@@ -669,7 +674,7 @@ class Chroma(Flux):
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
attn_padding: int = 1,
input_vec: Tensor | None = None,
mod_vectors: Tensor | None = None,
) -> Tensor:
# print(
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
@@ -684,22 +689,9 @@ class Chroma(Flux):
img = self.img_in(img)
txt = self.txt_in(txt)
if input_vec is None:
# TODO:
# need to fix grad accumulation issue here for now it's in no grad mode
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
# the fan out operation here is deleting the backward graph
# alternatively doing forward pass for every block manually is doable but slow
# custom backward probably be better
if mod_vectors is None: # fallback to the original logic
with torch.no_grad():
input_vec = self.get_input_vec(timesteps, guidance, img.shape[0])
# kohya-ss: I'm not sure why requires_grad is set to True here
# original code: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L217
input_vec.requires_grad = True
mod_vectors = self.distilled_guidance_layer(input_vec)
else:
mod_vectors = self.distilled_guidance_layer(input_vec)
mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0])
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
# calculate text length for each batch instead of masking

View File

@@ -1009,8 +1009,8 @@ class Flux(nn.Module):
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
return None # FLUX.1 does not use input_vec, but Chroma does.
def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
return None # FLUX.1 does not use mod_vectors, but Chroma does.
def forward(
self,
@@ -1024,7 +1024,7 @@ class Flux(nn.Module):
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
input_vec: Tensor | None = None,
mod_vectors: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")