fix: --fp8_vl to work

This commit is contained in:
Kohya S
2025-09-18 23:46:18 +09:00
parent f6b4bdc83f
commit f834b2e0d4
2 changed files with 3 additions and 4 deletions

View File

@@ -250,7 +250,7 @@ def sample_image_inference(
arg_c_null = None
gen_args = SimpleNamespace(
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled
)
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import

View File

@@ -15,7 +15,7 @@ from transformers.models.t5.modeling_t5 import T5Stack
from accelerate import init_empty_weights
from library.safetensors_utils import load_safetensors
from library.utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
@@ -542,7 +542,6 @@ def get_qwen_prompt_embeds_from_tokens(
attention_mask = attention_mask.to(device=device)
if dtype.itemsize == 1: # fp8
# TODO dtype should be vlm.dtype?
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
else:
@@ -564,7 +563,7 @@ def get_qwen_prompt_embeds_from_tokens(
prompt_embeds = hidden_states[:, drop_idx:, :]
encoder_attention_mask = attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(device=device)
return prompt_embeds, encoder_attention_mask