Update model spec to 1.0.1. Refactor model spec

This commit is contained in:
rockerBOO
2025-08-02 21:14:27 -04:00
parent 5dff02a65d
commit d24d733892
2 changed files with 654 additions and 199 deletions

View File

@@ -1,14 +1,19 @@
# based on https://github.com/Stability-AI/ModelSpec
import datetime
import hashlib
import argparse
import base64
import logging
import mimetypes
import subprocess
from dataclasses import dataclass, field
from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -31,23 +36,44 @@ metadata = {
"""
BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
# === Universal MUST fields ===
"modelspec.sai_model_spec": "1.0.1", # Updated to latest spec version
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
# === Universal SHOULD fields ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.hash_sha256": None,
# === Universal CAN fields ===
"modelspec.implementation_version": None,
"modelspec.license": None,
"modelspec.usage_hint": None,
"modelspec.thumbnail": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
# === Image generation MUST fields ===
"modelspec.resolution": None,
# === Image generation CAN fields ===
"modelspec.trigger_phrase": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
"modelspec.preprocessor": None,
"modelspec.is_negative_embedding": None,
"modelspec.unet_dtype": None,
"modelspec.vae_dtype": None,
# === Text prediction fields ===
"modelspec.data_format": None,
"modelspec.format_type": None,
"modelspec.language": None,
"modelspec.format_template": None,
}
# 別に使うやつだけ定義
@@ -80,6 +106,256 @@ PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@dataclass
class ModelSpecMetadata:
"""
ModelSpec 1.0.1 compliant metadata for safetensors models.
All fields correspond to modelspec.* keys in the final metadata.
"""
# === Universal MUST fields ===
architecture: str
implementation: str
title: str
# === Universal SHOULD fields ===
description: Optional[str] = None
author: Optional[str] = None
date: Optional[str] = None
hash_sha256: Optional[str] = None
# === Universal CAN fields ===
sai_model_spec: str = "1.0.1"
implementation_version: Optional[str] = None
license: Optional[str] = None
usage_hint: Optional[str] = None
thumbnail: Optional[str] = None
tags: Optional[str] = None
merged_from: Optional[str] = None
# === Image generation MUST fields ===
resolution: Optional[str] = None
# === Image generation CAN fields ===
trigger_phrase: Optional[str] = None
prediction_type: Optional[str] = None
timestep_range: Optional[str] = None
encoder_layer: Optional[str] = None
preprocessor: Optional[str] = None
is_negative_embedding: Optional[str] = None
unet_dtype: Optional[str] = None
vae_dtype: Optional[str] = None
# === Text prediction fields ===
data_format: Optional[str] = None
format_type: Optional[str] = None
language: Optional[str] = None
format_template: Optional[str] = None
# === Additional metadata ===
additional_fields: Dict[str, str] = field(default_factory=dict)
def to_metadata_dict(self) -> Dict[str, str]:
"""Convert dataclass to metadata dictionary with modelspec. prefixes."""
metadata = {}
# Add all non-None fields with modelspec prefix
for field_name, value in self.__dict__.items():
if field_name == "additional_fields":
# Handle additional fields separately
for key, val in value.items():
if key.startswith("modelspec."):
metadata[key] = val
else:
metadata[f"modelspec.{key}"] = val
elif value is not None:
metadata[f"modelspec.{field_name}"] = value
return metadata
@classmethod
def from_args(cls, args, **kwargs) -> "ModelSpecMetadata":
"""Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields."""
metadata_fields = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
value = getattr(args, attr_name, None)
if value is not None:
# Remove metadata_ prefix
field_name = attr_name[9:] # len("metadata_") = 9
metadata_fields[field_name] = value
# Handle known standard fields
standard_fields = {
"author": metadata_fields.pop("author", None),
"description": metadata_fields.pop("description", None),
"license": metadata_fields.pop("license", None),
"tags": metadata_fields.pop("tags", None),
}
# Remove None values
standard_fields = {k: v for k, v in standard_fields.items() if v is not None}
# Merge with kwargs and remaining metadata fields
all_fields = {**standard_fields, **kwargs}
if metadata_fields:
all_fields["additional_fields"] = metadata_fields
return cls(**all_fields)
def determine_architecture(
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
model_config: Optional[dict] = None
) -> str:
"""Determine model architecture string from parameters."""
model_config = model_config or {}
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif "sd3" in model_config:
arch = ARCH_SD3_M + "-" + model_config["sd3"]
elif "flux" in model_config:
flux_type = model_config["flux"]
if flux_type == "dev":
arch = ARCH_FLUX_1_DEV
elif flux_type == "schnell":
arch = ARCH_FLUX_1_SCHNELL
elif flux_type == "chroma":
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif "lumina" in model_config:
lumina_type = model_config["lumina"]
if lumina_type == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
# Add adapter suffix
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
return arch
def determine_implementation(
lora: bool,
textual_inversion: bool,
sdxl: bool,
model_config: Optional[dict] = None,
is_stable_diffusion_ckpt: Optional[bool] = None
) -> str:
"""Determine implementation string from parameters."""
model_config = model_config or {}
if "flux" in model_config:
if model_config["flux"] == "chroma":
return IMPL_CHROMA
else:
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
return IMPL_DIFFUSERS
def get_implementation_version() -> str:
"""Get the current implementation version as sd-scripts/{commit_hash}."""
try:
# Get the git commit hash
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root
timeout=5
)
if result.returncode == 0:
commit_hash = result.stdout.strip()
return f"sd-scripts/{commit_hash}"
else:
logger.warning("Failed to get git commit hash, using fallback")
return "sd-scripts/unknown"
except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e:
logger.warning(f"Could not determine git commit: {e}")
return "sd-scripts/unknown"
def file_to_data_url(file_path: str) -> str:
"""Convert a file path to a data URL for embedding in metadata."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
# Get MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
# Default to binary if we can't detect
mime_type = "application/octet-stream"
# Read file and encode as base64
with open(file_path, "rb") as f:
file_data = f.read()
encoded_data = base64.b64encode(file_data).decode("ascii")
return f"data:{mime_type};base64,{encoded_data}"
def determine_resolution(
reso: Optional[Union[int, Tuple[int, int]]] = None,
sdxl: bool = False,
model_config: Optional[dict] = None,
v2: bool = False,
v_parameterization: bool = False
) -> str:
"""Determine resolution string from parameters."""
model_config = model_config or {}
if reso is not None:
# Handle comma separated string
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
# Handle single int
if isinstance(reso, int):
reso = (reso, reso)
# Handle single-element tuple
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if (sdxl or
"sd3" in model_config or
"flux" in model_config or
"lumina" in model_config):
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)
else:
reso = (512, 512)
return f"{reso[0]}x{reso[1]}"
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
@@ -109,6 +385,135 @@ def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError
def build_metadata_dataclass(
state_dict: Optional[dict],
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
model_config: Optional[dict] = None,
optional_metadata: Optional[dict] = None,
) -> ModelSpecMetadata:
"""
Build ModelSpec 1.0.1 compliant metadata dataclass.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# Use helper functions for complex logic
architecture = determine_architecture(
v2, v_parameterization, sdxl, lora, textual_inversion, model_config
)
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
implementation = determine_implementation(
lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt
)
if title is None:
if lora:
title = "LoRA"
elif textual_inversion:
title = "TextualInversion"
else:
title = "Checkpoint"
title += f"@{timestamp}"
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
# Use helper function for resolution
resolution = determine_resolution(
reso, sdxl, model_config, v2, v_parameterization
)
# Handle prediction type - Flux models don't use prediction_type
model_config = model_config or {}
prediction_type = None
if "flux" not in model_config:
if v_parameterization:
prediction_type = PRED_TYPE_V
else:
prediction_type = PRED_TYPE_EPSILON
# Handle timesteps
timestep_range = None
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
timestep_range = f"{timesteps[0]},{timesteps[1]}"
# Handle encoder layer (clip skip)
encoder_layer = None
if clip_skip is not None:
encoder_layer = f"{clip_skip}"
# TODO: Implement hash calculation when memory-efficient method is available
# hash_sha256 = None
# if state_dict is not None:
# hash_sha256 = precalculate_safetensors_hashes(state_dict)
# Process thumbnail - convert file path to data URL if needed
processed_optional_metadata = optional_metadata.copy() if optional_metadata else {}
if "thumbnail" in processed_optional_metadata:
thumbnail_value = processed_optional_metadata["thumbnail"]
# Check if it's already a data URL or if it's a file path
if thumbnail_value and not thumbnail_value.startswith("data:"):
try:
processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value)
logger.info(f"Converted thumbnail file {thumbnail_value} to data URL")
except FileNotFoundError as e:
logger.warning(f"Thumbnail file not found, skipping: {e}")
del processed_optional_metadata["thumbnail"]
except Exception as e:
logger.warning(f"Failed to convert thumbnail to data URL: {e}")
del processed_optional_metadata["thumbnail"]
# Automatically set implementation version if not provided
if "implementation_version" not in processed_optional_metadata:
processed_optional_metadata["implementation_version"] = get_implementation_version()
# Create the dataclass
metadata = ModelSpecMetadata(
architecture=architecture,
implementation=implementation,
title=title,
description=description,
author=author,
date=date,
license=license,
tags=tags,
merged_from=merged_from,
resolution=resolution,
prediction_type=prediction_type,
timestep_range=timestep_range,
encoder_layer=encoder_layer,
additional_fields=processed_optional_metadata
)
return metadata
def build_metadata(
state_dict: Optional[dict],
v2: bool,
@@ -127,164 +532,41 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
lumina: Optional[str] = None,
):
model_config: Optional[dict] = None,
optional_metadata: Optional[dict] = None,
) -> Dict[str, str]:
"""
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
Build ModelSpec 1.0.1 compliant metadata for safetensors models.
Legacy function that returns dict - prefer build_metadata_dataclass for new code.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# if state_dict is None, hash is not calculated
metadata = {}
metadata.update(BASE_METADATA)
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
# if state_dict is not None:
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
elif flux == "schnell":
arch = ARCH_FLUX_1_SCHNELL
elif flux == "chroma":
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif lumina is not None:
if lumina == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if flux is not None:
# Flux
if flux == "chroma":
impl = IMPL_CHROMA
else:
impl = IMPL_FLUX
elif lumina is not None:
# Lumina
impl = IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
if title is None:
if lora:
title = "LoRA"
elif textual_inversion:
title = "TextualInversion"
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title
if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]
if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]
if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]
if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]
if tags is not None:
metadata["modelspec.tags"] = tags
else:
del metadata["modelspec.tags"]
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date
if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None or flux is not None or lumina is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
else:
del metadata["modelspec.timestep_range"]
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
# Use the dataclass function and convert to dict
metadata_obj = build_metadata_dataclass(
state_dict=state_dict,
v2=v2,
v_parameterization=v_parameterization,
sdxl=sdxl,
lora=lora,
textual_inversion=textual_inversion,
timestamp=timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=author,
description=description,
license=license,
tags=tags,
merged_from=merged_from,
timesteps=timesteps,
clip_skip=clip_skip,
model_config=model_config,
optional_metadata=optional_metadata,
)
return metadata_obj.to_metadata_dict()
# region utils
@@ -317,6 +599,121 @@ def build_merged_from(models: List[str]) -> str:
return ", ".join(titles)
def add_model_spec_arguments(parser: argparse.ArgumentParser):
"""Add all ModelSpec metadata arguments to the parser."""
# === Existing standard metadata fields ===
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
# === Universal CAN fields ===
# Note: implementation_version is automatically set to sd-scripts/{commit_hash}
parser.add_argument(
"--metadata_usage_hint",
type=str,
default=None,
help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント",
)
parser.add_argument(
"--metadata_thumbnail",
type=str,
default=None,
help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます",
)
parser.add_argument(
"--metadata_merged_from",
type=str,
default=None,
help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名",
)
# === Image generation CAN fields ===
parser.add_argument(
"--metadata_trigger_phrase",
type=str,
default=None,
help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ",
)
parser.add_argument(
"--metadata_preprocessor",
type=str,
default=None,
help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法",
)
parser.add_argument(
"--metadata_is_negative_embedding",
type=str,
default=None,
help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか",
)
parser.add_argument(
"--metadata_unet_dtype",
type=str,
default=None,
help="UNet data type for model metadata / メタデータに書き込まれるUNetのデータ型",
)
parser.add_argument(
"--metadata_vae_dtype",
type=str,
default=None,
help="VAE data type for model metadata / メタデータに書き込まれるVAEのデータ型",
)
# === Text prediction fields ===
parser.add_argument(
"--metadata_data_format",
type=str,
default=None,
help="data format for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルのデータ形式",
)
parser.add_argument(
"--metadata_format_type",
type=str,
default=None,
help="format type for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式タイプ",
)
parser.add_argument(
"--metadata_language",
type=str,
default=None,
help="language for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの言語",
)
parser.add_argument(
"--metadata_format_template",
type=str,
default=None,
help="format template for text prediction model metadata / メタデータに書き込まれるテキスト予測モデルの形式テンプレート",
)
# endregion

View File

@@ -3484,6 +3484,7 @@ def get_sai_model_spec(
sd3: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
lumina: str = None,
optional_metadata: dict[str, str] | None = None
):
timestamp = time.time()
@@ -3500,6 +3501,34 @@ def get_sai_model_spec(
else:
timesteps = None
# Convert individual model parameters to model_config dict
# TODO: Update calls to this function to pass in the model config
model_config = {}
if sd3 is not None:
model_config["sd3"] = sd3
if flux is not None:
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina
# Extract metadata_* fields from args and merge with optional_metadata
extracted_metadata = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
value = getattr(args, attr_name, None)
if value is not None:
# Remove metadata_ prefix and exclude already handled fields
field_name = attr_name[9:] # len("metadata_") = 9
if field_name not in ["title", "author", "description", "license", "tags"]:
extracted_metadata[field_name] = value
# Merge extracted metadata with provided optional_metadata
all_optional_metadata = {**extracted_metadata}
if optional_metadata:
all_optional_metadata.update(optional_metadata)
metadata = sai_model_spec.build_metadata(
state_dict,
v2,
@@ -3517,13 +3546,75 @@ def get_sai_model_spec(
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
lumina=lumina,
model_config=model_config,
optional_metadata=all_optional_metadata if all_optional_metadata else None,
)
return metadata
def get_sai_model_spec_dataclass(
state_dict: dict,
args: argparse.Namespace,
sdxl: bool,
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None,
sd3: str = None,
flux: str = None,
lumina: str = None,
optional_metadata: dict[str, str] | None = None
) -> sai_model_spec.ModelSpecMetadata:
"""
Get ModelSpec metadata as a dataclass - preferred for new code.
Automatically extracts metadata_* fields from args.
"""
timestamp = time.time()
v2 = args.v2
v_parameterization = args.v_parameterization
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
# Convert individual model parameters to model_config dict
model_config = {}
if sd3 is not None:
model_config["sd3"] = sd3
if flux is not None:
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,
v2,
v_parameterization,
sdxl,
lora,
textual_inversion,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip,
model_config=model_config,
optional_metadata=optional_metadata,
)
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument(
@@ -4103,39 +4194,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
# SAI Model spec
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
if support_dreambooth:
# DreamBooth training
parser.add_argument(