mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix: --fp8_vl to work
This commit is contained in:
@@ -250,7 +250,7 @@ def sample_image_inference(
|
||||
arg_c_null = None
|
||||
|
||||
gen_args = SimpleNamespace(
|
||||
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale
|
||||
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled
|
||||
)
|
||||
|
||||
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
|
||||
|
||||
@@ -15,7 +15,7 @@ from transformers.models.t5.modeling_t5 import T5Stack
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library.safetensors_utils import load_safetensors
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -542,7 +542,6 @@ def get_qwen_prompt_embeds_from_tokens(
|
||||
attention_mask = attention_mask.to(device=device)
|
||||
|
||||
if dtype.itemsize == 1: # fp8
|
||||
# TODO dtype should be vlm.dtype?
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
|
||||
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
else:
|
||||
@@ -564,7 +563,7 @@ def get_qwen_prompt_embeds_from_tokens(
|
||||
|
||||
prompt_embeds = hidden_states[:, drop_idx:, :]
|
||||
encoder_attention_mask = attention_mask[:, drop_idx:]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user