mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
import torch
|
|
from torch.nn.modules import conv
|
|
|
|
from library import lumina_util
|
|
|
|
|
|
def test_unpack_latents():
|
|
# Create a test tensor
|
|
# Shape: [batch, height*width, channels*patch_height*patch_width]
|
|
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
|
|
packed_latent_height = 2
|
|
packed_latent_width = 2
|
|
|
|
# Unpack the latents
|
|
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
|
|
|
# Check output shape
|
|
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
|
|
assert unpacked.shape == (2, 4, 4, 4)
|
|
|
|
|
|
def test_pack_latents():
|
|
# Create a test tensor
|
|
# Shape: [batch, channels, height*patch_height, width*patch_width]
|
|
x = torch.randn(2, 4, 4, 4)
|
|
|
|
# Pack the latents
|
|
packed = lumina_util.pack_latents(x)
|
|
|
|
# Check output shape
|
|
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
|
|
assert packed.shape == (2, 4, 16)
|
|
|
|
|
|
def test_convert_diffusers_sd_to_alpha_vllm():
|
|
num_double_blocks = 2
|
|
# Predefined test cases based on the actual conversion map
|
|
test_cases = [
|
|
# Static key conversions with possible list mappings
|
|
{
|
|
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
|
|
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
|
|
"expected_converted_keys": ["cap_embedder.0.weight"],
|
|
},
|
|
{
|
|
"original_keys": ["patch_embedder.proj.weight"],
|
|
"original_pattern": ["patch_embedder.proj.weight"],
|
|
"expected_converted_keys": ["x_embedder.weight"],
|
|
},
|
|
{
|
|
"original_keys": ["transformer_blocks.0.norm1.weight"],
|
|
"original_pattern": ["transformer_blocks.().norm1.weight"],
|
|
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
|
|
},
|
|
]
|
|
|
|
|
|
for test_case in test_cases:
|
|
for original_key, original_pattern, expected_converted_key in zip(
|
|
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
|
|
):
|
|
# Create test state dict
|
|
test_sd = {original_key: torch.randn(10, 10)}
|
|
|
|
# Convert the state dict
|
|
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
|
|
|
# Verify conversion (handle both string and list keys)
|
|
# Find the correct converted key
|
|
match_found = False
|
|
if expected_converted_key in converted_sd:
|
|
# Verify tensor preservation
|
|
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
|
|
f"Tensor mismatch for {original_key}"
|
|
)
|
|
match_found = True
|
|
break
|
|
|
|
assert match_found, f"Failed to convert {original_key}"
|
|
|
|
# Ensure original key is also present
|
|
assert original_key in converted_sd
|
|
|
|
# Test with block-specific keys
|
|
block_specific_cases = [
|
|
{
|
|
"original_pattern": "transformer_blocks.().norm1.weight",
|
|
"converted_pattern": "layers.().attention_norm1.weight",
|
|
}
|
|
]
|
|
|
|
for case in block_specific_cases:
|
|
for block_idx in range(2): # Test multiple block indices
|
|
# Prepare block-specific keys
|
|
block_original_key = case["original_pattern"].replace("()", str(block_idx))
|
|
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
|
|
print(block_original_key, block_converted_key)
|
|
|
|
# Create test state dict
|
|
test_sd = {block_original_key: torch.randn(10, 10)}
|
|
|
|
# Convert the state dict
|
|
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
|
|
|
# Verify conversion
|
|
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
|
|
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
|
|
f"Tensor mismatch for block key {block_original_key}"
|
|
)
|
|
|
|
# Ensure original key is also present
|
|
assert block_original_key in converted_sd
|