diff --git a/fine_tune.py b/fine_tune.py index e1ed4749..ffbbbb09 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -519,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/flux_train.py b/flux_train.py index 84db34cf..4aa67220 100644 --- a/flux_train.py +++ b/flux_train.py @@ -30,7 +30,7 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util @@ -787,6 +787,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 93c20dab..01991405 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -32,6 +32,7 @@ init_ipex() from accelerate.utils import set_seed import library.train_util as train_util +import library.sai_model_spec as sai_model_spec from library import ( deepspeed_utils, flux_train_utils, @@ -820,6 +821,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/lumina_train.py b/lumina_train.py index a333427d..ca60c658 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -31,6 +31,7 @@ from library import ( lumina_util, strategy_base, strategy_lumina, + sai_model_spec ) from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler @@ -904,6 +905,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sd3_train.py b/sd3_train.py index 3bff6a50..355e13dd 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -20,6 +20,8 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 + +import library.sai_model_spec as sai_model_spec from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -986,6 +988,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train.py b/sdxl_train.py index a60f6df6..f454263a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec import library.train_util as train_util @@ -893,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index c6e8136f..3d107e57 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -25,6 +25,7 @@ from library import ( strategy_base, strategy_sd, strategy_sdxl, + sai_model_spec ) import library.train_util as train_util @@ -664,6 +665,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) # train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 00e51a67..4dd4b8d9 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -32,6 +32,7 @@ from library import ( strategy_base, strategy_sd, strategy_sdxl, + sai_model_spec, ) import library.model_util as model_util @@ -589,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 63457cc6..0a9f4a92 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -24,6 +24,7 @@ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_origi import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -536,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/tests/library/test_sai_model_spec.py b/tests/library/test_sai_model_spec.py index 92dcf4c6..0bbfa116 100644 --- a/tests/library/test_sai_model_spec.py +++ b/tests/library/test_sai_model_spec.py @@ -1,4 +1,5 @@ """Tests for sai_model_spec module.""" + import pytest import time @@ -7,7 +8,7 @@ from library import sai_model_spec class MockArgs: """Mock argparse.Namespace for testing.""" - + def __init__(self, **kwargs): # Default values self.v2 = False @@ -22,7 +23,7 @@ class MockArgs: self.max_timestep = None self.clip_skip = None self.output_name = "test_output" - + # Override with provided values for key, value in kwargs.items(): setattr(self, key, value) @@ -30,57 +31,56 @@ class MockArgs: class TestModelSpecMetadata: """Test the ModelSpecMetadata dataclass.""" - + def test_creation_and_conversion(self): """Test creating dataclass and converting to metadata dict.""" metadata = sai_model_spec.ModelSpecMetadata( architecture="stable-diffusion-v1", implementation="diffusers", title="Test Model", + resolution="512x512", author="Test Author", - description=None # Test None exclusion + description=None, # Test None exclusion ) - + assert metadata.architecture == "stable-diffusion-v1" assert metadata.sai_model_spec == "1.0.1" - + metadata_dict = metadata.to_metadata_dict() assert "modelspec.architecture" in metadata_dict assert "modelspec.author" in metadata_dict assert "modelspec.description" not in metadata_dict # None values excluded assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" - + def test_additional_fields_handling(self): """Test handling of additional metadata fields.""" additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"} - + metadata = sai_model_spec.ModelSpecMetadata( architecture="stable-diffusion-v1", implementation="diffusers", title="Test Model", - additional_fields=additional + resolution="512x512", + additional_fields=additional, ) - + metadata_dict = metadata.to_metadata_dict() assert "modelspec.custom_field" in metadata_dict assert "modelspec.prefixed" in metadata_dict assert metadata_dict["modelspec.custom_field"] == "custom_value" - + def test_from_args_extraction(self): """Test creating ModelSpecMetadata from args with metadata_* fields.""" - args = MockArgs( - metadata_author="Test Author", - metadata_trigger_phrase="anime style", - metadata_usage_hint="Use CFG 7.5" - ) - + args = MockArgs(metadata_author="Test Author", metadata_trigger_phrase="anime style", metadata_usage_hint="Use CFG 7.5") + metadata = sai_model_spec.ModelSpecMetadata.from_args( args, architecture="stable-diffusion-v1", implementation="diffusers", - title="Test Model" + title="Test Model", + resolution="512x512", ) - + assert metadata.author == "Test Author" assert metadata.additional_fields["trigger_phrase"] == "anime style" assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5" @@ -88,79 +88,87 @@ class TestModelSpecMetadata: class TestArchitectureDetection: """Test architecture detection for different model types.""" - - @pytest.mark.parametrize("config,expected", [ - ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), - ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), - ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), - ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, "stable-diffusion-3-large"), - ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), - ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), - ]) + + @pytest.mark.parametrize( + "config,expected", + [ + ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), + ( + {"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, + "stable-diffusion-3-large", + ), + ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), + ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), + ], + ) def test_architecture_detection(self, config, expected): """Test architecture detection for various model configurations.""" model_config = config.pop("model_config", None) - arch = sai_model_spec.determine_architecture( - lora=False, textual_inversion=False, model_config=model_config, **config - ) + arch = sai_model_spec.determine_architecture(lora=False, textual_inversion=False, model_config=model_config, **config) assert arch == expected - + def test_adapter_suffixes(self): """Test LoRA and textual inversion suffixes.""" lora_arch = sai_model_spec.determine_architecture( - v2=False, v_parameterization=False, sdxl=True, - lora=True, textual_inversion=False + v2=False, v_parameterization=False, sdxl=True, lora=True, textual_inversion=False ) assert lora_arch == "stable-diffusion-xl-v1-base/lora" - + ti_arch = sai_model_spec.determine_architecture( - v2=False, v_parameterization=False, sdxl=False, - lora=False, textual_inversion=True + v2=False, v_parameterization=False, sdxl=False, lora=False, textual_inversion=True ) assert ti_arch == "stable-diffusion-v1/textual-inversion" class TestImplementationDetection: """Test implementation detection for different model types.""" - - @pytest.mark.parametrize("config,expected", [ - ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), - ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), - ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), - ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), - ({"lora": True, "sdxl": False}, "diffusers"), - ]) + + @pytest.mark.parametrize( + "config,expected", + [ + ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), + ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), + ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), + ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), + ({"lora": True, "sdxl": False}, "diffusers"), + ], + ) def test_implementation_detection(self, config, expected): """Test implementation detection for various configurations.""" model_config = config.pop("model_config", None) impl = sai_model_spec.determine_implementation( - lora=config.get("lora", False), - textual_inversion=False, - sdxl=config.get("sdxl", False), - model_config=model_config + lora=config.get("lora", False), textual_inversion=False, sdxl=config.get("sdxl", False), model_config=model_config ) assert impl == expected class TestResolutionHandling: """Test resolution parsing and defaults.""" - - @pytest.mark.parametrize("input_reso,expected", [ - ((768, 1024), "768x1024"), - (768, "768x768"), - ("768,1024", "768x1024"), - ]) + + @pytest.mark.parametrize( + "input_reso,expected", + [ + ((768, 1024), "768x1024"), + (768, "768x768"), + ("768,1024", "768x1024"), + ], + ) def test_explicit_resolution_formats(self, input_reso, expected): """Test different resolution input formats.""" res = sai_model_spec.determine_resolution(reso=input_reso) assert res == expected - - @pytest.mark.parametrize("config,expected", [ - ({"sdxl": True}, "1024x1024"), - ({"model_config": {"flux": "dev"}}, "1024x1024"), - ({"v2": True, "v_parameterization": True}, "768x768"), - ({}, "512x512"), # Default SD v1 - ]) + + @pytest.mark.parametrize( + "config,expected", + [ + ({"sdxl": True}, "1024x1024"), + ({"model_config": {"flux": "dev"}}, "1024x1024"), + ({"v2": True, "v_parameterization": True}, "768x768"), + ({}, "512x512"), # Default SD v1 + ], + ) def test_default_resolutions(self, config, expected): """Test default resolution detection.""" model_config = config.pop("model_config", None) @@ -170,59 +178,60 @@ class TestResolutionHandling: class TestThumbnailProcessing: """Test thumbnail data URL processing.""" - + def test_file_to_data_url(self): """Test converting file to data URL.""" import tempfile import os - + # Create a tiny test PNG (1x1 pixel) - test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' - - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82" + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: f.write(test_png_data) temp_path = f.name - + try: data_url = sai_model_spec.file_to_data_url(temp_path) - + # Check format assert data_url.startswith("data:image/png;base64,") - + # Check it's a reasonable length (base64 encoded) assert len(data_url) > 50 - + # Verify we can decode it back import base64 + encoded_part = data_url.split(",", 1)[1] decoded_data = base64.b64decode(encoded_part) assert decoded_data == test_png_data - + finally: os.unlink(temp_path) - + def test_file_to_data_url_nonexistent_file(self): """Test error handling for nonexistent files.""" import pytest - + with pytest.raises(FileNotFoundError): sai_model_spec.file_to_data_url("/nonexistent/file.png") - + def test_thumbnail_processing_in_metadata(self): """Test thumbnail processing in build_metadata_dataclass.""" import tempfile import os - + # Create a test image file - test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82' - - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82" + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: f.write(test_png_data) temp_path = f.name - + try: timestamp = time.time() - + # Test with file path - should be converted to data URL metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, @@ -233,22 +242,24 @@ class TestThumbnailProcessing: textual_inversion=False, timestamp=timestamp, title="Test Model", - optional_metadata={"thumbnail": temp_path} + optional_metadata={"thumbnail": temp_path}, ) - + # Should be converted to data URL assert "thumbnail" in metadata.additional_fields assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,") - + finally: os.unlink(temp_path) - + def test_thumbnail_data_url_passthrough(self): """Test that existing data URLs are passed through unchanged.""" timestamp = time.time() - - existing_data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" - + + existing_data_url = ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + ) + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -258,16 +269,16 @@ class TestThumbnailProcessing: textual_inversion=False, timestamp=timestamp, title="Test Model", - optional_metadata={"thumbnail": existing_data_url} + optional_metadata={"thumbnail": existing_data_url}, ) - + # Should be unchanged assert metadata.additional_fields["thumbnail"] == existing_data_url - + def test_thumbnail_invalid_file_handling(self): """Test graceful handling of invalid thumbnail files.""" timestamp = time.time() - + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -277,20 +288,20 @@ class TestThumbnailProcessing: textual_inversion=False, timestamp=timestamp, title="Test Model", - optional_metadata={"thumbnail": "/nonexistent/file.png"} + optional_metadata={"thumbnail": "/nonexistent/file.png"}, ) - + # Should be removed from additional_fields due to error assert "thumbnail" not in metadata.additional_fields class TestBuildMetadataIntegration: """Test the complete metadata building workflow.""" - + def test_sdxl_model_workflow(self): """Test complete workflow for SDXL model.""" timestamp = time.time() - + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -299,18 +310,18 @@ class TestBuildMetadataIntegration: lora=False, textual_inversion=False, timestamp=timestamp, - title="Test SDXL Model" + title="Test SDXL Model", ) - + assert metadata.architecture == "stable-diffusion-xl-v1-base" assert metadata.implementation == "https://github.com/Stability-AI/generative-models" assert metadata.resolution == "1024x1024" assert metadata.prediction_type == "epsilon" - + def test_flux_model_workflow(self): """Test complete workflow for Flux model.""" timestamp = time.time() - + metadata = sai_model_spec.build_metadata_dataclass( state_dict=None, v2=False, @@ -321,18 +332,18 @@ class TestBuildMetadataIntegration: timestamp=timestamp, title="Test Flux Model", model_config={"flux": "dev"}, - optional_metadata={"trigger_phrase": "anime style"} + optional_metadata={"trigger_phrase": "anime style"}, ) - + assert metadata.architecture == "flux-1-dev" assert metadata.implementation == "https://github.com/black-forest-labs/flux" assert metadata.prediction_type is None # Flux doesn't use prediction_type assert metadata.additional_fields["trigger_phrase"] == "anime style" - + def test_legacy_function_compatibility(self): """Test that legacy build_metadata function works correctly.""" timestamp = time.time() - + metadata_dict = sai_model_spec.build_metadata( state_dict=None, v2=False, @@ -341,9 +352,9 @@ class TestBuildMetadataIntegration: lora=False, textual_inversion=False, timestamp=timestamp, - title="Test Model" + title="Test Model", ) - + assert isinstance(metadata_dict, dict) assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" - assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" \ No newline at end of file + assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 515ece98..5baddb5b 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -12,6 +12,7 @@ from tqdm import tqdm from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -161,6 +162,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 00459658..8e604292 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -22,6 +22,7 @@ from library import ( from library import train_util from library import sdxl_train_util from library import utils +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -188,6 +189,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_control_net.py b/train_control_net.py index ba016ac5..97cd1ebb 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -25,6 +25,7 @@ from safetensors.torch import load_file import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, diff --git a/train_db.py b/train_db.py index edd67403..4bf3b31c 100644 --- a/train_db.py +++ b/train_db.py @@ -22,6 +22,7 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -512,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_network.py b/train_network.py index aa42a3bf..e055f5d8 100644 --- a/train_network.py +++ b/train_network.py @@ -24,7 +24,7 @@ from accelerate.utils import set_seed from accelerate import Accelerator from diffusers import DDPMScheduler from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd +from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -1711,6 +1711,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) @@ -1718,7 +1719,6 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - sai_model_spec.add_model_spec_arguments(parser) parser.add_argument( "--cpu_offload_checkpointing", diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0c6568b0..8575698d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -16,7 +16,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util, strategy_base, strategy_sd +from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -771,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 6ff97d03..77821095 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -21,6 +21,7 @@ import library import library.train_util as train_util import library.huggingface_util as huggingface_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -668,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser)