Files
Kohya-ss-sd-scripts/tests/test_lumina_train_network.py

178 lines
6.7 KiB
Python

import pytest
import torch
from unittest.mock import MagicMock, patch
import argparse
from library import lumina_models, lumina_util
from lumina_train_network import LuminaNetworkTrainer
@pytest.fixture
def lumina_trainer():
return LuminaNetworkTrainer()
@pytest.fixture
def mock_args():
args = MagicMock()
args.pretrained_model_name_or_path = "test_path"
args.disable_mmap_load_safetensors = False
args.use_flash_attn = False
args.use_sage_attn = False
args.fp8_base = False
args.blocks_to_swap = None
args.gemma2 = "test_gemma2_path"
args.ae = "test_ae_path"
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = False
args.network_train_unet_only = False
return args
@pytest.fixture
def mock_accelerator():
accelerator = MagicMock()
accelerator.device = torch.device("cpu")
accelerator.prepare.side_effect = lambda x, **kwargs: x
accelerator.unwrap_model.side_effect = lambda x: x
return accelerator
def test_assert_extra_args(lumina_trainer, mock_args):
train_dataset_group = MagicMock()
train_dataset_group.verify_bucket_reso_steps = MagicMock()
val_dataset_group = MagicMock()
val_dataset_group.verify_bucket_reso_steps = MagicMock()
# Test with default settings
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
# Verify verify_bucket_reso_steps was called for both groups
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
# Check text encoder output caching
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
assert mock_args.cache_text_encoder_outputs is True
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
# Patch lumina_util methods
with (
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
patch("library.lumina_util.load_ae") as mock_load_ae,
):
# Create mock models
mock_model = MagicMock(spec=lumina_models.NextDiT)
mock_model.dtype = torch.float32
mock_gemma2 = MagicMock()
mock_ae = MagicMock()
mock_load_lumina_model.return_value = mock_model
mock_load_gemma2.return_value = mock_gemma2
mock_load_ae.return_value = mock_ae
# Test load_target_model
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
# Verify calls and return values
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
assert gemma2_list == [mock_gemma2]
assert ae == mock_ae
assert model == mock_model
# Verify load calls
mock_load_lumina_model.assert_called_once()
mock_load_gemma2.assert_called_once()
mock_load_ae.assert_called_once()
def test_get_strategies(lumina_trainer, mock_args):
# Test tokenize strategy
try:
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
# Test latents caching strategy
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
# Test text encoding strategy
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
# Call assert_extra_args to set train_gemma2
train_dataset_group = MagicMock()
train_dataset_group.verify_bucket_reso_steps = MagicMock()
val_dataset_group = MagicMock()
val_dataset_group.verify_bucket_reso_steps = MagicMock()
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
# With text encoder caching enabled
mock_args.skip_cache_check = False
mock_args.text_encoder_batch_size = 16
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
assert strategy.cache_to_disk is False # based on mock_args
# With text encoder caching disabled
mock_args.cache_text_encoder_outputs = False
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
assert strategy is None
def test_noise_scheduler(lumina_trainer, mock_args):
device = torch.device("cpu")
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
assert noise_scheduler.num_train_timesteps == 1000
assert hasattr(lumina_trainer, "noise_scheduler_copy")
def test_sai_model_spec(lumina_trainer, mock_args):
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
mock_get_spec.return_value = "test_spec"
spec = lumina_trainer.get_sai_model_spec(mock_args)
assert spec == "test_spec"
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
def test_update_metadata(lumina_trainer, mock_args):
metadata = {}
lumina_trainer.update_metadata(metadata, mock_args)
assert "ss_weighting_scheme" in metadata
assert "ss_logit_mean" in metadata
assert "ss_logit_std" in metadata
assert "ss_mode_scale" in metadata
assert "ss_timestep_sampling" in metadata
assert "ss_sigmoid_scale" in metadata
assert "ss_model_prediction_type" in metadata
assert "ss_discrete_flow_shift" in metadata
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
# Test with text encoder output caching, but not training text encoder
mock_args.cache_text_encoder_outputs = True
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False):
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is True
# Test with text encoder output caching and training text encoder
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True):
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is False
# Test with no text encoder output caching
mock_args.cache_text_encoder_outputs = False
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is False