mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
fix: inference for Chroma model
This commit is contained in:
@@ -78,16 +78,19 @@ def denoise(
|
||||
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
cfg_scale: Optional[float] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
# prepare classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
||||
do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0)
|
||||
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
if do_cfg:
|
||||
print("Using classifier free guidance")
|
||||
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
||||
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
||||
b_txt = torch.cat([neg_txt, txt], dim=0)
|
||||
b_vec = torch.cat([neg_vec, vec], dim=0)
|
||||
b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None
|
||||
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
||||
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
else:
|
||||
@@ -103,17 +106,13 @@ def denoise(
|
||||
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
if do_cfg:
|
||||
b_img = torch.cat([img, img], dim=0)
|
||||
else:
|
||||
b_img = img
|
||||
|
||||
# For Chroma model, y might be None, so create dummy tensor
|
||||
if b_vec is None:
|
||||
y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor
|
||||
else:
|
||||
y_input = b_vec
|
||||
|
||||
y_input = b_vec
|
||||
|
||||
pred = model(
|
||||
img=b_img,
|
||||
img_ids=b_img_ids,
|
||||
@@ -126,7 +125,7 @@ def denoise(
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
if do_cfg:
|
||||
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
||||
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
||||
|
||||
@@ -309,7 +308,7 @@ def generate_image(
|
||||
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
||||
|
||||
# NaN check
|
||||
if torch.isnan(l_pooled).any():
|
||||
if l_pooled is not None and torch.isnan(l_pooled).any():
|
||||
raise ValueError("NaN in l_pooled")
|
||||
if torch.isnan(t5_out).any():
|
||||
raise ValueError("NaN in t5_out")
|
||||
@@ -329,6 +328,7 @@ def generate_image(
|
||||
|
||||
img_ids = img_ids.to(device)
|
||||
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||
neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None
|
||||
|
||||
x = do_sample(
|
||||
accelerator,
|
||||
|
||||
@@ -240,7 +240,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -343,7 +343,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
# replaced with compiled fn
|
||||
@@ -555,6 +555,11 @@ class Chroma(Flux):
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# print(
|
||||
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
|
||||
# )
|
||||
# print(f"timesteps: {timesteps}, guidance: {guidance}")
|
||||
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ def load_flow_model(
|
||||
from . import chroma_models
|
||||
|
||||
# build model
|
||||
logger.info("Building Chroma model from BFL checkpoint")
|
||||
logger.info("Building Chroma model")
|
||||
with torch.device("meta"):
|
||||
model = chroma_models.Chroma(chroma_models.chroma_params)
|
||||
if dtype is not None:
|
||||
|
||||
Reference in New Issue
Block a user