Files
Kohya-ss-sd-scripts/tests/manual_test_anima_cache.py
Kohya S. 34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

608 lines
24 KiB
Python

"""
Diagnostic script to test Anima latent & text encoder caching independently.
Usage:
python manual_test_anima_cache.py \
--image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \
[--t5_tokenizer_path /path/to/t5] \
[--cache_to_disk]
The image_dir should contain pairs of:
image1.png + image1.txt
image2.jpg + image2.txt
...
"""
import argparse
import glob
import os
import sys
import traceback
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
# Helpers
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(), # [0,1]
transforms.Normalize([0.5], [0.5]), # [-1,1]
]
)
def find_image_caption_pairs(image_dir: str):
"""Find (image_path, caption_text) pairs from a directory."""
pairs = []
for f in sorted(os.listdir(image_dir)):
ext = os.path.splitext(f)[1].lower()
if ext not in IMAGE_EXTENSIONS:
continue
img_path = os.path.join(image_dir, f)
txt_path = os.path.splitext(img_path)[0] + ".txt"
if os.path.exists(txt_path):
with open(txt_path, "r", encoding="utf-8") as fh:
caption = fh.read().strip()
else:
caption = ""
pairs.append((img_path, caption))
return pairs
def print_tensor_info(name: str, t, indent=2):
prefix = " " * indent
if t is None:
print(f"{prefix}{name}: None")
return
if isinstance(t, np.ndarray):
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
elif isinstance(t, torch.Tensor):
print(
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
)
else:
print(f"{prefix}{name}: type={type(t)} value={t}")
# Test 1: Latent Cache
def test_latent_cache(args, pairs):
print("\n" + "=" * 70)
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
print("=" * 70)
from library import qwen_image_autoencoder_kl
# Load VAE
print("\n[1.1] Loading VAE...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vae_dtype = torch.float32
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
print(f" VAE loaded on {device}, dtype={vae_dtype}")
for img_path, caption in pairs:
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
# Load image
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
img_tensor = IMAGE_TRANSFORMS(img_np)
print(
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
)
# Check range is [-1, 1]
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
print(" ** WARNING: tensor out of [-1, 1] range!")
else:
print(" OK: tensor in [-1, 1] range")
# Encode with VAE
img_batch = img_tensor.unsqueeze(0).to(device, dtype=vae_dtype) # (1, C, H, W)
img_5d = img_batch.unsqueeze(2) # (1, C, 1, H, W) - add temporal dim
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
with torch.no_grad():
latents = vae.encode_pixels_to_latents(img_5d)
latents_cpu = latents.cpu()
print_tensor_info("Encoded latents", latents_cpu)
# Check for NaN/Inf
if torch.any(torch.isnan(latents_cpu)):
print(" ** ERROR: NaN in latents!")
elif torch.any(torch.isinf(latents_cpu)):
print(" ** ERROR: Inf in latents!")
else:
print(" OK: no NaN/Inf")
# Test disk cache round-trip
if args.cache_to_disk:
npz_path = os.path.splitext(img_path)[0] + "_test_latent.npz"
latents_np = latents_cpu.float().numpy()
h, w = img_np.shape[:2]
np.savez(
npz_path,
latents=latents_np,
original_size=np.array([w, h]),
crop_ltrb=np.array([0, 0, 0, 0]),
)
print(f" Saved to: {npz_path}")
# Reload
loaded = np.load(npz_path)
loaded_latents = loaded["latents"]
print_tensor_info("Reloaded latents", loaded_latents)
# Compare
diff = np.abs(latents_np - loaded_latents).max()
print(f" Max diff (save vs load): {diff:.2e}")
if diff > 1e-5:
print(" ** WARNING: latent cache round-trip has significant diff!")
else:
print(" OK: round-trip matches")
os.remove(npz_path)
print(f" Cleaned up {npz_path}")
vae.to("cpu")
del vae
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("\n[1.3] Latent cache test DONE.")
# Test 2: Text Encoder Output Cache
def test_text_encoder_cache(args, pairs):
# TODO Rewrite this
print("\n" + "=" * 70)
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
print("=" * 70)
from library import anima_utils
# Load tokenizers
print("\n[2.1] Loading tokenizers...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
# Load text encoder
print("\n[2.2] Loading Qwen3 text encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
# Create strategy objects
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
tokenize_strategy = AnimaTokenizeStrategy(
qwen3_tokenizer=qwen3_tokenizer,
t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy()
captions = [cap for _, cap in pairs]
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
for i, cap in enumerate(captions):
print(f" [{i}] \"{cap[:80]}{'...' if len(cap) > 80 else ''}\"")
tokens_and_masks = tokenize_strategy.tokenize(captions)
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens_and_masks
print(f"\n Tokenization results:")
print_tensor_info("qwen3_input_ids", qwen3_input_ids)
print_tensor_info("qwen3_attn_mask", qwen3_attn_mask)
print_tensor_info("t5_input_ids", t5_input_ids)
print_tensor_info("t5_attn_mask", t5_attn_mask)
# Encode
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
with torch.no_grad():
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
tokenize_strategy, [qwen3_model], tokens_and_masks
)
print(f" Encoding results:")
print_tensor_info("prompt_embeds", prompt_embeds)
print_tensor_info("attn_mask", attn_mask)
print_tensor_info("t5_input_ids", t5_ids_out)
print_tensor_info("t5_attn_mask", t5_mask_out)
# Check for NaN/Inf
if torch.any(torch.isnan(prompt_embeds)):
print(" ** ERROR: NaN in prompt_embeds!")
elif torch.any(torch.isinf(prompt_embeds)):
print(" ** ERROR: Inf in prompt_embeds!")
else:
print(" OK: no NaN/Inf in prompt_embeds")
# Test cache round-trip (simulate what AnimaTextEncoderOutputsCachingStrategy does)
print(f"\n[2.5] Testing cache round-trip (encode -> numpy -> npz -> reload -> tensor)...")
# Convert to numpy (same as cache_batch_outputs in strategy_anima.py)
pe_cpu = prompt_embeds.cpu()
if pe_cpu.dtype == torch.bfloat16:
pe_cpu = pe_cpu.float()
pe_np = pe_cpu.numpy()
am_np = attn_mask.cpu().numpy()
t5_ids_np = t5_ids_out.cpu().numpy().astype(np.int32)
t5_mask_np = t5_mask_out.cpu().numpy().astype(np.int32)
print(f" Numpy conversions:")
print_tensor_info("prompt_embeds_np", pe_np)
print_tensor_info("attn_mask_np", am_np)
print_tensor_info("t5_input_ids_np", t5_ids_np)
print_tensor_info("t5_attn_mask_np", t5_mask_np)
if args.cache_to_disk:
npz_path = os.path.join(args.image_dir, "_test_te_cache.npz")
# Save per-sample (simulating cache_batch_outputs)
for i in range(len(captions)):
sample_npz = os.path.splitext(pairs[i][0])[0] + "_test_te.npz"
np.savez(
sample_npz,
prompt_embeds=pe_np[i],
attn_mask=am_np[i],
t5_input_ids=t5_ids_np[i],
t5_attn_mask=t5_mask_np[i],
)
print(f" Saved: {sample_npz}")
# Reload (simulating load_outputs_npz)
data = np.load(sample_npz)
print(f" Reloaded keys: {list(data.keys())}")
print_tensor_info(" loaded prompt_embeds", data["prompt_embeds"], indent=4)
print_tensor_info(" loaded attn_mask", data["attn_mask"], indent=4)
print_tensor_info(" loaded t5_input_ids", data["t5_input_ids"], indent=4)
print_tensor_info(" loaded t5_attn_mask", data["t5_attn_mask"], indent=4)
# Check diff
diff_pe = np.abs(pe_np[i] - data["prompt_embeds"]).max()
diff_t5 = np.abs(t5_ids_np[i] - data["t5_input_ids"]).max()
print(f" Max diff prompt_embeds: {diff_pe:.2e}")
print(f" Max diff t5_input_ids: {diff_t5:.2e}")
if diff_pe > 1e-5 or diff_t5 > 0:
print(" ** WARNING: cache round-trip mismatch!")
else:
print(" OK: round-trip matches")
os.remove(sample_npz)
print(f" Cleaned up {sample_npz}")
# Test in-memory cache round-trip (simulating what __getitem__ does)
print(f"\n[2.6] Testing in-memory cache simulation (tuple -> none_or_stack_elements -> batch)...")
# Simulate per-sample storage (like info.text_encoder_outputs = tuple)
per_sample_cached = []
for i in range(len(captions)):
per_sample_cached.append((pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]))
# Simulate none_or_stack_elements with torch.FloatTensor converter
# This is what train_util.py __getitem__ does at line 1784
stacked = []
for elem_idx in range(4):
arrays = [sample[elem_idx] for sample in per_sample_cached]
stacked.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
print(f" Stacked batch (like batch['text_encoder_outputs_list']):")
names = ["prompt_embeds", "attn_mask", "t5_input_ids", "t5_attn_mask"]
for name, tensor in zip(names, stacked):
print_tensor_info(name, tensor)
# Check condition: len(text_encoder_conds) == 0 or text_encoder_conds[0] is None
text_encoder_conds = stacked
cond_check_1 = len(text_encoder_conds) == 0
cond_check_2 = text_encoder_conds[0] is None
print(f"\n Condition check (should both be False when caching works):")
print(f" len(text_encoder_conds) == 0 : {cond_check_1}")
print(f" text_encoder_conds[0] is None: {cond_check_2}")
if not cond_check_1 and not cond_check_2:
print(" OK: cached text encoder outputs would be used")
else:
print(" ** BUG: code would try to re-encode (and crash on None input_ids_list)!")
# Test unpack for get_noise_pred_and_target (line 311)
print(f"\n[2.7] Testing unpack: prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds")
try:
pe_batch, am_batch, t5_ids_batch, t5_mask_batch = text_encoder_conds
print(f" Unpack OK")
print_tensor_info("prompt_embeds", pe_batch)
print_tensor_info("attn_mask", am_batch)
print_tensor_info("t5_input_ids", t5_ids_batch)
print_tensor_info("t5_attn_mask", t5_mask_batch)
# Check t5_input_ids are integers (they were converted to FloatTensor!)
if t5_ids_batch.dtype != torch.long and t5_ids_batch.dtype != torch.int32:
print(f"\n ** NOTE: t5_input_ids dtype is {t5_ids_batch.dtype}, will be cast to long at line 316")
t5_ids_long = t5_ids_batch.to(dtype=torch.long)
# Check if any precision was lost
diff = (t5_ids_batch - t5_ids_long.float()).abs().max()
print(f" Float->Long precision loss: {diff:.2e}")
if diff > 0.5:
print(" ** ERROR: token IDs corrupted by float conversion!")
else:
print(" OK: float->long conversion is lossless for these IDs")
except Exception as e:
print(f" ** ERROR unpacking: {e}")
traceback.print_exc()
# Test drop_cached_text_encoder_outputs
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
dropout_strategy = AnimaTextEncodingStrategy(
dropout_rate=0.5, # high rate to ensure some drops
)
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
print(f" Returned {len(dropped)} tensors")
for name, tensor in zip(names, dropped):
print_tensor_info(f"dropped_{name}", tensor)
# Check which items were dropped
for i in range(len(captions)):
is_zero = (dropped[0][i].abs().sum() == 0).item()
print(f" Sample {i}: {'DROPPED' if is_zero else 'KEPT'}")
qwen3_model.to("cpu")
del qwen3_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("\n[2.8] Text encoder cache test DONE.")
# Test 3: Full batch simulation
def test_full_batch_simulation(args, pairs):
print("\n" + "=" * 70)
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
print("=" * 70)
from library import anima_utils
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
device = "cuda" if torch.cuda.is_available() else "cpu"
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
vae_dtype = torch.float32
# Load all models
print("\n[3.1] Loading models...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
tokenize_strategy = AnimaTokenizeStrategy(
qwen3_tokenizer=qwen3_tokenizer,
t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
captions = [cap for _, cap in pairs]
# --- Simulate caching phase ---
print("\n[3.2] Simulating text encoder caching phase...")
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
te_outputs = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_model],
tokens_and_masks,
enable_dropout=False,
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
# Convert to numpy (same as cache_batch_outputs)
pe_np = prompt_embeds.cpu().float().numpy()
am_np = attn_mask.cpu().numpy()
t5_ids_np = t5_input_ids.cpu().numpy().astype(np.int32)
t5_mask_np = t5_attn_mask.cpu().numpy().astype(np.int32)
# Per-sample storage (like info.text_encoder_outputs)
per_sample_te = [(pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]) for i in range(len(captions))]
print(f"\n[3.3] Simulating latent caching phase...")
per_sample_latents = []
for img_path, _ in pairs:
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
img_tensor = IMAGE_TRANSFORMS(img_np).unsqueeze(0).unsqueeze(2) # (1,C,1,H,W)
img_tensor = img_tensor.to(device, dtype=vae_dtype)
with torch.no_grad():
lat = vae.encode(img_tensor, vae_scale).cpu()
per_sample_latents.append(lat.squeeze(0)) # (C,1,H,W)
print(f" {os.path.basename(img_path)}: latent shape={tuple(lat.shape)}")
# --- Simulate batch construction (__getitem__) ---
print(f"\n[3.4] Simulating batch construction...")
# Use first image's latents only (images may have different resolutions)
latents_batch = per_sample_latents[0].unsqueeze(0) # (1,C,1,H,W)
print(f" Using first image latent for simulation: shape={tuple(latents_batch.shape)}")
# Stack text encoder outputs (none_or_stack_elements)
text_encoder_outputs_list = []
for elem_idx in range(4):
arrays = [s[elem_idx] for s in per_sample_te]
text_encoder_outputs_list.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
# input_ids_list is None when caching
input_ids_list = None
batch = {
"latents": latents_batch,
"text_encoder_outputs_list": text_encoder_outputs_list,
"input_ids_list": input_ids_list,
"loss_weights": torch.ones(len(captions)),
}
print(f" batch keys: {list(batch.keys())}")
print(f" batch['latents']: shape={tuple(batch['latents'].shape)}")
print(f" batch['text_encoder_outputs_list']: {len(batch['text_encoder_outputs_list'])} tensors")
print(f" batch['input_ids_list']: {batch['input_ids_list']}")
# --- Simulate process_batch logic ---
print(f"\n[3.5] Simulating process_batch logic...")
text_encoder_conds = []
te_out = batch.get("text_encoder_outputs_list", None)
if te_out is not None:
text_encoder_conds = te_out
print(f" text_encoder_conds loaded from cache: {len(text_encoder_conds)} tensors")
else:
print(f" text_encoder_conds: empty (no cache)")
# The critical condition
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
cond_new = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_FALSE
print(f"\n === CRITICAL CONDITION CHECK ===")
print(f" len(text_encoder_conds) == 0 : {len(text_encoder_conds) == 0}")
print(f" text_encoder_conds[0] is None: {text_encoder_conds[0] is None}")
print(f" train_text_encoder (OLD=True) : {train_text_encoder_TRUE}")
print(f" train_text_encoder (NEW=False): {train_text_encoder_FALSE}")
print(f"")
print(f" Condition with OLD behavior (no override): {cond_old}")
msg = (
"ENTERS re-encode block -> accesses batch['input_ids_list'] -> CRASH!"
if cond_old
else "SKIPS re-encode block -> uses cache -> OK"
)
print(f" -> {msg}")
print(f" Condition with NEW behavior (override): {cond_new}")
print(f" -> {'ENTERS re-encode block' if cond_new else 'SKIPS re-encode block -> uses cache -> OK'}")
if cond_old and not cond_new:
print(f"\n ** CONFIRMED: the is_train_text_encoder override fixes the crash **")
# Simulate the rest of process_batch
print(f"\n[3.6] Simulating get_noise_pred_and_target unpack...")
try:
pe, am, t5_ids, t5_mask = text_encoder_conds
pe = pe.to(device, dtype=te_dtype)
am = am.to(device)
t5_ids = t5_ids.to(device, dtype=torch.long)
t5_mask = t5_mask.to(device)
print(f" Unpack + device transfer OK:")
print_tensor_info("prompt_embeds", pe)
print_tensor_info("attn_mask", am)
print_tensor_info("t5_input_ids", t5_ids)
print_tensor_info("t5_attn_mask", t5_mask)
# Verify t5_input_ids didn't get corrupted by float conversion
t5_ids_orig = torch.tensor(t5_ids_np, dtype=torch.long, device=device)
id_match = torch.all(t5_ids == t5_ids_orig).item()
print(f"\n t5_input_ids integrity (float->long roundtrip): {'OK' if id_match else '** MISMATCH **'}")
if not id_match:
diff_count = (t5_ids != t5_ids_orig).sum().item()
print(f" {diff_count} token IDs differ!")
# Show example
idx = torch.where(t5_ids != t5_ids_orig)
if len(idx[0]) > 0:
i, j = idx[0][0].item(), idx[1][0].item()
print(f" Example: position [{i},{j}] original={t5_ids_orig[i,j].item()} loaded={t5_ids[i,j].item()}")
except Exception as e:
print(f" ** ERROR: {e}")
traceback.print_exc()
# Cleanup
vae.to("cpu")
qwen3_model.to("cpu")
del vae, qwen3_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("\n[3.7] Full batch simulation DONE.")
# Main
def main():
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
parser.add_argument("--qwen3_max_length", type=int, default=512)
parser.add_argument("--t5_max_length", type=int, default=512)
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
args = parser.parse_args()
# Find pairs
pairs = find_image_caption_pairs(args.image_dir)
if len(pairs) == 0:
print(f"ERROR: No image+txt pairs found in {args.image_dir}")
print("Expected: image.png + image.txt, image.jpg + image.txt, etc.")
sys.exit(1)
print(f"Found {len(pairs)} image-caption pairs:")
for img_path, cap in pairs:
print(f" {os.path.basename(img_path)}: \"{cap[:60]}{'...' if len(cap) > 60 else ''}\"")
results = {}
if not args.skip_latent:
try:
test_latent_cache(args, pairs)
results["latent_cache"] = "PASS"
except Exception as e:
print(f"\n** LATENT CACHE TEST FAILED: {e}")
traceback.print_exc()
results["latent_cache"] = f"FAIL: {e}"
if not args.skip_text:
try:
test_text_encoder_cache(args, pairs)
results["text_encoder_cache"] = "PASS"
except Exception as e:
print(f"\n** TEXT ENCODER CACHE TEST FAILED: {e}")
traceback.print_exc()
results["text_encoder_cache"] = f"FAIL: {e}"
if not args.skip_full:
try:
test_full_batch_simulation(args, pairs)
results["full_batch_sim"] = "PASS"
except Exception as e:
print(f"\n** FULL BATCH SIMULATION FAILED: {e}")
traceback.print_exc()
results["full_batch_sim"] = f"FAIL: {e}"
# Summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
for test, result in results.items():
status = "OK" if result == "PASS" else "FAIL"
print(f" [{status}] {test}: {result}")
print()
if __name__ == "__main__":
main()