mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
608 lines
24 KiB
Python
608 lines
24 KiB
Python
"""
|
|
Diagnostic script to test Anima latent & text encoder caching independently.
|
|
|
|
Usage:
|
|
python manual_test_anima_cache.py \
|
|
--image_dir /path/to/images \
|
|
--qwen3_path /path/to/qwen3 \
|
|
--vae_path /path/to/vae.safetensors \
|
|
[--t5_tokenizer_path /path/to/t5] \
|
|
[--cache_to_disk]
|
|
|
|
The image_dir should contain pairs of:
|
|
image1.png + image1.txt
|
|
image2.jpg + image2.txt
|
|
...
|
|
"""
|
|
|
|
import argparse
|
|
import glob
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
# Helpers
|
|
|
|
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
|
|
|
|
IMAGE_TRANSFORMS = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(), # [0,1]
|
|
transforms.Normalize([0.5], [0.5]), # [-1,1]
|
|
]
|
|
)
|
|
|
|
|
|
def find_image_caption_pairs(image_dir: str):
|
|
"""Find (image_path, caption_text) pairs from a directory."""
|
|
pairs = []
|
|
for f in sorted(os.listdir(image_dir)):
|
|
ext = os.path.splitext(f)[1].lower()
|
|
if ext not in IMAGE_EXTENSIONS:
|
|
continue
|
|
img_path = os.path.join(image_dir, f)
|
|
txt_path = os.path.splitext(img_path)[0] + ".txt"
|
|
if os.path.exists(txt_path):
|
|
with open(txt_path, "r", encoding="utf-8") as fh:
|
|
caption = fh.read().strip()
|
|
else:
|
|
caption = ""
|
|
pairs.append((img_path, caption))
|
|
return pairs
|
|
|
|
|
|
def print_tensor_info(name: str, t, indent=2):
|
|
prefix = " " * indent
|
|
if t is None:
|
|
print(f"{prefix}{name}: None")
|
|
return
|
|
if isinstance(t, np.ndarray):
|
|
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
|
|
elif isinstance(t, torch.Tensor):
|
|
print(
|
|
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
|
|
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
|
|
)
|
|
else:
|
|
print(f"{prefix}{name}: type={type(t)} value={t}")
|
|
|
|
|
|
# Test 1: Latent Cache
|
|
|
|
|
|
def test_latent_cache(args, pairs):
|
|
print("\n" + "=" * 70)
|
|
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
|
|
print("=" * 70)
|
|
|
|
from library import qwen_image_autoencoder_kl
|
|
|
|
# Load VAE
|
|
print("\n[1.1] Loading VAE...")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
vae_dtype = torch.float32
|
|
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
|
|
print(f" VAE loaded on {device}, dtype={vae_dtype}")
|
|
|
|
for img_path, caption in pairs:
|
|
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
|
|
|
|
# Load image
|
|
img = Image.open(img_path).convert("RGB")
|
|
img_np = np.array(img)
|
|
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
|
|
|
|
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
|
|
img_tensor = IMAGE_TRANSFORMS(img_np)
|
|
print(
|
|
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
|
|
)
|
|
|
|
# Check range is [-1, 1]
|
|
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
|
|
print(" ** WARNING: tensor out of [-1, 1] range!")
|
|
else:
|
|
print(" OK: tensor in [-1, 1] range")
|
|
|
|
# Encode with VAE
|
|
img_batch = img_tensor.unsqueeze(0).to(device, dtype=vae_dtype) # (1, C, H, W)
|
|
img_5d = img_batch.unsqueeze(2) # (1, C, 1, H, W) - add temporal dim
|
|
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
|
|
|
|
with torch.no_grad():
|
|
latents = vae.encode_pixels_to_latents(img_5d)
|
|
latents_cpu = latents.cpu()
|
|
print_tensor_info("Encoded latents", latents_cpu)
|
|
|
|
# Check for NaN/Inf
|
|
if torch.any(torch.isnan(latents_cpu)):
|
|
print(" ** ERROR: NaN in latents!")
|
|
elif torch.any(torch.isinf(latents_cpu)):
|
|
print(" ** ERROR: Inf in latents!")
|
|
else:
|
|
print(" OK: no NaN/Inf")
|
|
|
|
# Test disk cache round-trip
|
|
if args.cache_to_disk:
|
|
npz_path = os.path.splitext(img_path)[0] + "_test_latent.npz"
|
|
latents_np = latents_cpu.float().numpy()
|
|
h, w = img_np.shape[:2]
|
|
np.savez(
|
|
npz_path,
|
|
latents=latents_np,
|
|
original_size=np.array([w, h]),
|
|
crop_ltrb=np.array([0, 0, 0, 0]),
|
|
)
|
|
print(f" Saved to: {npz_path}")
|
|
|
|
# Reload
|
|
loaded = np.load(npz_path)
|
|
loaded_latents = loaded["latents"]
|
|
print_tensor_info("Reloaded latents", loaded_latents)
|
|
|
|
# Compare
|
|
diff = np.abs(latents_np - loaded_latents).max()
|
|
print(f" Max diff (save vs load): {diff:.2e}")
|
|
if diff > 1e-5:
|
|
print(" ** WARNING: latent cache round-trip has significant diff!")
|
|
else:
|
|
print(" OK: round-trip matches")
|
|
|
|
os.remove(npz_path)
|
|
print(f" Cleaned up {npz_path}")
|
|
|
|
vae.to("cpu")
|
|
del vae
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
|
print("\n[1.3] Latent cache test DONE.")
|
|
|
|
|
|
# Test 2: Text Encoder Output Cache
|
|
|
|
|
|
def test_text_encoder_cache(args, pairs):
|
|
# TODO Rewrite this
|
|
print("\n" + "=" * 70)
|
|
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
|
|
print("=" * 70)
|
|
|
|
from library import anima_utils
|
|
|
|
# Load tokenizers
|
|
print("\n[2.1] Loading tokenizers...")
|
|
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
|
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
|
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
|
|
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
|
|
|
|
# Load text encoder
|
|
print("\n[2.2] Loading Qwen3 text encoder...")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
|
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
|
qwen3_model.eval()
|
|
|
|
# Create strategy objects
|
|
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
|
|
|
tokenize_strategy = AnimaTokenizeStrategy(
|
|
qwen3_tokenizer=qwen3_tokenizer,
|
|
t5_tokenizer=t5_tokenizer,
|
|
qwen3_max_length=args.qwen3_max_length,
|
|
t5_max_length=args.t5_max_length,
|
|
)
|
|
text_encoding_strategy = AnimaTextEncodingStrategy()
|
|
|
|
captions = [cap for _, cap in pairs]
|
|
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
|
|
for i, cap in enumerate(captions):
|
|
print(f" [{i}] \"{cap[:80]}{'...' if len(cap) > 80 else ''}\"")
|
|
|
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
|
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens_and_masks
|
|
|
|
print(f"\n Tokenization results:")
|
|
print_tensor_info("qwen3_input_ids", qwen3_input_ids)
|
|
print_tensor_info("qwen3_attn_mask", qwen3_attn_mask)
|
|
print_tensor_info("t5_input_ids", t5_input_ids)
|
|
print_tensor_info("t5_attn_mask", t5_attn_mask)
|
|
|
|
# Encode
|
|
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
|
|
with torch.no_grad():
|
|
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy, [qwen3_model], tokens_and_masks
|
|
)
|
|
|
|
print(f" Encoding results:")
|
|
print_tensor_info("prompt_embeds", prompt_embeds)
|
|
print_tensor_info("attn_mask", attn_mask)
|
|
print_tensor_info("t5_input_ids", t5_ids_out)
|
|
print_tensor_info("t5_attn_mask", t5_mask_out)
|
|
|
|
# Check for NaN/Inf
|
|
if torch.any(torch.isnan(prompt_embeds)):
|
|
print(" ** ERROR: NaN in prompt_embeds!")
|
|
elif torch.any(torch.isinf(prompt_embeds)):
|
|
print(" ** ERROR: Inf in prompt_embeds!")
|
|
else:
|
|
print(" OK: no NaN/Inf in prompt_embeds")
|
|
|
|
# Test cache round-trip (simulate what AnimaTextEncoderOutputsCachingStrategy does)
|
|
print(f"\n[2.5] Testing cache round-trip (encode -> numpy -> npz -> reload -> tensor)...")
|
|
|
|
# Convert to numpy (same as cache_batch_outputs in strategy_anima.py)
|
|
pe_cpu = prompt_embeds.cpu()
|
|
if pe_cpu.dtype == torch.bfloat16:
|
|
pe_cpu = pe_cpu.float()
|
|
pe_np = pe_cpu.numpy()
|
|
am_np = attn_mask.cpu().numpy()
|
|
t5_ids_np = t5_ids_out.cpu().numpy().astype(np.int32)
|
|
t5_mask_np = t5_mask_out.cpu().numpy().astype(np.int32)
|
|
|
|
print(f" Numpy conversions:")
|
|
print_tensor_info("prompt_embeds_np", pe_np)
|
|
print_tensor_info("attn_mask_np", am_np)
|
|
print_tensor_info("t5_input_ids_np", t5_ids_np)
|
|
print_tensor_info("t5_attn_mask_np", t5_mask_np)
|
|
|
|
if args.cache_to_disk:
|
|
npz_path = os.path.join(args.image_dir, "_test_te_cache.npz")
|
|
# Save per-sample (simulating cache_batch_outputs)
|
|
for i in range(len(captions)):
|
|
sample_npz = os.path.splitext(pairs[i][0])[0] + "_test_te.npz"
|
|
np.savez(
|
|
sample_npz,
|
|
prompt_embeds=pe_np[i],
|
|
attn_mask=am_np[i],
|
|
t5_input_ids=t5_ids_np[i],
|
|
t5_attn_mask=t5_mask_np[i],
|
|
)
|
|
print(f" Saved: {sample_npz}")
|
|
|
|
# Reload (simulating load_outputs_npz)
|
|
data = np.load(sample_npz)
|
|
print(f" Reloaded keys: {list(data.keys())}")
|
|
print_tensor_info(" loaded prompt_embeds", data["prompt_embeds"], indent=4)
|
|
print_tensor_info(" loaded attn_mask", data["attn_mask"], indent=4)
|
|
print_tensor_info(" loaded t5_input_ids", data["t5_input_ids"], indent=4)
|
|
print_tensor_info(" loaded t5_attn_mask", data["t5_attn_mask"], indent=4)
|
|
|
|
# Check diff
|
|
diff_pe = np.abs(pe_np[i] - data["prompt_embeds"]).max()
|
|
diff_t5 = np.abs(t5_ids_np[i] - data["t5_input_ids"]).max()
|
|
print(f" Max diff prompt_embeds: {diff_pe:.2e}")
|
|
print(f" Max diff t5_input_ids: {diff_t5:.2e}")
|
|
if diff_pe > 1e-5 or diff_t5 > 0:
|
|
print(" ** WARNING: cache round-trip mismatch!")
|
|
else:
|
|
print(" OK: round-trip matches")
|
|
|
|
os.remove(sample_npz)
|
|
print(f" Cleaned up {sample_npz}")
|
|
|
|
# Test in-memory cache round-trip (simulating what __getitem__ does)
|
|
print(f"\n[2.6] Testing in-memory cache simulation (tuple -> none_or_stack_elements -> batch)...")
|
|
|
|
# Simulate per-sample storage (like info.text_encoder_outputs = tuple)
|
|
per_sample_cached = []
|
|
for i in range(len(captions)):
|
|
per_sample_cached.append((pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]))
|
|
|
|
# Simulate none_or_stack_elements with torch.FloatTensor converter
|
|
# This is what train_util.py __getitem__ does at line 1784
|
|
stacked = []
|
|
for elem_idx in range(4):
|
|
arrays = [sample[elem_idx] for sample in per_sample_cached]
|
|
stacked.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
|
|
|
print(f" Stacked batch (like batch['text_encoder_outputs_list']):")
|
|
names = ["prompt_embeds", "attn_mask", "t5_input_ids", "t5_attn_mask"]
|
|
for name, tensor in zip(names, stacked):
|
|
print_tensor_info(name, tensor)
|
|
|
|
# Check condition: len(text_encoder_conds) == 0 or text_encoder_conds[0] is None
|
|
text_encoder_conds = stacked
|
|
cond_check_1 = len(text_encoder_conds) == 0
|
|
cond_check_2 = text_encoder_conds[0] is None
|
|
print(f"\n Condition check (should both be False when caching works):")
|
|
print(f" len(text_encoder_conds) == 0 : {cond_check_1}")
|
|
print(f" text_encoder_conds[0] is None: {cond_check_2}")
|
|
if not cond_check_1 and not cond_check_2:
|
|
print(" OK: cached text encoder outputs would be used")
|
|
else:
|
|
print(" ** BUG: code would try to re-encode (and crash on None input_ids_list)!")
|
|
|
|
# Test unpack for get_noise_pred_and_target (line 311)
|
|
print(f"\n[2.7] Testing unpack: prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds")
|
|
try:
|
|
pe_batch, am_batch, t5_ids_batch, t5_mask_batch = text_encoder_conds
|
|
print(f" Unpack OK")
|
|
print_tensor_info("prompt_embeds", pe_batch)
|
|
print_tensor_info("attn_mask", am_batch)
|
|
print_tensor_info("t5_input_ids", t5_ids_batch)
|
|
print_tensor_info("t5_attn_mask", t5_mask_batch)
|
|
|
|
# Check t5_input_ids are integers (they were converted to FloatTensor!)
|
|
if t5_ids_batch.dtype != torch.long and t5_ids_batch.dtype != torch.int32:
|
|
print(f"\n ** NOTE: t5_input_ids dtype is {t5_ids_batch.dtype}, will be cast to long at line 316")
|
|
t5_ids_long = t5_ids_batch.to(dtype=torch.long)
|
|
# Check if any precision was lost
|
|
diff = (t5_ids_batch - t5_ids_long.float()).abs().max()
|
|
print(f" Float->Long precision loss: {diff:.2e}")
|
|
if diff > 0.5:
|
|
print(" ** ERROR: token IDs corrupted by float conversion!")
|
|
else:
|
|
print(" OK: float->long conversion is lossless for these IDs")
|
|
except Exception as e:
|
|
print(f" ** ERROR unpacking: {e}")
|
|
traceback.print_exc()
|
|
|
|
# Test drop_cached_text_encoder_outputs
|
|
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
|
|
dropout_strategy = AnimaTextEncodingStrategy(
|
|
dropout_rate=0.5, # high rate to ensure some drops
|
|
)
|
|
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
|
|
print(f" Returned {len(dropped)} tensors")
|
|
for name, tensor in zip(names, dropped):
|
|
print_tensor_info(f"dropped_{name}", tensor)
|
|
|
|
# Check which items were dropped
|
|
for i in range(len(captions)):
|
|
is_zero = (dropped[0][i].abs().sum() == 0).item()
|
|
print(f" Sample {i}: {'DROPPED' if is_zero else 'KEPT'}")
|
|
|
|
qwen3_model.to("cpu")
|
|
del qwen3_model
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
|
print("\n[2.8] Text encoder cache test DONE.")
|
|
|
|
|
|
# Test 3: Full batch simulation
|
|
|
|
|
|
def test_full_batch_simulation(args, pairs):
|
|
print("\n" + "=" * 70)
|
|
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
|
|
print("=" * 70)
|
|
|
|
from library import anima_utils
|
|
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
|
vae_dtype = torch.float32
|
|
|
|
# Load all models
|
|
print("\n[3.1] Loading models...")
|
|
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
|
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
|
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
|
qwen3_model.eval()
|
|
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
|
|
|
|
tokenize_strategy = AnimaTokenizeStrategy(
|
|
qwen3_tokenizer=qwen3_tokenizer,
|
|
t5_tokenizer=t5_tokenizer,
|
|
qwen3_max_length=args.qwen3_max_length,
|
|
t5_max_length=args.t5_max_length,
|
|
)
|
|
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
|
|
|
|
captions = [cap for _, cap in pairs]
|
|
|
|
# --- Simulate caching phase ---
|
|
print("\n[3.2] Simulating text encoder caching phase...")
|
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
|
with torch.no_grad():
|
|
te_outputs = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy,
|
|
[qwen3_model],
|
|
tokens_and_masks,
|
|
enable_dropout=False,
|
|
)
|
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
|
|
|
|
# Convert to numpy (same as cache_batch_outputs)
|
|
pe_np = prompt_embeds.cpu().float().numpy()
|
|
am_np = attn_mask.cpu().numpy()
|
|
t5_ids_np = t5_input_ids.cpu().numpy().astype(np.int32)
|
|
t5_mask_np = t5_attn_mask.cpu().numpy().astype(np.int32)
|
|
|
|
# Per-sample storage (like info.text_encoder_outputs)
|
|
per_sample_te = [(pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]) for i in range(len(captions))]
|
|
|
|
print(f"\n[3.3] Simulating latent caching phase...")
|
|
per_sample_latents = []
|
|
for img_path, _ in pairs:
|
|
img = Image.open(img_path).convert("RGB")
|
|
img_np = np.array(img)
|
|
img_tensor = IMAGE_TRANSFORMS(img_np).unsqueeze(0).unsqueeze(2) # (1,C,1,H,W)
|
|
img_tensor = img_tensor.to(device, dtype=vae_dtype)
|
|
with torch.no_grad():
|
|
lat = vae.encode(img_tensor, vae_scale).cpu()
|
|
per_sample_latents.append(lat.squeeze(0)) # (C,1,H,W)
|
|
print(f" {os.path.basename(img_path)}: latent shape={tuple(lat.shape)}")
|
|
|
|
# --- Simulate batch construction (__getitem__) ---
|
|
print(f"\n[3.4] Simulating batch construction...")
|
|
|
|
# Use first image's latents only (images may have different resolutions)
|
|
latents_batch = per_sample_latents[0].unsqueeze(0) # (1,C,1,H,W)
|
|
print(f" Using first image latent for simulation: shape={tuple(latents_batch.shape)}")
|
|
|
|
# Stack text encoder outputs (none_or_stack_elements)
|
|
text_encoder_outputs_list = []
|
|
for elem_idx in range(4):
|
|
arrays = [s[elem_idx] for s in per_sample_te]
|
|
text_encoder_outputs_list.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
|
|
|
# input_ids_list is None when caching
|
|
input_ids_list = None
|
|
|
|
batch = {
|
|
"latents": latents_batch,
|
|
"text_encoder_outputs_list": text_encoder_outputs_list,
|
|
"input_ids_list": input_ids_list,
|
|
"loss_weights": torch.ones(len(captions)),
|
|
}
|
|
|
|
print(f" batch keys: {list(batch.keys())}")
|
|
print(f" batch['latents']: shape={tuple(batch['latents'].shape)}")
|
|
print(f" batch['text_encoder_outputs_list']: {len(batch['text_encoder_outputs_list'])} tensors")
|
|
print(f" batch['input_ids_list']: {batch['input_ids_list']}")
|
|
|
|
# --- Simulate process_batch logic ---
|
|
print(f"\n[3.5] Simulating process_batch logic...")
|
|
|
|
text_encoder_conds = []
|
|
te_out = batch.get("text_encoder_outputs_list", None)
|
|
if te_out is not None:
|
|
text_encoder_conds = te_out
|
|
print(f" text_encoder_conds loaded from cache: {len(text_encoder_conds)} tensors")
|
|
else:
|
|
print(f" text_encoder_conds: empty (no cache)")
|
|
|
|
# The critical condition
|
|
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
|
|
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
|
|
|
|
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
|
|
cond_new = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_FALSE
|
|
|
|
print(f"\n === CRITICAL CONDITION CHECK ===")
|
|
print(f" len(text_encoder_conds) == 0 : {len(text_encoder_conds) == 0}")
|
|
print(f" text_encoder_conds[0] is None: {text_encoder_conds[0] is None}")
|
|
print(f" train_text_encoder (OLD=True) : {train_text_encoder_TRUE}")
|
|
print(f" train_text_encoder (NEW=False): {train_text_encoder_FALSE}")
|
|
print(f"")
|
|
print(f" Condition with OLD behavior (no override): {cond_old}")
|
|
msg = (
|
|
"ENTERS re-encode block -> accesses batch['input_ids_list'] -> CRASH!"
|
|
if cond_old
|
|
else "SKIPS re-encode block -> uses cache -> OK"
|
|
)
|
|
|
|
print(f" -> {msg}")
|
|
print(f" Condition with NEW behavior (override): {cond_new}")
|
|
print(f" -> {'ENTERS re-encode block' if cond_new else 'SKIPS re-encode block -> uses cache -> OK'}")
|
|
|
|
if cond_old and not cond_new:
|
|
print(f"\n ** CONFIRMED: the is_train_text_encoder override fixes the crash **")
|
|
|
|
# Simulate the rest of process_batch
|
|
print(f"\n[3.6] Simulating get_noise_pred_and_target unpack...")
|
|
try:
|
|
pe, am, t5_ids, t5_mask = text_encoder_conds
|
|
pe = pe.to(device, dtype=te_dtype)
|
|
am = am.to(device)
|
|
t5_ids = t5_ids.to(device, dtype=torch.long)
|
|
t5_mask = t5_mask.to(device)
|
|
|
|
print(f" Unpack + device transfer OK:")
|
|
print_tensor_info("prompt_embeds", pe)
|
|
print_tensor_info("attn_mask", am)
|
|
print_tensor_info("t5_input_ids", t5_ids)
|
|
print_tensor_info("t5_attn_mask", t5_mask)
|
|
|
|
# Verify t5_input_ids didn't get corrupted by float conversion
|
|
t5_ids_orig = torch.tensor(t5_ids_np, dtype=torch.long, device=device)
|
|
id_match = torch.all(t5_ids == t5_ids_orig).item()
|
|
print(f"\n t5_input_ids integrity (float->long roundtrip): {'OK' if id_match else '** MISMATCH **'}")
|
|
if not id_match:
|
|
diff_count = (t5_ids != t5_ids_orig).sum().item()
|
|
print(f" {diff_count} token IDs differ!")
|
|
# Show example
|
|
idx = torch.where(t5_ids != t5_ids_orig)
|
|
if len(idx[0]) > 0:
|
|
i, j = idx[0][0].item(), idx[1][0].item()
|
|
print(f" Example: position [{i},{j}] original={t5_ids_orig[i,j].item()} loaded={t5_ids[i,j].item()}")
|
|
|
|
except Exception as e:
|
|
print(f" ** ERROR: {e}")
|
|
traceback.print_exc()
|
|
|
|
# Cleanup
|
|
vae.to("cpu")
|
|
qwen3_model.to("cpu")
|
|
del vae, qwen3_model
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
|
print("\n[3.7] Full batch simulation DONE.")
|
|
|
|
|
|
# Main
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
|
|
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
|
|
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
|
|
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
|
|
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
|
|
parser.add_argument("--qwen3_max_length", type=int, default=512)
|
|
parser.add_argument("--t5_max_length", type=int, default=512)
|
|
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
|
|
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
|
|
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
|
|
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
|
|
args = parser.parse_args()
|
|
|
|
# Find pairs
|
|
pairs = find_image_caption_pairs(args.image_dir)
|
|
if len(pairs) == 0:
|
|
print(f"ERROR: No image+txt pairs found in {args.image_dir}")
|
|
print("Expected: image.png + image.txt, image.jpg + image.txt, etc.")
|
|
sys.exit(1)
|
|
|
|
print(f"Found {len(pairs)} image-caption pairs:")
|
|
for img_path, cap in pairs:
|
|
print(f" {os.path.basename(img_path)}: \"{cap[:60]}{'...' if len(cap) > 60 else ''}\"")
|
|
|
|
results = {}
|
|
|
|
if not args.skip_latent:
|
|
try:
|
|
test_latent_cache(args, pairs)
|
|
results["latent_cache"] = "PASS"
|
|
except Exception as e:
|
|
print(f"\n** LATENT CACHE TEST FAILED: {e}")
|
|
traceback.print_exc()
|
|
results["latent_cache"] = f"FAIL: {e}"
|
|
|
|
if not args.skip_text:
|
|
try:
|
|
test_text_encoder_cache(args, pairs)
|
|
results["text_encoder_cache"] = "PASS"
|
|
except Exception as e:
|
|
print(f"\n** TEXT ENCODER CACHE TEST FAILED: {e}")
|
|
traceback.print_exc()
|
|
results["text_encoder_cache"] = f"FAIL: {e}"
|
|
|
|
if not args.skip_full:
|
|
try:
|
|
test_full_batch_simulation(args, pairs)
|
|
results["full_batch_sim"] = "PASS"
|
|
except Exception as e:
|
|
print(f"\n** FULL BATCH SIMULATION FAILED: {e}")
|
|
traceback.print_exc()
|
|
results["full_batch_sim"] = f"FAIL: {e}"
|
|
|
|
# Summary
|
|
print("\n" + "=" * 70)
|
|
print("SUMMARY")
|
|
print("=" * 70)
|
|
for test, result in results.items():
|
|
status = "OK" if result == "PASS" else "FAIL"
|
|
print(f" [{status}] {test}: {result}")
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|