mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: implement modulation vector extraction for Chroma and update related methods
This commit is contained in:
@@ -113,6 +113,8 @@ def denoise(
|
||||
|
||||
y_input = b_vec
|
||||
|
||||
mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0])
|
||||
|
||||
pred = model(
|
||||
img=b_img,
|
||||
img_ids=b_img_ids,
|
||||
@@ -122,6 +124,7 @@ def denoise(
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=b_t5_attn_mask,
|
||||
mod_vectors=mod_vectors,
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
@@ -341,7 +341,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# get modulation vectors for Chroma
|
||||
input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
@@ -350,15 +351,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t.requires_grad_(True)
|
||||
img_ids.requires_grad_(True)
|
||||
guidance_vec.requires_grad_(True)
|
||||
if input_vec is not None:
|
||||
input_vec.requires_grad_(True)
|
||||
if mod_vectors is not None:
|
||||
mod_vectors.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -371,7 +372,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
input_vec=input_vec,
|
||||
mod_vectors=mod_vectors,
|
||||
)
|
||||
return model_pred
|
||||
|
||||
@@ -384,7 +385,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps,
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
input_vec=input_vec,
|
||||
mod_vectors=mod_vectors,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
@@ -416,7 +417,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None,
|
||||
mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
|
||||
@@ -641,7 +641,10 @@ class Chroma(Flux):
|
||||
|
||||
print("Chroma: Gradient checkpointing disabled.")
|
||||
|
||||
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
# We extract this logic from forward to clarify the propagation of the gradients
|
||||
# original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195
|
||||
|
||||
# print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}")
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
@@ -654,7 +657,9 @@ class Chroma(Flux):
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
return input_vec
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
return mod_vectors
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -669,7 +674,7 @@ class Chroma(Flux):
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
attn_padding: int = 1,
|
||||
input_vec: Tensor | None = None,
|
||||
mod_vectors: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# print(
|
||||
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
|
||||
@@ -684,22 +689,9 @@ class Chroma(Flux):
|
||||
img = self.img_in(img)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if input_vec is None:
|
||||
# TODO:
|
||||
# need to fix grad accumulation issue here for now it's in no grad mode
|
||||
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
|
||||
# the fan out operation here is deleting the backward graph
|
||||
# alternatively doing forward pass for every block manually is doable but slow
|
||||
# custom backward probably be better
|
||||
if mod_vectors is None: # fallback to the original logic
|
||||
with torch.no_grad():
|
||||
input_vec = self.get_input_vec(timesteps, guidance, img.shape[0])
|
||||
|
||||
# kohya-ss: I'm not sure why requires_grad is set to True here
|
||||
# original code: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L217
|
||||
input_vec.requires_grad = True
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
else:
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0])
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
# calculate text length for each batch instead of masking
|
||||
|
||||
@@ -1009,8 +1009,8 @@ class Flux(nn.Module):
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||
|
||||
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
return None # FLUX.1 does not use input_vec, but Chroma does.
|
||||
def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
return None # FLUX.1 does not use mod_vectors, but Chroma does.
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1024,7 +1024,7 @@ class Flux(nn.Module):
|
||||
block_controlnet_single_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
input_vec: Tensor | None = None,
|
||||
mod_vectors: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
Reference in New Issue
Block a user