mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
178 lines
6.7 KiB
Python
178 lines
6.7 KiB
Python
import pytest
|
|
import torch
|
|
from unittest.mock import MagicMock, patch
|
|
import argparse
|
|
|
|
from library import lumina_models, lumina_util
|
|
from lumina_train_network import LuminaNetworkTrainer
|
|
|
|
|
|
@pytest.fixture
|
|
def lumina_trainer():
|
|
return LuminaNetworkTrainer()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_args():
|
|
args = MagicMock()
|
|
args.pretrained_model_name_or_path = "test_path"
|
|
args.disable_mmap_load_safetensors = False
|
|
args.use_flash_attn = False
|
|
args.use_sage_attn = False
|
|
args.fp8_base = False
|
|
args.blocks_to_swap = None
|
|
args.gemma2 = "test_gemma2_path"
|
|
args.ae = "test_ae_path"
|
|
args.cache_text_encoder_outputs = True
|
|
args.cache_text_encoder_outputs_to_disk = False
|
|
args.network_train_unet_only = False
|
|
return args
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_accelerator():
|
|
accelerator = MagicMock()
|
|
accelerator.device = torch.device("cpu")
|
|
accelerator.prepare.side_effect = lambda x, **kwargs: x
|
|
accelerator.unwrap_model.side_effect = lambda x: x
|
|
return accelerator
|
|
|
|
|
|
def test_assert_extra_args(lumina_trainer, mock_args):
|
|
train_dataset_group = MagicMock()
|
|
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
|
val_dataset_group = MagicMock()
|
|
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
|
|
|
# Test with default settings
|
|
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
|
|
|
# Verify verify_bucket_reso_steps was called for both groups
|
|
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
|
|
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
|
|
|
|
# Check text encoder output caching
|
|
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
|
|
assert mock_args.cache_text_encoder_outputs is True
|
|
|
|
|
|
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
|
|
# Patch lumina_util methods
|
|
with (
|
|
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
|
|
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
|
|
patch("library.lumina_util.load_ae") as mock_load_ae,
|
|
):
|
|
# Create mock models
|
|
mock_model = MagicMock(spec=lumina_models.NextDiT)
|
|
mock_model.dtype = torch.float32
|
|
mock_gemma2 = MagicMock()
|
|
mock_ae = MagicMock()
|
|
|
|
mock_load_lumina_model.return_value = mock_model
|
|
mock_load_gemma2.return_value = mock_gemma2
|
|
mock_load_ae.return_value = mock_ae
|
|
|
|
# Test load_target_model
|
|
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
|
|
|
|
# Verify calls and return values
|
|
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
|
|
assert gemma2_list == [mock_gemma2]
|
|
assert ae == mock_ae
|
|
assert model == mock_model
|
|
|
|
# Verify load calls
|
|
mock_load_lumina_model.assert_called_once()
|
|
mock_load_gemma2.assert_called_once()
|
|
mock_load_ae.assert_called_once()
|
|
|
|
|
|
def test_get_strategies(lumina_trainer, mock_args):
|
|
# Test tokenize strategy
|
|
try:
|
|
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
|
|
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
|
|
except OSError as e:
|
|
# If the tokenizer is not found (due to gated repo), we can skip the test
|
|
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
|
|
|
# Test latents caching strategy
|
|
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
|
|
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
|
|
|
|
# Test text encoding strategy
|
|
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
|
|
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
|
|
|
|
|
|
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
|
|
# Call assert_extra_args to set train_gemma2
|
|
train_dataset_group = MagicMock()
|
|
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
|
val_dataset_group = MagicMock()
|
|
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
|
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
|
|
|
# With text encoder caching enabled
|
|
mock_args.skip_cache_check = False
|
|
mock_args.text_encoder_batch_size = 16
|
|
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
|
|
|
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
|
|
assert strategy.cache_to_disk is False # based on mock_args
|
|
|
|
# With text encoder caching disabled
|
|
mock_args.cache_text_encoder_outputs = False
|
|
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
|
assert strategy is None
|
|
|
|
|
|
def test_noise_scheduler(lumina_trainer, mock_args):
|
|
device = torch.device("cpu")
|
|
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
|
|
|
|
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
|
|
assert noise_scheduler.num_train_timesteps == 1000
|
|
assert hasattr(lumina_trainer, "noise_scheduler_copy")
|
|
|
|
|
|
def test_sai_model_spec(lumina_trainer, mock_args):
|
|
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
|
|
mock_get_spec.return_value = "test_spec"
|
|
spec = lumina_trainer.get_sai_model_spec(mock_args)
|
|
assert spec == "test_spec"
|
|
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
|
|
|
|
|
|
def test_update_metadata(lumina_trainer, mock_args):
|
|
metadata = {}
|
|
lumina_trainer.update_metadata(metadata, mock_args)
|
|
|
|
assert "ss_weighting_scheme" in metadata
|
|
assert "ss_logit_mean" in metadata
|
|
assert "ss_logit_std" in metadata
|
|
assert "ss_mode_scale" in metadata
|
|
assert "ss_timestep_sampling" in metadata
|
|
assert "ss_sigmoid_scale" in metadata
|
|
assert "ss_model_prediction_type" in metadata
|
|
assert "ss_discrete_flow_shift" in metadata
|
|
|
|
|
|
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
|
|
# Test with text encoder output caching, but not training text encoder
|
|
mock_args.cache_text_encoder_outputs = True
|
|
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False):
|
|
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
|
assert result is True
|
|
|
|
# Test with text encoder output caching and training text encoder
|
|
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True):
|
|
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
|
assert result is False
|
|
|
|
# Test with no text encoder output caching
|
|
mock_args.cache_text_encoder_outputs = False
|
|
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
|
assert result is False
|