From 404ddb060d04285d72ffff9342542eec71d9c352 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 14:08:54 +0900 Subject: [PATCH] fix: inference for Chroma model --- flux_minimal_inference.py | 30 +++++++++++++++--------------- library/chroma_models.py | 9 +++++++-- library/flux_utils.py | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index a7bff74d..550904d2 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -78,16 +78,19 @@ def denoise( neg_t5_attn_mask: Optional[torch.Tensor] = None, cfg_scale: Optional[float] = None, ): - # this is ignored for schnell - logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") - guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - # prepare classifier free guidance - if neg_txt is not None and neg_vec is not None: + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") + do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0) + + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype) + + if do_cfg: + print("Using classifier free guidance") b_img_ids = torch.cat([img_ids, img_ids], dim=0) b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) b_txt = torch.cat([neg_txt, txt], dim=0) - b_vec = torch.cat([neg_vec, vec], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None if t5_attn_mask is not None and neg_t5_attn_mask is not None: b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) else: @@ -103,17 +106,13 @@ def denoise( t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) # classifier free guidance - if neg_txt is not None and neg_vec is not None: + if do_cfg: b_img = torch.cat([img, img], dim=0) else: b_img = img - # For Chroma model, y might be None, so create dummy tensor - if b_vec is None: - y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor - else: - y_input = b_vec - + y_input = b_vec + pred = model( img=b_img, img_ids=b_img_ids, @@ -126,7 +125,7 @@ def denoise( ) # classifier free guidance - if neg_txt is not None and neg_vec is not None: + if do_cfg: pred_uncond, pred = torch.chunk(pred, 2, dim=0) pred = pred_uncond + cfg_scale * (pred - pred_uncond) @@ -309,7 +308,7 @@ def generate_image( neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check - if torch.isnan(l_pooled).any(): + if l_pooled is not None and torch.isnan(l_pooled).any(): raise ValueError("NaN in l_pooled") if torch.isnan(t5_out).any(): raise ValueError("NaN in t5_out") @@ -329,6 +328,7 @@ def generate_image( img_ids = img_ids.to(device) t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None x = do_sample( accelerator, diff --git a/library/chroma_models.py b/library/chroma_models.py index e1da751b..f725db87 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -240,7 +240,7 @@ class DoubleStreamBlock(nn.Module): k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe, mask=mask) + attn = attention(q, k, v, pe=pe, attn_mask=mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks @@ -343,7 +343,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=mask) + attn = attention(q, k, v, pe=pe, attn_mask=mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) # replaced with compiled fn @@ -555,6 +555,11 @@ class Chroma(Flux): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: + # print( + # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" + # ) + # print(f"timesteps: {timesteps}, guidance: {guidance}") + if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") diff --git a/library/flux_utils.py b/library/flux_utils.py index a5cfcdff..dda7c789 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -146,7 +146,7 @@ def load_flow_model( from . import chroma_models # build model - logger.info("Building Chroma model from BFL checkpoint") + logger.info("Building Chroma model") with torch.device("meta"): model = chroma_models.Chroma(chroma_models.chroma_params) if dtype is not None: