mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
616 lines
21 KiB
Python
616 lines
21 KiB
Python
# Anima Training Utilities
|
||
|
||
import argparse
|
||
import gc
|
||
import math
|
||
import os
|
||
import time
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import torch
|
||
from accelerate import Accelerator
|
||
from tqdm import tqdm
|
||
from PIL import Image
|
||
|
||
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
|
||
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
|
||
|
||
init_ipex()
|
||
|
||
from .utils import setup_logging
|
||
|
||
setup_logging()
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# Anima-specific training arguments
|
||
|
||
|
||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||
"""Add Anima-specific training arguments to the parser."""
|
||
parser.add_argument(
|
||
"--qwen3",
|
||
type=str,
|
||
default=None,
|
||
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||
)
|
||
parser.add_argument(
|
||
"--llm_adapter_path",
|
||
type=str,
|
||
default=None,
|
||
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
|
||
)
|
||
parser.add_argument(
|
||
"--llm_adapter_lr",
|
||
type=float,
|
||
default=None,
|
||
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
|
||
)
|
||
parser.add_argument(
|
||
"--self_attn_lr",
|
||
type=float,
|
||
default=None,
|
||
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
|
||
)
|
||
parser.add_argument(
|
||
"--cross_attn_lr",
|
||
type=float,
|
||
default=None,
|
||
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
|
||
)
|
||
parser.add_argument(
|
||
"--mlp_lr",
|
||
type=float,
|
||
default=None,
|
||
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
|
||
)
|
||
parser.add_argument(
|
||
"--mod_lr",
|
||
type=float,
|
||
default=None,
|
||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
|
||
)
|
||
parser.add_argument(
|
||
"--t5_tokenizer_path",
|
||
type=str,
|
||
default=None,
|
||
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
|
||
)
|
||
parser.add_argument(
|
||
"--qwen3_max_token_length",
|
||
type=int,
|
||
default=512,
|
||
help="Maximum token length for Qwen3 tokenizer (default: 512)",
|
||
)
|
||
parser.add_argument(
|
||
"--t5_max_token_length",
|
||
type=int,
|
||
default=512,
|
||
help="Maximum token length for T5 tokenizer (default: 512)",
|
||
)
|
||
parser.add_argument(
|
||
"--discrete_flow_shift",
|
||
type=float,
|
||
default=1.0,
|
||
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||
)
|
||
parser.add_argument(
|
||
"--timestep_sampling",
|
||
type=str,
|
||
default="sigmoid",
|
||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||
help="Timestep sampling method (default: sigmoid (logit normal))",
|
||
)
|
||
parser.add_argument(
|
||
"--sigmoid_scale",
|
||
type=float,
|
||
default=1.0,
|
||
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
|
||
)
|
||
parser.add_argument(
|
||
"--attn_mode",
|
||
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||
default=None,
|
||
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||
)
|
||
parser.add_argument(
|
||
"--split_attn",
|
||
action="store_true",
|
||
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||
)
|
||
parser.add_argument(
|
||
"--vae_chunk_size",
|
||
type=int,
|
||
default=None,
|
||
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
|
||
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
|
||
)
|
||
parser.add_argument(
|
||
"--vae_disable_cache",
|
||
action="store_true",
|
||
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
|
||
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
|
||
)
|
||
|
||
|
||
# Loss weighting
|
||
|
||
|
||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||
"""Compute loss weighting for Anima training.
|
||
|
||
Same schemes as SD3 but can add Anima-specific ones if needed in future.
|
||
"""
|
||
if weighting_scheme == "sigma_sqrt":
|
||
weighting = (sigmas**-2.0).float()
|
||
elif weighting_scheme == "cosmap":
|
||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||
weighting = 2 / (math.pi * bot)
|
||
elif weighting_scheme == "none" or weighting_scheme is None:
|
||
weighting = torch.ones_like(sigmas)
|
||
else:
|
||
weighting = torch.ones_like(sigmas)
|
||
return weighting
|
||
|
||
|
||
# Parameter groups (6 groups with separate LRs)
|
||
def get_anima_param_groups(
|
||
dit,
|
||
base_lr: float,
|
||
self_attn_lr: Optional[float] = None,
|
||
cross_attn_lr: Optional[float] = None,
|
||
mlp_lr: Optional[float] = None,
|
||
mod_lr: Optional[float] = None,
|
||
llm_adapter_lr: Optional[float] = None,
|
||
):
|
||
"""Create parameter groups for Anima training with separate learning rates.
|
||
|
||
Args:
|
||
dit: Anima model
|
||
base_lr: Base learning rate
|
||
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||
cross_attn_lr: LR for cross-attention layers
|
||
mlp_lr: LR for MLP layers
|
||
mod_lr: LR for AdaLN modulation layers
|
||
llm_adapter_lr: LR for LLM adapter
|
||
|
||
Returns:
|
||
List of parameter group dicts for optimizer
|
||
"""
|
||
if self_attn_lr is None:
|
||
self_attn_lr = base_lr
|
||
if cross_attn_lr is None:
|
||
cross_attn_lr = base_lr
|
||
if mlp_lr is None:
|
||
mlp_lr = base_lr
|
||
if mod_lr is None:
|
||
mod_lr = base_lr
|
||
if llm_adapter_lr is None:
|
||
llm_adapter_lr = base_lr
|
||
|
||
base_params = []
|
||
self_attn_params = []
|
||
cross_attn_params = []
|
||
mlp_params = []
|
||
mod_params = []
|
||
llm_adapter_params = []
|
||
|
||
for name, p in dit.named_parameters():
|
||
# Store original name for debugging
|
||
p.original_name = name
|
||
|
||
if "llm_adapter" in name:
|
||
llm_adapter_params.append(p)
|
||
elif ".self_attn" in name:
|
||
self_attn_params.append(p)
|
||
elif ".cross_attn" in name:
|
||
cross_attn_params.append(p)
|
||
elif ".mlp" in name:
|
||
mlp_params.append(p)
|
||
elif ".adaln_modulation" in name:
|
||
mod_params.append(p)
|
||
else:
|
||
base_params.append(p)
|
||
|
||
logger.info(f"Parameter groups:")
|
||
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
|
||
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
|
||
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
|
||
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
|
||
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
|
||
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
|
||
|
||
param_groups = []
|
||
for lr, params, name in [
|
||
(base_lr, base_params, "base"),
|
||
(self_attn_lr, self_attn_params, "self_attn"),
|
||
(cross_attn_lr, cross_attn_params, "cross_attn"),
|
||
(mlp_lr, mlp_params, "mlp"),
|
||
(mod_lr, mod_params, "mod"),
|
||
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
|
||
]:
|
||
if lr == 0:
|
||
for p in params:
|
||
p.requires_grad_(False)
|
||
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||
elif len(params) > 0:
|
||
param_groups.append({"params": params, "lr": lr})
|
||
|
||
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
|
||
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||
|
||
return param_groups
|
||
|
||
|
||
# Save functions
|
||
def save_anima_model_on_train_end(
|
||
args: argparse.Namespace,
|
||
save_dtype: torch.dtype,
|
||
epoch: int,
|
||
global_step: int,
|
||
dit: anima_models.Anima,
|
||
):
|
||
"""Save Anima model at the end of training."""
|
||
|
||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||
).to_metadata_dict()
|
||
dit_sd = dit.state_dict()
|
||
# Save with 'net.' prefix for ComfyUI compatibility
|
||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||
|
||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||
|
||
|
||
def save_anima_model_on_epoch_end_or_stepwise(
|
||
args: argparse.Namespace,
|
||
on_epoch_end: bool,
|
||
accelerator: Accelerator,
|
||
save_dtype: torch.dtype,
|
||
epoch: int,
|
||
num_train_epochs: int,
|
||
global_step: int,
|
||
dit: anima_models.Anima,
|
||
):
|
||
"""Save Anima model at epoch end or specific steps."""
|
||
|
||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||
).to_metadata_dict()
|
||
dit_sd = dit.state_dict()
|
||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||
|
||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||
args,
|
||
on_epoch_end,
|
||
accelerator,
|
||
True,
|
||
True,
|
||
epoch,
|
||
num_train_epochs,
|
||
global_step,
|
||
sd_saver,
|
||
None,
|
||
)
|
||
|
||
|
||
# Sampling (Euler discrete for rectified flow)
|
||
def do_sample(
|
||
height: int,
|
||
width: int,
|
||
seed: Optional[int],
|
||
dit: anima_models.Anima,
|
||
crossattn_emb: torch.Tensor,
|
||
steps: int,
|
||
dtype: torch.dtype,
|
||
device: torch.device,
|
||
guidance_scale: float = 1.0,
|
||
flow_shift: float = 3.0,
|
||
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||
) -> torch.Tensor:
|
||
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||
|
||
Args:
|
||
height, width: Output image dimensions
|
||
seed: Random seed (None for random)
|
||
dit: Anima model
|
||
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||
steps: Number of sampling steps
|
||
dtype: Compute dtype
|
||
device: Compute device
|
||
guidance_scale: CFG scale (1.0 = no guidance)
|
||
flow_shift: Flow shift parameter for rectified flow
|
||
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||
|
||
Returns:
|
||
Denoised latents
|
||
"""
|
||
# Latent shape: (1, 16, 1, H/8, W/8) for single image
|
||
latent_h = height // 8
|
||
latent_w = width // 8
|
||
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
|
||
|
||
# Generate noise
|
||
if seed is not None:
|
||
generator = torch.manual_seed(seed)
|
||
else:
|
||
generator = None
|
||
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
|
||
|
||
# Timestep schedule: linear from 1.0 to 0.0
|
||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||
flow_shift = float(flow_shift)
|
||
if flow_shift != 1.0:
|
||
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
|
||
|
||
# Start from pure noise
|
||
x = noise.clone()
|
||
|
||
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
|
||
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
|
||
|
||
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
|
||
|
||
for i in tqdm(range(steps), desc="Sampling"):
|
||
sigma = sigmas[i]
|
||
t = sigma.unsqueeze(0) # (1,)
|
||
|
||
if use_cfg:
|
||
# CFG: two separate passes to reduce memory usage
|
||
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||
pos_out = pos_out.float()
|
||
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
|
||
neg_out = neg_out.float()
|
||
|
||
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||
else:
|
||
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||
model_output = model_output.float()
|
||
|
||
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
|
||
dt = sigmas[i + 1] - sigma
|
||
x = x + model_output * dt
|
||
x = x.to(dtype)
|
||
|
||
return x
|
||
|
||
|
||
def sample_images(
|
||
accelerator: Accelerator,
|
||
args: argparse.Namespace,
|
||
epoch,
|
||
steps,
|
||
dit: anima_models.Anima,
|
||
vae,
|
||
text_encoder,
|
||
tokenize_strategy,
|
||
text_encoding_strategy,
|
||
sample_prompts_te_outputs=None,
|
||
prompt_replacement=None,
|
||
):
|
||
"""Generate sample images during training.
|
||
|
||
This is a simplified sampler for Anima - it generates images using the current model state.
|
||
"""
|
||
if steps == 0:
|
||
if not args.sample_at_first:
|
||
return
|
||
else:
|
||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||
return
|
||
if args.sample_every_n_epochs is not None:
|
||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||
return
|
||
else:
|
||
if steps % args.sample_every_n_steps != 0 or epoch is not None:
|
||
return
|
||
|
||
logger.info(f"Generating sample images at step {steps}")
|
||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||
logger.error(f"No prompt file: {args.sample_prompts}")
|
||
return
|
||
|
||
# Unwrap models
|
||
dit = accelerator.unwrap_model(dit)
|
||
if text_encoder is not None:
|
||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||
|
||
dit.switch_block_swap_for_inference()
|
||
|
||
prompts = train_util.load_prompts(args.sample_prompts)
|
||
save_dir = os.path.join(args.output_dir, "sample")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# Save RNG state
|
||
rng_state = torch.get_rng_state()
|
||
cuda_rng_state = None
|
||
try:
|
||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||
except Exception:
|
||
pass
|
||
|
||
with torch.no_grad(), accelerator.autocast():
|
||
for prompt_dict in prompts:
|
||
dit.prepare_block_swap_before_forward()
|
||
_sample_image_inference(
|
||
accelerator,
|
||
args,
|
||
dit,
|
||
text_encoder,
|
||
vae,
|
||
tokenize_strategy,
|
||
text_encoding_strategy,
|
||
save_dir,
|
||
prompt_dict,
|
||
epoch,
|
||
steps,
|
||
sample_prompts_te_outputs,
|
||
prompt_replacement,
|
||
)
|
||
|
||
# Restore RNG state
|
||
torch.set_rng_state(rng_state)
|
||
if cuda_rng_state is not None:
|
||
torch.cuda.set_rng_state(cuda_rng_state)
|
||
|
||
dit.switch_block_swap_for_training()
|
||
clean_memory_on_device(accelerator.device)
|
||
|
||
|
||
def _sample_image_inference(
|
||
accelerator,
|
||
args,
|
||
dit,
|
||
text_encoder,
|
||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
|
||
tokenize_strategy,
|
||
text_encoding_strategy,
|
||
save_dir,
|
||
prompt_dict,
|
||
epoch,
|
||
steps,
|
||
sample_prompts_te_outputs,
|
||
prompt_replacement,
|
||
):
|
||
"""Generate a single sample image."""
|
||
prompt = prompt_dict.get("prompt", "")
|
||
negative_prompt = prompt_dict.get("negative_prompt", "")
|
||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||
width = prompt_dict.get("width", 512)
|
||
height = prompt_dict.get("height", 512)
|
||
scale = prompt_dict.get("scale", 7.5)
|
||
seed = prompt_dict.get("seed")
|
||
flow_shift = prompt_dict.get("flow_shift", 3.0)
|
||
|
||
if prompt_replacement is not None:
|
||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||
if negative_prompt:
|
||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||
|
||
if seed is not None:
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
|
||
|
||
height = max(64, height - height % 16)
|
||
width = max(64, width - width % 16)
|
||
|
||
logger.info(
|
||
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
|
||
)
|
||
|
||
# Encode prompt
|
||
def encode_prompt(prpt):
|
||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||
return sample_prompts_te_outputs[prpt]
|
||
if text_encoder is not None:
|
||
tokens = tokenize_strategy.tokenize(prpt)
|
||
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||
return encoded
|
||
return None
|
||
|
||
encoded = encode_prompt(prompt)
|
||
if encoded is None:
|
||
logger.warning("Cannot encode prompt, skipping sample")
|
||
return
|
||
|
||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
|
||
|
||
# Convert to tensors if numpy
|
||
if isinstance(prompt_embeds, np.ndarray):
|
||
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
|
||
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
|
||
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||
|
||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
|
||
attn_mask = attn_mask.to(accelerator.device)
|
||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||
|
||
# Process through LLM adapter if available
|
||
if dit.use_llm_adapter:
|
||
crossattn_emb = dit.llm_adapter(
|
||
source_hidden_states=prompt_embeds,
|
||
target_input_ids=t5_input_ids,
|
||
target_attention_mask=t5_attn_mask,
|
||
source_attention_mask=attn_mask,
|
||
)
|
||
crossattn_emb[~t5_attn_mask.bool()] = 0
|
||
else:
|
||
crossattn_emb = prompt_embeds
|
||
|
||
# Encode negative prompt for CFG
|
||
neg_crossattn_emb = None
|
||
if scale > 1.0 and negative_prompt is not None:
|
||
neg_encoded = encode_prompt(negative_prompt)
|
||
if neg_encoded is not None:
|
||
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
|
||
if isinstance(neg_pe, np.ndarray):
|
||
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
|
||
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
|
||
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||
|
||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
|
||
neg_am = neg_am.to(accelerator.device)
|
||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||
|
||
if dit.use_llm_adapter:
|
||
neg_crossattn_emb = dit.llm_adapter(
|
||
source_hidden_states=neg_pe,
|
||
target_input_ids=neg_t5_ids,
|
||
target_attention_mask=neg_t5_am,
|
||
source_attention_mask=neg_am,
|
||
)
|
||
neg_crossattn_emb[~neg_t5_am.bool()] = 0
|
||
else:
|
||
neg_crossattn_emb = neg_pe
|
||
|
||
# Generate sample
|
||
clean_memory_on_device(accelerator.device)
|
||
latents = do_sample(
|
||
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
|
||
)
|
||
|
||
# Decode latents
|
||
gc.collect()
|
||
synchronize_device(accelerator.device)
|
||
clean_memory_on_device(accelerator.device)
|
||
org_vae_device = vae.device
|
||
vae.to(accelerator.device)
|
||
decoded = vae.decode_to_pixels(latents)
|
||
vae.to(org_vae_device)
|
||
clean_memory_on_device(accelerator.device)
|
||
|
||
# Convert to image
|
||
image = decoded.float()
|
||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||
# Remove temporal dim if present
|
||
if image.ndim == 4:
|
||
image = image[:, 0, :, :]
|
||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||
decoded_np = decoded_np.astype(np.uint8)
|
||
|
||
image = Image.fromarray(decoded_np)
|
||
|
||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||
seed_suffix = "" if seed is None else f"_{seed}"
|
||
i = prompt_dict.get("enum", 0)
|
||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||
image.save(os.path.join(save_dir, img_filename))
|
||
|
||
# Log to wandb if enabled
|
||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||
wandb_tracker = accelerator.get_tracker("wandb")
|
||
import wandb
|
||
|
||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|