feat: block swap for inference and initial impl for HunyuanImage LoRA (not working)

This commit is contained in:
Kohya S
2025-09-11 22:15:22 +09:00
parent 5149be5a87
commit 7f983c558d
16 changed files with 1363 additions and 1303 deletions

View File

@@ -29,7 +29,10 @@ koo="koo"
yos="yos"
wn="wn"
hime="hime"
OT="OT"
[files]
extend-exclude = ["_typos.toml", "venv"]
# [files]
# # Extend the default list of files to check
# extend-exclude = [
# "library/hunyuan_image_text_encoder.py",
# ]

View File

@@ -7,8 +7,8 @@ import os
import re
import time
import copy
from types import ModuleType
from typing import Tuple, Optional, List, Any, Dict
from types import ModuleType, SimpleNamespace
from typing import Tuple, Optional, List, Any, Dict, Union
import numpy as np
import torch
@@ -21,7 +21,7 @@ from PIL import Image
from library import hunyuan_image_models, hunyuan_image_text_encoder, hunyuan_image_utils
from library import hunyuan_image_vae
from library.hunyuan_image_vae import HunyuanVAE2D
from library.device_utils import clean_memory_on_device
from library.device_utils import clean_memory_on_device, synchronize_device
from networks import lora_hunyuan_image
@@ -29,7 +29,6 @@ lycoris_available = find_spec("lycoris") is not None
if lycoris_available:
from lycoris.kohya import create_network_from_weights
from library.custom_offloading_utils import synchronize_device
from library.utils import mem_eff_save_file, setup_logging
setup_logging()
@@ -513,10 +512,11 @@ def prepare_text_inputs(
else:
move_models_to_device_if_needed()
embed, mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds(tokenizer_vlm, text_encoder_vlm, prompt)
ocr_mask, embed_byt5, mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds(
tokenizer_byt5, text_encoder_byt5, prompt
)
with torch.no_grad():
embed, mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds(tokenizer_vlm, text_encoder_vlm, prompt)
ocr_mask, embed_byt5, mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds(
tokenizer_byt5, text_encoder_byt5, prompt
)
embed = embed.cpu()
mask = mask.cpu()
embed_byt5 = embed_byt5.cpu()
@@ -531,12 +531,13 @@ def prepare_text_inputs(
else:
move_models_to_device_if_needed()
negative_embed, negative_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds(
tokenizer_vlm, text_encoder_vlm, negative_prompt
)
negative_ocr_mask, negative_embed_byt5, negative_mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds(
tokenizer_byt5, text_encoder_byt5, negative_prompt
)
with torch.no_grad():
negative_embed, negative_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds(
tokenizer_vlm, text_encoder_vlm, negative_prompt
)
negative_ocr_mask, negative_embed_byt5, negative_mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds(
tokenizer_byt5, text_encoder_byt5, negative_prompt
)
negative_embed = negative_embed.cpu()
negative_mask = negative_mask.cpu()
negative_embed_byt5 = negative_embed_byt5.cpu()
@@ -617,6 +618,18 @@ def generate(
# model.move_to_device_except_swap_blocks(device) # Handles block swap correctly
# model.prepare_block_swap_before_forward()
return generate_body(args, model, context, context_null, device, seed)
def generate_body(
args: Union[argparse.Namespace, SimpleNamespace],
model: hunyuan_image_models.HYImageDiffusionTransformer,
context: Dict[str, Any],
context_null: Optional[Dict[str, Any]],
device: torch.device,
seed: int,
) -> torch.Tensor:
# set random generator
seed_g = torch.Generator(device="cpu")
seed_g.manual_seed(seed)
@@ -633,6 +646,10 @@ def generate(
embed_byt5 = context["embed_byt5"].to(device, dtype=torch.bfloat16)
mask_byt5 = context["mask_byt5"].to(device, dtype=torch.bfloat16)
ocr_mask = context["ocr_mask"] # list of bool
if context_null is None:
context_null = context # dummy for unconditional
negative_embed = context_null["embed"].to(device, dtype=torch.bfloat16)
negative_mask = context_null["mask"].to(device, dtype=torch.bfloat16)
negative_embed_byt5 = context_null["embed_byt5"].to(device, dtype=torch.bfloat16)

View File

@@ -0,0 +1,640 @@
import argparse
import copy
from typing import Any, Optional, Union
import argparse
import os
import time
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from accelerate import Accelerator, PartialState
from library import hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
flux_train_utils,
hunyuan_image_models,
hunyuan_image_text_encoder,
hunyuan_image_utils,
hunyuan_image_vae,
sai_model_spec,
sd3_train_utils,
strategy_base,
strategy_hunyuan_image,
train_util,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# region sampling
# TODO commonize with flux_utils
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
dit,
vae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
dit = accelerator.unwrap_model(dit)
if text_encoders is not None:
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
dit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
dit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
dit: hunyuan_image_models.HYImageDiffusionTransformer,
text_encoders: Optional[list[nn.Module]],
vae: hunyuan_image_vae.HunyuanVAE2D,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
cfg_scale = prompt_dict.get("scale", 1.0)
seed = prompt_dict.get("seed")
prompt: str = prompt_dict.get("prompt", "")
flow_shift: float = prompt_dict.get("flow_shift", 4.0)
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
if cfg_scale != 1.0:
logger.info(f"CFG scale: {cfg_scale}")
logger.info(f"flow_shift: {flow_shift}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
vl_embed, vl_mask, byt5_embed, byt5_mask, ocr_mask = encode_prompt(prompt)
arg_c = {
"embed": vl_embed,
"mask": vl_mask,
"embed_byt5": byt5_embed,
"mask_byt5": byt5_mask,
"ocr_mask": ocr_mask,
"prompt": prompt,
}
# encode negative prompts
if cfg_scale != 1.0:
neg_vl_embed, neg_vl_mask, neg_byt5_embed, neg_byt5_mask, neg_ocr_mask = encode_prompt(negative_prompt)
arg_c_null = {
"embed": neg_vl_embed,
"mask": neg_vl_mask,
"embed_byt5": neg_byt5_embed,
"mask_byt5": neg_byt5_mask,
"ocr_mask": neg_ocr_mask,
"prompt": negative_prompt,
}
else:
arg_c_null = None
gen_args = SimpleNamespace(
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale
)
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
latents = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with torch.autocast(accelerator.device.type, vae.dtype, enabled=True), torch.no_grad():
x = x / hunyuan_image_vae.VAE_SCALE_FACTOR
x = vae.decode(x)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# endregion
class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.mixed_precision == "fp16":
logger.warning(
"mixed_precision bf16 is recommended for HunyuanImage-2.1 / HunyuanImage-2.1ではmixed_precision bf16が推奨されます"
)
if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled:
logger.warning(
"fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください"
)
if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet):
logger.info(
"fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます"
)
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# currently offload to cpu for some models
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
split_attn = True
attn_mode = "torch"
model = hunyuan_image_models.load_hunyuan_image_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
)
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
vl_dtype = torch.bfloat16
vl_device = "cpu"
_, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl(
args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
_, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5(
args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
model_version = hunyuan_image_utils.MODEL_VERSION_2_1
return model_version, [text_encoder_vlm, text_encoder_byt5], vae, model
def get_tokenize_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageTokenizeStrategy(args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_hunyuan_image.HunyuanImageTokenizeStrategy):
return [tokenize_strategy.vlm_tokenizer, tokenize_strategy.byt5_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
def get_text_encoding_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders
def get_text_encoders_train_flags(self, args, text_encoders):
# HunyuanImage-2.1 does not support training VLM or byT5
return [False, False]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_hunyuan_image.HunyuanImageTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
# VLM (bf16) and byT5 (fp16) are used for encoding, so we cannot use autocast here
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_hunyuan_image.HunyuanImageTokenizeStrategy = (
strategy_base.TokenizeStrategy.get_strategy()
)
text_encoding_strategy: strategy_hunyuan_image.HunyuanImageTextEncodingStrategy = (
strategy_base.TextEncodingStrategy.get_strategy()
)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
logger.info("move VLM back to cpu")
text_encoders[0].to("cpu")
logger.info("move byT5 back to cpu")
text_encoders[1].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
# for encoding, we need to scale the latents
return latents * hunyuan_image_vae.VAE_SCALE_FACTOR
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: hunyuan_image_models.HYImageDiffusionTransformer,
network,
weight_dtype,
train_unet,
is_train=True,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Predict the noise residual
# ocr_mask is for inference only, so it is not used here
vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = text_encoder_conds
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = unet(noisy_model_input, timesteps / 1000, vlm_embed, vlm_mask, byt5_embed, byt5_mask)
# model prediction and weighting is omitted for HunyuanImage-2.1 currently
# flow matching loss
target = noise - latents
# differential output preservation is not used for HunyuanImage-2.1 currently
return model_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
# if self.model_type != "chroma":
# model_description = "schnell" if self.is_schnell else "dev"
# else:
# model_description = "chroma"
# return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1")
def update_metadata(self, metadata, args):
metadata["ss_model_type"] = args.model_type
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# do not support text encoder training for HunyuanImage-2.1
pass
def cast_text_encoder(self):
return False # VLM is bf16, byT5 is fp16, so do not cast to other dtype
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
# fp8 text encoder for HunyuanImage-2.1 is not supported currently
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
model: hunyuan_image_models.HYImageDiffusionTransformer = unet
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return model
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = HunyuanImageNetworkTrainer()
trainer.train(args)

View File

@@ -1,19 +1,12 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional, Union, Callable, Tuple
from typing import Any, Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
from library.device_utils import clean_memory_on_device, synchronize_device
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
# region block swap utils
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
@@ -71,7 +64,6 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
@@ -97,7 +89,8 @@ class Offloader:
common offloading class
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.block_type = block_type
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
@@ -117,12 +110,16 @@ class Offloader:
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
print(
f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
)
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
print(
f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s"
)
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
@@ -137,7 +134,7 @@ class Offloader:
return
if self.debug:
print(f"Wait for block {block_idx}")
print(f"[{self.block_type}] Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
@@ -146,33 +143,41 @@ class Offloader:
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(len(blocks), blocks_to_swap, device, debug)
def __init__(
self, blocks: list[nn.Module], blocks_to_swap: int, supports_backward: bool, device: torch.device, debug: bool = False
):
block_type = f"{blocks[0].__class__.__name__}" if len(blocks) > 0 else "Unknown"
super().__init__(block_type, len(blocks), blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
self.supports_backward = supports_backward
self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
if self.supports_backward:
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
def set_forward_only(self, forward_only: bool):
self.forward_only = forward_only
def __del__(self):
for handle in self.remove_handles:
handle.remove()
if self.supports_backward:
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -186,7 +191,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
def backward_hook(module, grad_input, grad_output):
if self.debug:
print(f"Backward hook for block {block_index}")
@@ -198,20 +203,20 @@ class ModelOffloader(Offloader):
return backward_hook
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if self.debug:
print("Prepare block devices before forward")
print(f"[{self.block_type}] Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
@@ -221,11 +226,85 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
# check if blocks_to_swap is enabled
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
# if backward is enabled, we do not swap blocks in forward pass more than blocks_to_swap, because it should be on GPU
if not self.forward_only and block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
# endregion
# region cpu offload utils
def to_device(x: Any, device: torch.device) -> Any:
if isinstance(x, torch.Tensor):
return x.to(device)
elif isinstance(x, list):
return [to_device(elem, device) for elem in x]
elif isinstance(x, tuple):
return tuple(to_device(elem, device) for elem in x)
elif isinstance(x, dict):
return {k: to_device(v, device) for k, v in x.items()}
else:
return x
def to_cpu(x: Any) -> Any:
"""
Recursively moves torch.Tensor objects (and containers thereof) to CPU.
Args:
x: A torch.Tensor, or a (possibly nested) list, tuple, or dict containing tensors.
Returns:
The same structure as x, with all torch.Tensor objects moved to CPU.
Non-tensor objects are returned unchanged.
"""
if isinstance(x, torch.Tensor):
return x.cpu()
elif isinstance(x, list):
return [to_cpu(elem) for elem in x]
elif isinstance(x, tuple):
return tuple(to_cpu(elem) for elem in x)
elif isinstance(x, dict):
return {k: to_cpu(v) for k, v in x.items()}
else:
return x
def create_cpu_offloading_wrapper(func: Callable, device: torch.device) -> Callable:
"""
Create a wrapper function that offloads inputs to CPU before calling the original function
and moves outputs back to the specified device.
Args:
func: The original function to wrap.
device: The device to move outputs back to.
Returns:
A wrapped function that offloads inputs to CPU and moves outputs back to the specified device.
"""
def wrapper(orig_func: Callable) -> Callable:
def custom_forward(*inputs):
nonlocal device, orig_func
cuda_inputs = to_device(inputs, device)
outputs = orig_func(*cuda_inputs)
return to_cpu(outputs)
return custom_forward
return wrapper(func)
# endregion

View File

@@ -2,6 +2,7 @@ import functools
import gc
import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
@@ -51,6 +52,15 @@ def clean_memory_on_device(device: torch.device):
torch.mps.empty_cache()
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""

View File

@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
from accelerate import init_empty_weights
from library import custom_offloading_utils
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.utils import setup_logging
@@ -132,6 +133,74 @@ class HYImageDiffusionTransformer(nn.Module):
self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels, nn.SiLU)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
print(f"HunyuanImage-2.1: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
print("HunyuanImage-2.1: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, supports_backward, device
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, supports_backward, device
)
# , debug=True
print(
f"HunyuanImage-2.1: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def get_rotary_pos_embed(self, rope_sizes):
"""
Generate 2D rotary position embeddings for image tokens.
@@ -255,16 +324,29 @@ class HYImageDiffusionTransformer(nn.Module):
txt = txt[:, :max_txt_len, :]
txt_seq_len = txt.shape[1]
input_device = img.device
# Process through double-stream blocks (separate image/text attention)
for index, block in enumerate(self.double_blocks):
if self.blocks_to_swap:
self.offloader_double.wait_for_block(index)
img, txt = block(img, txt, vec, freqs_cis, seq_lens)
if self.blocks_to_swap:
self.offloader_double.submit_move_blocks(self.double_blocks, index)
# Concatenate image and text tokens for joint processing
x = torch.cat((img, txt), 1)
# Process through single-stream blocks (joint attention)
for index, block in enumerate(self.single_blocks):
if self.blocks_to_swap:
self.offloader_single.wait_for_block(index)
x = block(x, vec, txt_seq_len, freqs_cis, seq_lens)
if self.blocks_to_swap:
self.offloader_single.submit_move_blocks(self.single_blocks, index)
x = x.to(input_device)
vec = vec.to(input_device)
img = x[:, :img_seq_len, ...]

View File

@@ -6,6 +6,7 @@ import torch
import torch.nn as nn
from einops import rearrange
from library import custom_offloading_utils
from library.attention import attention
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
from library.attention import attention
@@ -608,7 +609,18 @@ class MMDoubleStreamBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
def forward(
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Extract modulation parameters for image and text streams
@@ -688,6 +700,18 @@ class MMDoubleStreamBlock(nn.Module):
return img, txt
def forward(
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.gradient_checkpointing and self.training:
forward_fn = self._forward
if self.cpu_offload_checkpointing:
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, seq_lens, use_reentrant=False)
else:
return self._forward(img, txt, vec, freqs_cis, seq_lens)
class MMSingleStreamBlock(nn.Module):
"""
@@ -748,7 +772,18 @@ class MMSingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=nn.SiLU)
def forward(
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
@@ -800,5 +835,22 @@ class MMSingleStreamBlock(nn.Module):
return x + apply_gate(output, gate=mod_gate)
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
seq_lens: list[int] = None,
) -> torch.Tensor:
if self.gradient_checkpointing and self.training:
forward_fn = self._forward
if self.cpu_offload_checkpointing:
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, txt_len, freqs_cis, seq_lens, use_reentrant=False)
else:
return self._forward(x, vec, txt_len, freqs_cis, seq_lens)
# endregion

View File

@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
BYT5_TOKENIZER_PATH = "google/byt5-small"
QWEN_2_5_VL_IMAGE_ID ="Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_IMAGE_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
# Copy from Glyph-SDXL-V2
@@ -228,6 +228,7 @@ def load_byt5(
info = byt5_text_encoder.load_state_dict(sd, strict=True, assign=True)
byt5_text_encoder.to(device)
byt5_text_encoder.eval()
logger.info(f"BYT5 text encoder loaded with info: {info}")
return byt5_tokenizer, byt5_text_encoder
@@ -404,6 +405,7 @@ def load_qwen2_5_vl(
info = qwen2_5_vl.load_state_dict(sd, strict=True, assign=True)
logger.info(f"Loaded Qwen2.5-VL: {info}")
qwen2_5_vl.to(device)
qwen2_5_vl.eval()
if dtype is not None:
if dtype.itemsize == 1: # fp8
@@ -494,43 +496,59 @@ def load_qwen2_5_vl(
# Load tokenizer
logger.info(f"Loading tokenizer from {QWEN_2_5_VL_IMAGE_ID}")
tokenizer = Qwen2Tokenizer.from_pretrained(QWEN_2_5_VL_IMAGE_ID)
tokenizer = Qwen2Tokenizer.from_pretrained(QWEN_2_5_VL_IMAGE_ID)
return tokenizer, qwen2_5_vl
TOKENIZER_MAX_LENGTH = 1024
PROMPT_TEMPLATE_ENCODE_START_IDX = 34
def get_qwen_prompt_embeds(
tokenizer: Qwen2Tokenizer, vlm: Qwen2_5_VLForConditionalGeneration, prompt: Union[str, list[str]] = None
):
tokenizer_max_length = 1024
) -> Tuple[torch.Tensor, torch.Tensor]:
input_ids, mask = get_qwen_tokens(tokenizer, prompt)
return get_qwen_prompt_embeds_from_tokens(vlm, input_ids, mask)
def get_qwen_tokens(tokenizer: Qwen2Tokenizer, prompt: Union[str, list[str]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
tokenizer_max_length = TOKENIZER_MAX_LENGTH
# HunyuanImage-2.1 does not use "<|im_start|>assistant\n" in the prompt template
prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
# \n<|im_start|>assistant\n"
prompt_template_encode_start_idx = 34
prompt_template_encode_start_idx = PROMPT_TEMPLATE_ENCODE_START_IDX
# default_sample_size = 128
device = vlm.device
dtype = vlm.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
template = prompt_template_encode
drop_idx = prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(
device
)
txt_tokens = tokenizer(txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt")
return txt_tokens.input_ids, txt_tokens.attention_mask
def get_qwen_prompt_embeds_from_tokens(
vlm: Qwen2_5_VLForConditionalGeneration, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenizer_max_length = TOKENIZER_MAX_LENGTH
drop_idx = PROMPT_TEMPLATE_ENCODE_START_IDX
device = vlm.device
dtype = vlm.dtype
input_ids = input_ids.to(device=device)
attention_mask = attention_mask.to(device=device)
if dtype.itemsize == 1: # fp8
# TODO dtype should be vlm.dtype?
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
encoder_hidden_states = vlm(
input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True
)
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
else:
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=dtype, enabled=True):
encoder_hidden_states = vlm(
input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True
)
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
hidden_states = encoder_hidden_states.hidden_states[-3] # use the 3rd last layer's hidden states for HunyuanImage-2.1
if hidden_states.shape[1] > tokenizer_max_length + drop_idx:
logger.warning(f"Hidden states shape {hidden_states.shape} exceeds max length {tokenizer_max_length + drop_idx}")
@@ -545,7 +563,7 @@ def get_qwen_prompt_embeds(
# ----------------------------------------------------------
prompt_embeds = hidden_states[:, drop_idx:, :]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
encoder_attention_mask = attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, encoder_attention_mask
@@ -565,17 +583,42 @@ def format_prompt(texts, styles):
return prompt
BYT5_MAX_LENGTH = 128
def get_glyph_prompt_embeds(
tokenizer: T5Tokenizer, text_encoder: T5Stack, prompt: Union[str, list[str]] = None
tokenizer: T5Tokenizer, text_encoder: T5Stack, prompt: Optional[str] = None
) -> Tuple[list[bool], torch.Tensor, torch.Tensor]:
byt5_max_length = 128
if not prompt:
byt5_tokens, byt5_text_mask = get_byt5_text_tokens(tokenizer, prompt)
return get_byt5_prompt_embeds_from_tokens(text_encoder, byt5_tokens, byt5_text_mask)
def get_byt5_prompt_embeds_from_tokens(
text_encoder: T5Stack, byt5_text_ids: Optional[torch.Tensor], byt5_text_mask: Optional[torch.Tensor]
) -> Tuple[list[bool], torch.Tensor, torch.Tensor]:
byt5_max_length = BYT5_MAX_LENGTH
if byt5_text_ids is None or byt5_text_mask is None:
return (
[False],
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
)
byt5_text_ids = byt5_text_ids.to(device=text_encoder.device)
byt5_text_mask = byt5_text_mask.to(device=text_encoder.device)
with torch.no_grad(), torch.autocast(device_type=text_encoder.device.type, dtype=text_encoder.dtype, enabled=True):
byt5_prompt_embeds = text_encoder(byt5_text_ids, attention_mask=byt5_text_mask.float())
byt5_emb = byt5_prompt_embeds[0]
return [True], byt5_emb, byt5_text_mask
def get_byt5_text_tokens(tokenizer, prompt):
if not prompt:
return None, None
try:
text_prompt_texts = []
# pattern_quote_single = r"\'(.*?)\'"
@@ -594,56 +637,26 @@ def get_glyph_prompt_embeds(
text_prompt_texts.extend(matches_quote_chinese_double)
if not text_prompt_texts:
return (
[False],
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
)
return None, None
text_prompt_style_list = [{"color": None, "font-family": None} for _ in range(len(text_prompt_texts))]
glyph_text_formatted = format_prompt(text_prompt_texts, text_prompt_style_list)
logger.info(f"Glyph text formatted: {glyph_text_formatted}")
byt5_text_ids, byt5_text_mask = get_byt5_text_tokens(tokenizer, byt5_max_length, glyph_text_formatted)
byt5_text_inputs = tokenizer(
glyph_text_formatted,
padding="max_length",
max_length=BYT5_MAX_LENGTH,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
byt5_text_ids = byt5_text_ids.to(device=text_encoder.device)
byt5_text_mask = byt5_text_mask.to(device=text_encoder.device)
byt5_text_ids = byt5_text_inputs.input_ids
byt5_text_mask = byt5_text_inputs.attention_mask
byt5_prompt_embeds = text_encoder(byt5_text_ids, attention_mask=byt5_text_mask.float())
byt5_emb = byt5_prompt_embeds[0]
return [True], byt5_emb, byt5_text_mask
return byt5_text_ids, byt5_text_mask
except Exception as e:
logger.warning(f"Warning: Error in glyph encoding, using fallback: {e}")
return (
[False],
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
)
def get_byt5_text_tokens(tokenizer, max_length, text_list):
"""
Get byT5 text tokens.
Args:
tokenizer: The tokenizer object
max_length: Maximum token length
text_list: List or string of text
Returns:
Tuple of (byt5_text_ids, byt5_text_mask)
"""
if isinstance(text_list, list):
text_prompt = " ".join(text_list)
else:
text_prompt = text_list
byt5_text_inputs = tokenizer(
text_prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt"
)
byt5_text_ids = byt5_text_inputs.input_ids
byt5_text_mask = byt5_text_inputs.attention_mask
return byt5_text_ids, byt5_text_mask
return None, None

View File

@@ -5,6 +5,18 @@ import math
from typing import Tuple, Union, Optional
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MODEL_VERSION_2_1 = "hunyuan-image-2.1"
# region model
def _to_tuple(x, dim=2):
"""
@@ -206,7 +218,7 @@ def reshape_for_broadcast(
x.shape[1],
x.shape[-1],
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
@@ -248,7 +260,7 @@ def apply_rotary_emb(
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(device), sin.to(device)
# Apply rotation: x' = x * cos + rotate_half(x) * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
@@ -256,6 +268,11 @@ def apply_rotary_emb(
return xq_out, xk_out
# endregion
# region inference
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate timesteps and sigmas for diffusion sampling.
@@ -291,6 +308,9 @@ def step(latents, noise_pred, sigmas, step_i):
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
# endregion
# region AdaptiveProjectedGuidance
@@ -298,6 +318,7 @@ class MomentumBuffer:
"""
Exponential moving average buffer for APG momentum.
"""
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
@@ -318,10 +339,10 @@ def normalized_guidance_apg(
):
"""
Apply normalized adaptive projected guidance.
Projects the guidance vector to reduce over-saturation while maintaining
directional control by decomposing into parallel and orthogonal components.
Args:
pred_cond: Conditional prediction.
pred_uncond: Unconditional prediction.
@@ -330,7 +351,7 @@ def normalized_guidance_apg(
eta: Scaling factor for parallel component.
norm_threshold: Maximum norm for guidance vector clipping.
use_original_formulation: Whether to use original APG formulation.
Returns:
Guided prediction tensor.
"""
@@ -366,10 +387,11 @@ def normalized_guidance_apg(
class AdaptiveProjectedGuidance:
"""
Adaptive Projected Guidance for classifier-free guidance.
Implements APG which projects the guidance vector to reduce over-saturation
while maintaining directional control.
"""
def __init__(
self,
guidance_scale: float = 7.5,
@@ -406,9 +428,6 @@ class AdaptiveProjectedGuidance:
return pred
# endregion
def apply_classifier_free_guidance(
noise_pred_text: torch.Tensor,
noise_pred_uncond: torch.Tensor,
@@ -459,3 +478,6 @@ def apply_classifier_free_guidance(
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
return noise_pred
# endregion

View File

@@ -7,7 +7,7 @@ import torch
from tqdm import tqdm
from library.custom_offloading_utils import synchronize_device
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.utils import MemoryEfficientSafeOpen, setup_logging

View File

@@ -37,18 +37,16 @@ metadata = {
BASE_METADATA = {
# === MUST ===
"modelspec.sai_model_spec": "1.0.1",
"modelspec.sai_model_spec": "1.0.1",
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === SHOULD ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
"modelspec.hash_sha256": None,
# === CAN===
"modelspec.implementation_version": None,
"modelspec.license": None,
@@ -81,6 +79,8 @@ ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
ARCH_FLUX_1_UNKNOWN = "flux-1"
ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -91,6 +91,7 @@ IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -102,20 +103,20 @@ class ModelSpecMetadata:
ModelSpec 1.0.1 compliant metadata for safetensors models.
All fields correspond to modelspec.* keys in the final metadata.
"""
# === MUST ===
architecture: str
implementation: str
title: str
resolution: str
sai_model_spec: str = "1.0.1"
# === SHOULD ===
description: str | None = None
author: str | None = None
date: str | None = None
hash_sha256: str | None = None
# === CAN ===
implementation_version: str | None = None
license: str | None = None
@@ -131,14 +132,14 @@ class ModelSpecMetadata:
is_negative_embedding: str | None = None
unet_dtype: str | None = None
vae_dtype: str | None = None
# === Additional metadata ===
additional_fields: dict[str, str] = field(default_factory=dict)
def to_metadata_dict(self) -> dict[str, str]:
"""Convert dataclass to metadata dictionary with modelspec. prefixes."""
metadata = {}
# Add all non-None fields with modelspec prefix
for field_name, value in self.__dict__.items():
if field_name == "additional_fields":
@@ -150,14 +151,14 @@ class ModelSpecMetadata:
metadata[f"modelspec.{key}"] = val
elif value is not None:
metadata[f"modelspec.{field_name}"] = value
return metadata
@classmethod
def from_args(cls, args, **kwargs) -> "ModelSpecMetadata":
"""Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields."""
metadata_fields = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
@@ -166,7 +167,7 @@ class ModelSpecMetadata:
# Remove metadata_ prefix
field_name = attr_name[9:] # len("metadata_") = 9
metadata_fields[field_name] = value
# Handle known standard fields
standard_fields = {
"author": metadata_fields.pop("author", None),
@@ -174,30 +175,25 @@ class ModelSpecMetadata:
"license": metadata_fields.pop("license", None),
"tags": metadata_fields.pop("tags", None),
}
# Remove None values
standard_fields = {k: v for k, v in standard_fields.items() if v is not None}
# Merge with kwargs and remaining metadata fields
all_fields = {**standard_fields, **kwargs}
if metadata_fields:
all_fields["additional_fields"] = metadata_fields
return cls(**all_fields)
def determine_architecture(
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
model_config: dict[str, str] | None = None
v2: bool, v_parameterization: bool, sdxl: bool, lora: bool, textual_inversion: bool, model_config: dict[str, str] | None = None
) -> str:
"""Determine model architecture string from parameters."""
model_config = model_config or {}
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif "sd3" in model_config:
@@ -218,17 +214,23 @@ def determine_architecture(
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif "hunyuan_image" in model_config:
hunyuan_image_type = model_config["hunyuan_image"]
if hunyuan_image_type == "2.1":
arch = ARCH_HUNYUAN_IMAGE_2_1
else:
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
# Add adapter suffix
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
return arch
@@ -237,12 +239,12 @@ def determine_implementation(
textual_inversion: bool,
sdxl: bool,
model_config: dict[str, str] | None = None,
is_stable_diffusion_ckpt: bool | None = None
is_stable_diffusion_ckpt: bool | None = None,
) -> str:
"""Determine implementation string from parameters."""
model_config = model_config or {}
if "flux" in model_config:
if model_config["flux"] == "chroma":
return IMPL_CHROMA
@@ -265,16 +267,16 @@ def get_implementation_version() -> str:
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root
timeout=5
timeout=5,
)
if result.returncode == 0:
commit_hash = result.stdout.strip()
return f"sd-scripts/{commit_hash}"
else:
logger.warning("Failed to get git commit hash, using fallback")
return "sd-scripts/unknown"
except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e:
logger.warning(f"Could not determine git commit: {e}")
return "sd-scripts/unknown"
@@ -284,19 +286,19 @@ def file_to_data_url(file_path: str) -> str:
"""Convert a file path to a data URL for embedding in metadata."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
# Get MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
# Default to binary if we can't detect
mime_type = "application/octet-stream"
# Read file and encode as base64
with open(file_path, "rb") as f:
file_data = f.read()
encoded_data = base64.b64encode(file_data).decode("ascii")
return f"data:{mime_type};base64,{encoded_data}"
@@ -305,12 +307,12 @@ def determine_resolution(
sdxl: bool = False,
model_config: dict[str, str] | None = None,
v2: bool = False,
v_parameterization: bool = False
v_parameterization: bool = False,
) -> str:
"""Determine resolution string from parameters."""
model_config = model_config or {}
if reso is not None:
# Handle comma separated string
if isinstance(reso, str):
@@ -318,21 +320,18 @@ def determine_resolution(
# Handle single int
if isinstance(reso, int):
reso = (reso, reso)
# Handle single-element tuple
# Handle single-element tuple
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if (sdxl or
"sd3" in model_config or
"flux" in model_config or
"lumina" in model_config):
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)
else:
reso = (512, 512)
return f"{reso[0]}x{reso[1]}"
@@ -388,23 +387,19 @@ def build_metadata_dataclass(
) -> ModelSpecMetadata:
"""
Build ModelSpec 1.0.1 compliant metadata dataclass.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# Use helper functions for complex logic
architecture = determine_architecture(
v2, v_parameterization, sdxl, lora, textual_inversion, model_config
)
architecture = determine_architecture(v2, v_parameterization, sdxl, lora, textual_inversion, model_config)
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
implementation = determine_implementation(
lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt
)
implementation = determine_implementation(lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt)
if title is None:
if lora:
@@ -421,9 +416,7 @@ def build_metadata_dataclass(
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
# Use helper function for resolution
resolution = determine_resolution(
reso, sdxl, model_config, v2, v_parameterization
)
resolution = determine_resolution(reso, sdxl, model_config, v2, v_parameterization)
# Handle prediction type - Flux models don't use prediction_type
model_config = model_config or {}
@@ -488,7 +481,7 @@ def build_metadata_dataclass(
prediction_type=prediction_type,
timestep_range=timestep_range,
encoder_layer=encoder_layer,
additional_fields=processed_optional_metadata
additional_fields=processed_optional_metadata,
)
return metadata
@@ -518,7 +511,7 @@ def build_metadata(
"""
Build ModelSpec 1.0.1 compliant metadata for safetensors models.
Legacy function that returns dict - prefer build_metadata_dataclass for new code.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
@@ -545,7 +538,7 @@ def build_metadata(
model_config=model_config,
optional_metadata=optional_metadata,
)
return metadata_obj.to_metadata_dict()
@@ -581,7 +574,7 @@ def build_merged_from(models: list[str]) -> str:
def add_model_spec_arguments(parser: argparse.ArgumentParser):
"""Add all ModelSpec metadata arguments to the parser."""
parser.add_argument(
"--metadata_title",
type=str,

View File

@@ -0,0 +1,187 @@
import os
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import AutoTokenizer, Qwen2Tokenizer
from library import hunyuan_image_text_encoder, hunyuan_image_vae, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class HunyuanImageTokenizeStrategy(TokenizeStrategy):
def __init__(self, tokenizer_cache_dir: Optional[str] = None) -> None:
self.vlm_tokenizer = self._load_tokenizer(
Qwen2Tokenizer, hunyuan_image_text_encoder.QWEN_2_5_VL_IMAGE_ID, tokenizer_cache_dir=tokenizer_cache_dir
)
self.byt5_tokenizer = self._load_tokenizer(
AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, tokenizer_cache_dir=tokenizer_cache_dir
)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
vlm_tokens, vlm_mask = hunyuan_image_text_encoder.get_qwen_tokens(self.vlm_tokenizer, text)
byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text)
return [vlm_tokens, vlm_mask, byt5_tokens, byt5_mask]
class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
vlm_tokens, vlm_mask, byt5_tokens, byt5_mask = tokens
qwen2vlm, byt5 = models
# autocast and no_grad are handled in hunyuan_image_text_encoder
vlm_embed, vlm_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds_from_tokens(qwen2vlm, vlm_tokens, vlm_mask)
ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens(
byt5, byt5_tokens, byt5_mask
)
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
vln_embed = data["vlm_embed"]
vlm_mask = data["vlm_mask"]
byt5_embed = data["byt5_embed"]
byt5_mask = data["byt5_mask"]
ocr_mask = data["ocr_mask"]
return [vln_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
huyuan_image_text_encoding_strategy: HunyuanImageTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = huyuan_image_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = np.array(ocr_mask, dtype=bool)
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(32, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self, vae: hunyuan_image_vae.HunyuanVAE2D, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -3588,6 +3588,7 @@ def get_sai_model_spec_dataclass(
sd3: str = None,
flux: str = None,
lumina: str = None,
hunyuan_image: str = None,
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
@@ -3617,6 +3618,8 @@ def get_sai_model_spec_dataclass(
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina
if hunyuan_image is not None:
model_config["hunyuan_image"] = hunyuan_image
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
@@ -3987,11 +3990,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
"--full_fp16",
action="store_true",
help="fp16 training including gradients, some models are not supported / 勾配も含めてfp16で学習する、一部のモデルではサポートされていません",
)
parser.add_argument(
"--full_bf16",
action="store_true",
help="bf16 training including gradients, some models are not supported / 勾配も含めてbf16で学習する、一部のモデルではサポートされていません",
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--fp8_base",
action="store_true",
help="use fp8 for base model, some models are not supported / base modelにfp8を使う、一部のモデルではサポートされていません",
)
parser.add_argument(
"--ddp_timeout",
@@ -6305,6 +6318,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
m = re.match(r"fs (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["flow_shift"] = m.group(1)
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)

View File

@@ -713,6 +713,10 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
@classmethod
def get_qkv_mlp_split_dims(cls) -> List[int]:
return [3072] * 3 + [12288]
def __init__(
self,
text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
@@ -842,7 +846,7 @@ class LoRANetwork(torch.nn.Module):
break
# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
if dim is None and modules_dim is None:
if dim is None and modules_dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
@@ -901,9 +905,9 @@ class LoRANetwork(torch.nn.Module):
split_dims = None
if is_flux and split_qkv:
if "double" in lora_name and "qkv" in lora_name:
split_dims = [3072] * 3
(split_dims,) = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in lora_name and "linear1" in lora_name:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
lora = module_class(
lora_name,
@@ -1036,9 +1040,9 @@ class LoRANetwork(torch.nn.Module):
# split qkv
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
else:
continue
@@ -1092,9 +1096,9 @@ class LoRANetwork(torch.nn.Module):
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
else:
new_state_dict[key] = state_dict[key]
continue

File diff suppressed because it is too large Load Diff

View File

@@ -475,6 +475,9 @@ class NetworkTrainer:
return loss.mean()
def cast_text_encoder(self):
return True # default for other than HunyuanImage
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -832,7 +835,7 @@ class NetworkTrainer:
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
if t_enc.device.type != "cpu" and self.cast_text_encoder():
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8