diff --git a/tests/library/test_sai_model_spec.py b/tests/library/test_sai_model_spec.py new file mode 100644 index 00000000..92dcf4c6 --- /dev/null +++ b/tests/library/test_sai_model_spec.py @@ -0,0 +1,349 @@ +"""Tests for sai_model_spec module.""" +import pytest +import time + +from library import sai_model_spec + + +class MockArgs: + """Mock argparse.Namespace for testing.""" + + def __init__(self, **kwargs): + # Default values + self.v2 = False + self.v_parameterization = False + self.resolution = 512 + self.metadata_title = None + self.metadata_author = None + self.metadata_description = None + self.metadata_license = None + self.metadata_tags = None + self.min_timestep = None + self.max_timestep = None + self.clip_skip = None + self.output_name = "test_output" + + # Override with provided values + for key, value in kwargs.items(): + setattr(self, key, value) + + +class TestModelSpecMetadata: + """Test the ModelSpecMetadata dataclass.""" + + def test_creation_and_conversion(self): + """Test creating dataclass and converting to metadata dict.""" + metadata = sai_model_spec.ModelSpecMetadata( + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + author="Test Author", + description=None # Test None exclusion + ) + + assert metadata.architecture == "stable-diffusion-v1" + assert metadata.sai_model_spec == "1.0.1" + + metadata_dict = metadata.to_metadata_dict() + assert "modelspec.architecture" in metadata_dict + assert "modelspec.author" in metadata_dict + assert "modelspec.description" not in metadata_dict # None values excluded + assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" + + def test_additional_fields_handling(self): + """Test handling of additional metadata fields.""" + additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"} + + metadata = sai_model_spec.ModelSpecMetadata( + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + additional_fields=additional + ) + + metadata_dict = metadata.to_metadata_dict() + assert "modelspec.custom_field" in metadata_dict + assert "modelspec.prefixed" in metadata_dict + assert metadata_dict["modelspec.custom_field"] == "custom_value" + + def test_from_args_extraction(self): + """Test creating ModelSpecMetadata from args with metadata_* fields.""" + args = MockArgs( + metadata_author="Test Author", + metadata_trigger_phrase="anime style", + metadata_usage_hint="Use CFG 7.5" + ) + + metadata = sai_model_spec.ModelSpecMetadata.from_args( + args, + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model" + ) + + assert metadata.author == "Test Author" + assert metadata.additional_fields["trigger_phrase"] == "anime style" + assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5" + + +class TestArchitectureDetection: + """Test architecture detection for different model types.""" + + @pytest.mark.parametrize("config,expected", [ + ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, "stable-diffusion-3-large"), + ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), + ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), + ]) + def test_architecture_detection(self, config, expected): + """Test architecture detection for various model configurations.""" + model_config = config.pop("model_config", None) + arch = sai_model_spec.determine_architecture( + lora=False, textual_inversion=False, model_config=model_config, **config + ) + assert arch == expected + + def test_adapter_suffixes(self): + """Test LoRA and textual inversion suffixes.""" + lora_arch = sai_model_spec.determine_architecture( + v2=False, v_parameterization=False, sdxl=True, + lora=True, textual_inversion=False + ) + assert lora_arch == "stable-diffusion-xl-v1-base/lora" + + ti_arch = sai_model_spec.determine_architecture( + v2=False, v_parameterization=False, sdxl=False, + lora=False, textual_inversion=True + ) + assert ti_arch == "stable-diffusion-v1/textual-inversion" + + +class TestImplementationDetection: + """Test implementation detection for different model types.""" + + @pytest.mark.parametrize("config,expected", [ + ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), + ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), + ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), + ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), + ({"lora": True, "sdxl": False}, "diffusers"), + ]) + def test_implementation_detection(self, config, expected): + """Test implementation detection for various configurations.""" + model_config = config.pop("model_config", None) + impl = sai_model_spec.determine_implementation( + lora=config.get("lora", False), + textual_inversion=False, + sdxl=config.get("sdxl", False), + model_config=model_config + ) + assert impl == expected + + +class TestResolutionHandling: + """Test resolution parsing and defaults.""" + + @pytest.mark.parametrize("input_reso,expected", [ + ((768, 1024), "768x1024"), + (768, "768x768"), + ("768,1024", "768x1024"), + ]) + def test_explicit_resolution_formats(self, input_reso, expected): + """Test different resolution input formats.""" + res = sai_model_spec.determine_resolution(reso=input_reso) + assert res == expected + + @pytest.mark.parametrize("config,expected", [ + ({"sdxl": True}, "1024x1024"), + ({"model_config": {"flux": "dev"}}, "1024x1024"), + ({"v2": True, "v_parameterization": True}, "768x768"), + ({}, "512x512"), # Default SD v1 + ]) + def test_default_resolutions(self, config, expected): + """Test default resolution detection.""" + model_config = config.pop("model_config", None) + res = sai_model_spec.determine_resolution(model_config=model_config, **config) + assert res == expected + + +class TestThumbnailProcessing: + """Test thumbnail data URL processing.""" + + def test_file_to_data_url(self): + """Test converting file to data URL.""" + import tempfile + import os + + # Create a tiny test PNG (1x1 pixel) + test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' + + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + f.write(test_png_data) + temp_path = f.name + + try: + data_url = sai_model_spec.file_to_data_url(temp_path) + + # Check format + assert data_url.startswith("data:image/png;base64,") + + # Check it's a reasonable length (base64 encoded) + assert len(data_url) > 50 + + # Verify we can decode it back + import base64 + encoded_part = data_url.split(",", 1)[1] + decoded_data = base64.b64decode(encoded_part) + assert decoded_data == test_png_data + + finally: + os.unlink(temp_path) + + def test_file_to_data_url_nonexistent_file(self): + """Test error handling for nonexistent files.""" + import pytest + + with pytest.raises(FileNotFoundError): + sai_model_spec.file_to_data_url("/nonexistent/file.png") + + def test_thumbnail_processing_in_metadata(self): + """Test thumbnail processing in build_metadata_dataclass.""" + import tempfile + import os + + # Create a test image file + test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' + + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + f.write(test_png_data) + temp_path = f.name + + try: + timestamp = time.time() + + # Test with file path - should be converted to data URL + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": temp_path} + ) + + # Should be converted to data URL + assert "thumbnail" in metadata.additional_fields + assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,") + + finally: + os.unlink(temp_path) + + def test_thumbnail_data_url_passthrough(self): + """Test that existing data URLs are passed through unchanged.""" + timestamp = time.time() + + existing_data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": existing_data_url} + ) + + # Should be unchanged + assert metadata.additional_fields["thumbnail"] == existing_data_url + + def test_thumbnail_invalid_file_handling(self): + """Test graceful handling of invalid thumbnail files.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": "/nonexistent/file.png"} + ) + + # Should be removed from additional_fields due to error + assert "thumbnail" not in metadata.additional_fields + + +class TestBuildMetadataIntegration: + """Test the complete metadata building workflow.""" + + def test_sdxl_model_workflow(self): + """Test complete workflow for SDXL model.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=True, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test SDXL Model" + ) + + assert metadata.architecture == "stable-diffusion-xl-v1-base" + assert metadata.implementation == "https://github.com/Stability-AI/generative-models" + assert metadata.resolution == "1024x1024" + assert metadata.prediction_type == "epsilon" + + def test_flux_model_workflow(self): + """Test complete workflow for Flux model.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Flux Model", + model_config={"flux": "dev"}, + optional_metadata={"trigger_phrase": "anime style"} + ) + + assert metadata.architecture == "flux-1-dev" + assert metadata.implementation == "https://github.com/black-forest-labs/flux" + assert metadata.prediction_type is None # Flux doesn't use prediction_type + assert metadata.additional_fields["trigger_phrase"] == "anime style" + + def test_legacy_function_compatibility(self): + """Test that legacy build_metadata function works correctly.""" + timestamp = time.time() + + metadata_dict = sai_model_spec.build_metadata( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=True, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model" + ) + + assert isinstance(metadata_dict, dict) + assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" + assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" \ No newline at end of file