Files
Kohya-ss-sd-scripts/tests/library/test_lumina_util.py
2025-06-09 21:14:51 -04:00

113 lines
4.2 KiB
Python

import torch
from torch.nn.modules import conv
from library import lumina_util
def test_unpack_latents():
# Create a test tensor
# Shape: [batch, height*width, channels*patch_height*patch_width]
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
packed_latent_height = 2
packed_latent_width = 2
# Unpack the latents
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# Check output shape
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
assert unpacked.shape == (2, 4, 4, 4)
def test_pack_latents():
# Create a test tensor
# Shape: [batch, channels, height*patch_height, width*patch_width]
x = torch.randn(2, 4, 4, 4)
# Pack the latents
packed = lumina_util.pack_latents(x)
# Check output shape
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
assert packed.shape == (2, 4, 16)
def test_convert_diffusers_sd_to_alpha_vllm():
num_double_blocks = 2
# Predefined test cases based on the actual conversion map
test_cases = [
# Static key conversions with possible list mappings
{
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
"expected_converted_keys": ["cap_embedder.0.weight"],
},
{
"original_keys": ["patch_embedder.proj.weight"],
"original_pattern": ["patch_embedder.proj.weight"],
"expected_converted_keys": ["x_embedder.weight"],
},
{
"original_keys": ["transformer_blocks.0.norm1.weight"],
"original_pattern": ["transformer_blocks.().norm1.weight"],
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
},
]
for test_case in test_cases:
for original_key, original_pattern, expected_converted_key in zip(
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
):
# Create test state dict
test_sd = {original_key: torch.randn(10, 10)}
# Convert the state dict
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
# Verify conversion (handle both string and list keys)
# Find the correct converted key
match_found = False
if expected_converted_key in converted_sd:
# Verify tensor preservation
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
f"Tensor mismatch for {original_key}"
)
match_found = True
break
assert match_found, f"Failed to convert {original_key}"
# Ensure original key is also present
assert original_key in converted_sd
# Test with block-specific keys
block_specific_cases = [
{
"original_pattern": "transformer_blocks.().norm1.weight",
"converted_pattern": "layers.().attention_norm1.weight",
}
]
for case in block_specific_cases:
for block_idx in range(2): # Test multiple block indices
# Prepare block-specific keys
block_original_key = case["original_pattern"].replace("()", str(block_idx))
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
print(block_original_key, block_converted_key)
# Create test state dict
test_sd = {block_original_key: torch.randn(10, 10)}
# Convert the state dict
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
# Verify conversion
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
f"Tensor mismatch for block key {block_original_key}"
)
# Ensure original key is also present
assert block_original_key in converted_sd