diff --git a/hunyuan_image_train_network.py b/hunyuan_image_train_network.py index a3c0cd89..60aa2178 100644 --- a/hunyuan_image_train_network.py +++ b/hunyuan_image_train_network.py @@ -250,7 +250,7 @@ def sample_image_inference( arg_c_null = None gen_args = SimpleNamespace( - image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale + image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled ) from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import diff --git a/library/hunyuan_image_text_encoder.py b/library/hunyuan_image_text_encoder.py index 509f9bd2..2171b410 100644 --- a/library/hunyuan_image_text_encoder.py +++ b/library/hunyuan_image_text_encoder.py @@ -15,7 +15,7 @@ from transformers.models.t5.modeling_t5 import T5Stack from accelerate import init_empty_weights from library.safetensors_utils import load_safetensors -from library.utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -542,7 +542,6 @@ def get_qwen_prompt_embeds_from_tokens( attention_mask = attention_mask.to(device=device) if dtype.itemsize == 1: # fp8 - # TODO dtype should be vlm.dtype? with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) else: @@ -564,7 +563,7 @@ def get_qwen_prompt_embeds_from_tokens( prompt_embeds = hidden_states[:, drop_idx:, :] encoder_attention_mask = attention_mask[:, drop_idx:] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds.to(device=device) return prompt_embeds, encoder_attention_mask