Add parser args for other trainers.

This commit is contained in:
rockerBOO 2025-08-03 00:58:25 -04:00
parent 9bb50c26c4
commit c149cf283b
No known key found for this signature in database
GPG Key ID: 0D4EAF00DCABC97B
17 changed files with 150 additions and 112 deletions

View File

@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@ -519,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)

View File

@ -30,7 +30,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
@ -787,6 +787,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@ -32,6 +32,7 @@ init_ipex()
from accelerate.utils import set_seed
import library.train_util as train_util
import library.sai_model_spec as sai_model_spec
from library import (
deepspeed_utils,
flux_train_utils,
@ -820,6 +821,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@ -31,6 +31,7 @@ from library import (
lumina_util,
strategy_base,
strategy_lumina,
sai_model_spec
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
@ -904,6 +905,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@ -20,6 +20,8 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3
import library.sai_model_spec as sai_model_spec
from library.sdxl_train_util import match_mixed_precision
# , sdxl_model_util
@ -986,6 +988,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@ -17,7 +17,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec
import library.train_util as train_util
@ -893,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)

View File

@ -25,6 +25,7 @@ from library import (
strategy_base,
strategy_sd,
strategy_sdxl,
sai_model_spec
)
import library.train_util as train_util
@ -664,6 +665,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
# train_util.add_masked_loss_arguments(parser)

View File

@ -32,6 +32,7 @@ from library import (
strategy_base,
strategy_sd,
strategy_sdxl,
sai_model_spec,
)
import library.model_util as model_util
@ -589,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)

View File

@ -24,6 +24,7 @@ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_origi
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@ -536,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)

View File

@ -1,4 +1,5 @@
"""Tests for sai_model_spec module."""
import pytest
import time
@ -7,7 +8,7 @@ from library import sai_model_spec
class MockArgs:
"""Mock argparse.Namespace for testing."""
def __init__(self, **kwargs):
# Default values
self.v2 = False
@ -22,7 +23,7 @@ class MockArgs:
self.max_timestep = None
self.clip_skip = None
self.output_name = "test_output"
# Override with provided values
for key, value in kwargs.items():
setattr(self, key, value)
@ -30,57 +31,56 @@ class MockArgs:
class TestModelSpecMetadata:
"""Test the ModelSpecMetadata dataclass."""
def test_creation_and_conversion(self):
"""Test creating dataclass and converting to metadata dict."""
metadata = sai_model_spec.ModelSpecMetadata(
architecture="stable-diffusion-v1",
implementation="diffusers",
title="Test Model",
resolution="512x512",
author="Test Author",
description=None # Test None exclusion
description=None, # Test None exclusion
)
assert metadata.architecture == "stable-diffusion-v1"
assert metadata.sai_model_spec == "1.0.1"
metadata_dict = metadata.to_metadata_dict()
assert "modelspec.architecture" in metadata_dict
assert "modelspec.author" in metadata_dict
assert "modelspec.description" not in metadata_dict # None values excluded
assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1"
def test_additional_fields_handling(self):
"""Test handling of additional metadata fields."""
additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"}
metadata = sai_model_spec.ModelSpecMetadata(
architecture="stable-diffusion-v1",
implementation="diffusers",
title="Test Model",
additional_fields=additional
resolution="512x512",
additional_fields=additional,
)
metadata_dict = metadata.to_metadata_dict()
assert "modelspec.custom_field" in metadata_dict
assert "modelspec.prefixed" in metadata_dict
assert metadata_dict["modelspec.custom_field"] == "custom_value"
def test_from_args_extraction(self):
"""Test creating ModelSpecMetadata from args with metadata_* fields."""
args = MockArgs(
metadata_author="Test Author",
metadata_trigger_phrase="anime style",
metadata_usage_hint="Use CFG 7.5"
)
args = MockArgs(metadata_author="Test Author", metadata_trigger_phrase="anime style", metadata_usage_hint="Use CFG 7.5")
metadata = sai_model_spec.ModelSpecMetadata.from_args(
args,
architecture="stable-diffusion-v1",
implementation="diffusers",
title="Test Model"
title="Test Model",
resolution="512x512",
)
assert metadata.author == "Test Author"
assert metadata.additional_fields["trigger_phrase"] == "anime style"
assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5"
@ -88,79 +88,87 @@ class TestModelSpecMetadata:
class TestArchitectureDetection:
"""Test architecture detection for different model types."""
@pytest.mark.parametrize("config,expected", [
({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, "stable-diffusion-3-large"),
({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"),
({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"),
])
@pytest.mark.parametrize(
"config,expected",
[
({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"),
({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"),
(
{"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}},
"stable-diffusion-3-large",
),
({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"),
({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"),
],
)
def test_architecture_detection(self, config, expected):
"""Test architecture detection for various model configurations."""
model_config = config.pop("model_config", None)
arch = sai_model_spec.determine_architecture(
lora=False, textual_inversion=False, model_config=model_config, **config
)
arch = sai_model_spec.determine_architecture(lora=False, textual_inversion=False, model_config=model_config, **config)
assert arch == expected
def test_adapter_suffixes(self):
"""Test LoRA and textual inversion suffixes."""
lora_arch = sai_model_spec.determine_architecture(
v2=False, v_parameterization=False, sdxl=True,
lora=True, textual_inversion=False
v2=False, v_parameterization=False, sdxl=True, lora=True, textual_inversion=False
)
assert lora_arch == "stable-diffusion-xl-v1-base/lora"
ti_arch = sai_model_spec.determine_architecture(
v2=False, v_parameterization=False, sdxl=False,
lora=False, textual_inversion=True
v2=False, v_parameterization=False, sdxl=False, lora=False, textual_inversion=True
)
assert ti_arch == "stable-diffusion-v1/textual-inversion"
class TestImplementationDetection:
"""Test implementation detection for different model types."""
@pytest.mark.parametrize("config,expected", [
({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"),
({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"),
({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"),
({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"),
({"lora": True, "sdxl": False}, "diffusers"),
])
@pytest.mark.parametrize(
"config,expected",
[
({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"),
({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"),
({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"),
({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"),
({"lora": True, "sdxl": False}, "diffusers"),
],
)
def test_implementation_detection(self, config, expected):
"""Test implementation detection for various configurations."""
model_config = config.pop("model_config", None)
impl = sai_model_spec.determine_implementation(
lora=config.get("lora", False),
textual_inversion=False,
sdxl=config.get("sdxl", False),
model_config=model_config
lora=config.get("lora", False), textual_inversion=False, sdxl=config.get("sdxl", False), model_config=model_config
)
assert impl == expected
class TestResolutionHandling:
"""Test resolution parsing and defaults."""
@pytest.mark.parametrize("input_reso,expected", [
((768, 1024), "768x1024"),
(768, "768x768"),
("768,1024", "768x1024"),
])
@pytest.mark.parametrize(
"input_reso,expected",
[
((768, 1024), "768x1024"),
(768, "768x768"),
("768,1024", "768x1024"),
],
)
def test_explicit_resolution_formats(self, input_reso, expected):
"""Test different resolution input formats."""
res = sai_model_spec.determine_resolution(reso=input_reso)
assert res == expected
@pytest.mark.parametrize("config,expected", [
({"sdxl": True}, "1024x1024"),
({"model_config": {"flux": "dev"}}, "1024x1024"),
({"v2": True, "v_parameterization": True}, "768x768"),
({}, "512x512"), # Default SD v1
])
@pytest.mark.parametrize(
"config,expected",
[
({"sdxl": True}, "1024x1024"),
({"model_config": {"flux": "dev"}}, "1024x1024"),
({"v2": True, "v_parameterization": True}, "768x768"),
({}, "512x512"), # Default SD v1
],
)
def test_default_resolutions(self, config, expected):
"""Test default resolution detection."""
model_config = config.pop("model_config", None)
@ -170,59 +178,60 @@ class TestResolutionHandling:
class TestThumbnailProcessing:
"""Test thumbnail data URL processing."""
def test_file_to_data_url(self):
"""Test converting file to data URL."""
import tempfile
import os
# Create a tiny test PNG (1x1 pixel)
test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82'
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82"
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(test_png_data)
temp_path = f.name
try:
data_url = sai_model_spec.file_to_data_url(temp_path)
# Check format
assert data_url.startswith("data:image/png;base64,")
# Check it's a reasonable length (base64 encoded)
assert len(data_url) > 50
# Verify we can decode it back
import base64
encoded_part = data_url.split(",", 1)[1]
decoded_data = base64.b64decode(encoded_part)
assert decoded_data == test_png_data
finally:
os.unlink(temp_path)
def test_file_to_data_url_nonexistent_file(self):
"""Test error handling for nonexistent files."""
import pytest
with pytest.raises(FileNotFoundError):
sai_model_spec.file_to_data_url("/nonexistent/file.png")
def test_thumbnail_processing_in_metadata(self):
"""Test thumbnail processing in build_metadata_dataclass."""
import tempfile
import os
# Create a test image file
test_png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82'
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82"
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(test_png_data)
temp_path = f.name
try:
timestamp = time.time()
# Test with file path - should be converted to data URL
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
@ -233,22 +242,24 @@ class TestThumbnailProcessing:
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
optional_metadata={"thumbnail": temp_path}
optional_metadata={"thumbnail": temp_path},
)
# Should be converted to data URL
assert "thumbnail" in metadata.additional_fields
assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,")
finally:
os.unlink(temp_path)
def test_thumbnail_data_url_passthrough(self):
"""Test that existing data URLs are passed through unchanged."""
timestamp = time.time()
existing_data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
existing_data_url = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
)
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
@ -258,16 +269,16 @@ class TestThumbnailProcessing:
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
optional_metadata={"thumbnail": existing_data_url}
optional_metadata={"thumbnail": existing_data_url},
)
# Should be unchanged
assert metadata.additional_fields["thumbnail"] == existing_data_url
def test_thumbnail_invalid_file_handling(self):
"""Test graceful handling of invalid thumbnail files."""
timestamp = time.time()
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
@ -277,20 +288,20 @@ class TestThumbnailProcessing:
textual_inversion=False,
timestamp=timestamp,
title="Test Model",
optional_metadata={"thumbnail": "/nonexistent/file.png"}
optional_metadata={"thumbnail": "/nonexistent/file.png"},
)
# Should be removed from additional_fields due to error
assert "thumbnail" not in metadata.additional_fields
class TestBuildMetadataIntegration:
"""Test the complete metadata building workflow."""
def test_sdxl_model_workflow(self):
"""Test complete workflow for SDXL model."""
timestamp = time.time()
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
@ -299,18 +310,18 @@ class TestBuildMetadataIntegration:
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test SDXL Model"
title="Test SDXL Model",
)
assert metadata.architecture == "stable-diffusion-xl-v1-base"
assert metadata.implementation == "https://github.com/Stability-AI/generative-models"
assert metadata.resolution == "1024x1024"
assert metadata.prediction_type == "epsilon"
def test_flux_model_workflow(self):
"""Test complete workflow for Flux model."""
timestamp = time.time()
metadata = sai_model_spec.build_metadata_dataclass(
state_dict=None,
v2=False,
@ -321,18 +332,18 @@ class TestBuildMetadataIntegration:
timestamp=timestamp,
title="Test Flux Model",
model_config={"flux": "dev"},
optional_metadata={"trigger_phrase": "anime style"}
optional_metadata={"trigger_phrase": "anime style"},
)
assert metadata.architecture == "flux-1-dev"
assert metadata.implementation == "https://github.com/black-forest-labs/flux"
assert metadata.prediction_type is None # Flux doesn't use prediction_type
assert metadata.additional_fields["trigger_phrase"] == "anime style"
def test_legacy_function_compatibility(self):
"""Test that legacy build_metadata function works correctly."""
timestamp = time.time()
metadata_dict = sai_model_spec.build_metadata(
state_dict=None,
v2=False,
@ -341,9 +352,9 @@ class TestBuildMetadataIntegration:
lora=False,
textual_inversion=False,
timestamp=timestamp,
title="Test Model"
title="Test Model",
)
assert isinstance(metadata_dict, dict)
assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1"
assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base"
assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base"

View File

@ -12,6 +12,7 @@ from tqdm import tqdm
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
from library import train_util
from library import sdxl_train_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@ -161,6 +162,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)

View File

@ -22,6 +22,7 @@ from library import (
from library import train_util
from library import sdxl_train_util
from library import utils
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@ -188,6 +189,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)

View File

@ -25,6 +25,7 @@ from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,

View File

@ -22,6 +22,7 @@ from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@ -512,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)

View File

@ -24,7 +24,7 @@ from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd
from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec
import library.train_util as train_util
from library.train_util import DreamBoothDataset
@ -1711,6 +1711,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
@ -1718,7 +1719,6 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
parser.add_argument(
"--cpu_offload_checkpointing",

View File

@ -16,7 +16,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@ -771,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)

View File

@ -21,6 +21,7 @@ import library
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@ -668,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)