fix: inference for Chroma model

This commit is contained in:
Kohya S
2025-07-20 14:08:54 +09:00
parent 24d2ea86c7
commit 404ddb060d
3 changed files with 23 additions and 18 deletions

View File

@@ -78,16 +78,19 @@ def denoise(
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
# this is ignored for schnell
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# prepare classifier free guidance
if neg_txt is not None and neg_vec is not None:
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype)
if do_cfg:
print("Using classifier free guidance")
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
b_txt = torch.cat([neg_txt, txt], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
else:
@@ -103,17 +106,13 @@ def denoise(
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# classifier free guidance
if neg_txt is not None and neg_vec is not None:
if do_cfg:
b_img = torch.cat([img, img], dim=0)
else:
b_img = img
# For Chroma model, y might be None, so create dummy tensor
if b_vec is None:
y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor
else:
y_input = b_vec
y_input = b_vec
pred = model(
img=b_img,
img_ids=b_img_ids,
@@ -126,7 +125,7 @@ def denoise(
)
# classifier free guidance
if neg_txt is not None and neg_vec is not None:
if do_cfg:
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
@@ -309,7 +308,7 @@ def generate_image(
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
# NaN check
if torch.isnan(l_pooled).any():
if l_pooled is not None and torch.isnan(l_pooled).any():
raise ValueError("NaN in l_pooled")
if torch.isnan(t5_out).any():
raise ValueError("NaN in t5_out")
@@ -329,6 +328,7 @@ def generate_image(
img_ids = img_ids.to(device)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None
x = do_sample(
accelerator,

View File

@@ -240,7 +240,7 @@ class DoubleStreamBlock(nn.Module):
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe, mask=mask)
attn = attention(q, k, v, pe=pe, attn_mask=mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
@@ -343,7 +343,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=mask)
attn = attention(q, k, v, pe=pe, attn_mask=mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
# replaced with compiled fn
@@ -555,6 +555,11 @@ class Chroma(Flux):
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
# print(
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
# )
# print(f"timesteps: {timesteps}, guidance: {guidance}")
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")

View File

@@ -146,7 +146,7 @@ def load_flow_model(
from . import chroma_models
# build model
logger.info("Building Chroma model from BFL checkpoint")
logger.info("Building Chroma model")
with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None: