Merge branch 'sd3' into lumina

This commit is contained in:
青龍聖者@bdsqlsz
2025-04-06 16:13:43 +08:00
committed by GitHub
19 changed files with 808 additions and 74 deletions

View File

@@ -14,6 +14,19 @@ The command to install PyTorch is as follows:
### Recent Updates
Mar 30, 2025:
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
Mar 20, 2025:
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
Mar 6, 2025:
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
Feb 26, 2025:
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
@@ -744,6 +757,8 @@ Not available yet.
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
Latest update: 2025-03-21 (Version 0.9.1)
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
@@ -887,6 +902,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.

View File

@@ -152,6 +152,7 @@ These options are related to subset configuration.
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` | (not specified) | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
@@ -165,6 +166,8 @@ These options are related to subset configuration.
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
* `resize_interpolation`
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
### DreamBooth-specific options

View File

@@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` |(通常は設定しません) | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
@@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
* `resize_interpolation`
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos``box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。

View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
@@ -42,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
image = image.astype(np.float32)
return image

View File

@@ -76,6 +76,7 @@ class BaseSubsetParams:
validation_seed: int = 0
validation_split: float = 0.0
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None
@dataclass
@@ -108,7 +109,7 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -199,6 +200,7 @@ class ConfigSanitizer:
"caption_suffix": str,
"custom_attributes": dict,
"system_prompt": str,
"resize_interpolation": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -245,6 +247,7 @@ class ConfigSanitizer:
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"system_prompt": str,
"resize_interpolation": str,
}
# options handled by argparse but not handled by user config
@@ -529,6 +532,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
system_prompt: {dataset.system_prompt}
""")
@@ -563,6 +567,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
system_prompt: {subset.system_prompt}
"""), " ")

186
library/jpeg_xl_util.py Normal file
View File

@@ -0,0 +1,186 @@
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
# Added partial read support for up to 200x speedup
import os
from typing import List, Tuple
class JXLBitstream:
"""
A stream of bits with methods for easy handling.
"""
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
self.shift = 0
self.bitstream = bytearray()
self.file = file
self.offset = offset
self.offsets = offsets
if self.offsets:
self.offset = self.offsets[0][1]
self.previous_data_len = 0
self.index = 0
self.file.seek(self.offset)
def get_bits(self, length: int = 1) -> int:
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_to_read_length = length
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(0, length)
self.bitstream.extend(self.file.read(self.partial_to_read_length))
else:
self.bitstream.extend(self.file.read(length))
bitmask = 2**length - 1
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
self.shift += length
return bits
def partial_read(self, current_length: int, length: int) -> None:
self.previous_data_len += self.offsets[self.index][2]
to_read_length = self.previous_data_len - (self.shift + current_length)
self.bitstream.extend(self.file.read(to_read_length))
current_length += to_read_length
self.partial_to_read_length -= to_read_length
self.index += 1
self.file.seek(self.offsets[self.index][1])
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(current_length, length)
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
"""
Decodes the actual codestream.
JXL codestream specification: http://www-internal/2022/18181-1
"""
# Convert codestream to int within an object to get some handy methods.
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
# Skip signature
codestream.get_bits(16)
# SizeHeader
div8 = codestream.get_bits(1)
if div8:
height = 8 * (1 + codestream.get_bits(5))
else:
distribution = codestream.get_bits(2)
match distribution:
case 0:
height = 1 + codestream.get_bits(9)
case 1:
height = 1 + codestream.get_bits(13)
case 2:
height = 1 + codestream.get_bits(18)
case 3:
height = 1 + codestream.get_bits(30)
ratio = codestream.get_bits(3)
if div8 and not ratio:
width = 8 * (1 + codestream.get_bits(5))
elif not ratio:
distribution = codestream.get_bits(2)
match distribution:
case 0:
width = 1 + codestream.get_bits(9)
case 1:
width = 1 + codestream.get_bits(13)
case 2:
width = 1 + codestream.get_bits(18)
case 3:
width = 1 + codestream.get_bits(30)
else:
match ratio:
case 1:
width = height
case 2:
width = (height * 12) // 10
case 3:
width = (height * 4) // 3
case 4:
width = (height * 3) // 2
case 5:
width = (height * 16) // 9
case 6:
width = (height * 5) // 4
case 7:
width = (height * 2) // 1
return width, height
def decode_container(file) -> Tuple[int,int]:
"""
Parses the ISOBMFF container, extracts the codestream, and decodes it.
JXL container specification: http://www-internal/2022/18181-2
"""
def parse_box(file, file_start: int) -> dict:
file.seek(file_start)
LBox = int.from_bytes(file.read(4), "big")
XLBox = None
if 1 < LBox <= 8:
raise ValueError(f"Invalid LBox at byte {file_start}.")
if LBox == 1:
file.seek(file_start + 8)
XLBox = int.from_bytes(file.read(8), "big")
if XLBox <= 16:
raise ValueError(f"Invalid XLBox at byte {file_start}.")
if XLBox:
header_length = 16
box_length = XLBox
else:
header_length = 8
if LBox == 0:
box_length = os.fstat(file.fileno()).st_size - file_start
else:
box_length = LBox
file.seek(file_start + 4)
box_type = file.read(4)
file.seek(file_start)
return {
"length": box_length,
"type": box_type,
"offset": header_length,
}
file.seek(0)
# Reject files missing required boxes. These two boxes are required to be at
# the start and contain no values, so we can manually check there presence.
# Signature box. (Redundant as has already been checked.)
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
raise ValueError("Invalid signature box.")
# File Type box.
if file.read(20) != bytes.fromhex(
"00000014 66747970 6A786C20 00000000 6A786C20"
):
raise ValueError("Invalid file type box.")
offset = 0
offsets = []
data_offset_not_found = True
container_pointer = 32
file_size = os.fstat(file.fileno()).st_size
while data_offset_not_found:
box = parse_box(file, container_pointer)
match box["type"]:
case b"jxlc":
offset = container_pointer + box["offset"]
data_offset_not_found = False
case b"jxlp":
file.seek(container_pointer + box["offset"])
index = int.from_bytes(file.read(4), "big")
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
container_pointer += box["length"]
if container_pointer >= file_size:
data_offset_not_found = False
if offsets:
offsets.sort(key=lambda i: i[0])
file.seek(0)
return decode_codestream(file, offset=offset, offsets=offsets)
def get_jxl_size(path: str) -> Tuple[int,int]:
with open(path, "rb") as file:
if file.read(2) == bytes.fromhex("FF0A"):
return decode_codestream(file)
return decode_container(file)

View File

@@ -74,7 +74,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image, validate_interpolation_fn
setup_logging()
import logging
@@ -113,14 +113,16 @@ except:
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
from library.jpeg_xl_util import get_jxl_size
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
# JPEG-XL on Windows
# JPEG-XL on Linux and Windows
try:
import pillow_jxl
from library.jpeg_xl_util import get_jxl_size
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
@@ -205,6 +207,7 @@ class ImageInfo:
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
self.system_prompt: Optional[str] = None
@@ -432,6 +435,7 @@ class BaseSubset:
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -463,6 +467,7 @@ class BaseSubset:
self.validation_split = validation_split
self.system_prompt = system_prompt
self.resize_interpolation = resize_interpolation
class DreamBoothSubset(BaseSubset):
@@ -496,6 +501,7 @@ class DreamBoothSubset(BaseSubset):
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -524,6 +530,7 @@ class DreamBoothSubset(BaseSubset):
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
resize_interpolation=resize_interpolation,
)
self.is_reg = is_reg
@@ -567,6 +574,7 @@ class FineTuningSubset(BaseSubset):
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -595,6 +603,7 @@ class FineTuningSubset(BaseSubset):
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
resize_interpolation=resize_interpolation,
)
self.metadata_file = metadata_file
@@ -634,6 +643,7 @@ class ControlNetSubset(BaseSubset):
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -662,6 +672,7 @@ class ControlNetSubset(BaseSubset):
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
resize_interpolation=resize_interpolation,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -682,6 +693,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
) -> None:
super().__init__()
@@ -716,6 +728,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = IMAGE_TRANSFORMS
if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
self.resize_interpolation = resize_interpolation
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -1459,6 +1475,8 @@ class BaseDataset(torch.utils.data.Dataset):
)
def get_image_size(self, image_path):
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
return get_jxl_size(image_path)
# return imagesize.get(image_path)
image_size = imagesize.get(image_path)
if image_size[0] <= 0:
@@ -1505,7 +1523,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
@@ -1602,7 +1620,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -1866,8 +1884,9 @@ class DreamBoothDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
system_prompt: Optional[str],
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2099,6 +2118,7 @@ class DreamBoothDataset(BaseDataset):
system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else ""
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2154,8 +2174,9 @@ class FineTuningDataset(BaseDataset):
debug_dataset: bool,
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
self.batch_size = batch_size
@@ -2381,9 +2402,10 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
db_subsets = []
for subset in subsets:
@@ -2415,6 +2437,7 @@ class ControlNetDataset(BaseDataset):
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
resize_interpolation=subset.resize_interpolation,
)
db_subsets.append(db_subset)
@@ -2433,6 +2456,7 @@ class ControlNetDataset(BaseDataset):
debug_dataset,
validation_split,
validation_seed,
resize_interpolation,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2441,7 +2465,8 @@ class ControlNetDataset(BaseDataset):
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
self.resize_interpolation = resize_interpolation
# assert all conditioning data exists
missing_imgs = []
@@ -2529,9 +2554,8 @@ class ControlNetDataset(BaseDataset):
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2545,7 +2569,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2942,17 +2966,13 @@ def load_image(image_path, alpha=False):
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
image_height, image_width = image.shape[0:2]
@@ -2997,7 +3017,7 @@ def load_images_and_masks_for_caching(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
@@ -3038,7 +3058,7 @@ def cache_batch_latents(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -4518,7 +4538,13 @@ def add_dataset_arguments(
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--resize_interpolation",
type=str,
default=None,
choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"],
help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area",
)
parser.add_argument(
"--token_warmup_min",
type=int,
@@ -6566,3 +6592,4 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses

View File

@@ -16,7 +16,6 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
# endregion
@@ -261,11 +262,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
@@ -276,6 +276,9 @@ class MemoryEfficientSafeOpen:
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
@@ -293,10 +296,9 @@ class MemoryEfficientSafeOpen:
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
@@ -377,7 +379,7 @@ def load_safetensors(
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
@@ -385,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
resized_pil = pil_image.resize(size, resample=interpolation)
# Convert back to cv2 format
if has_alpha:
@@ -396,6 +398,117 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
Returns:
image
"""
# Ensure all size parameters are actual integers
width = int(width)
height = int(height)
resized_width = int(resized_width)
resized_height = int(resized_height)
if resize_interpolation is None:
if width >= resized_width and height >= resized_height:
resize_interpolation = "area"
else:
resize_interpolation = "lanczos"
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
resized_size = (resized_width, resized_height)
if use_pil:
interpolation = get_pil_interpolation(resize_interpolation)
image = pil_resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (PIL)")
else:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (cv2)")
return image
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion
# TODO make inf_utils.py

View File

@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -9,11 +9,13 @@
import math
import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
from torch import Tensor
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
else:
return org_forwarded + lx * self.multiplier * scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
@torch.no_grad()
def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@torch.no_grad()
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
@torch.no_grad()
def update_norms(self):
# Not running GGPO so not currently running update norms
if self.ggpo_beta is None or self.ggpo_sigma is None:
return
# only update norms when we are training
if self.training is False:
return
module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
@torch.no_grad()
def update_grad_norms(self):
if self.training is False:
print(f"skipping update_grad_norms for {self.lora_name}")
return
lora_down_grad = None
lora_up_grad = None
for name, param in self.named_parameters():
if name == "lora_down.weight":
lora_down_grad = param.grad
elif name == "lora_up.weight":
lora_up_grad = param.grad
# Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None:
with torch.autocast(self.device.type):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoRAInfModule(LoRAModule):
def __init__(
@@ -420,6 +555,16 @@ def create_network(
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
ggpo_beta = kwargs.get("ggpo_beta", None)
ggpo_sigma = kwargs.get("ggpo_sigma", None)
if ggpo_beta is not None:
ggpo_beta = float(ggpo_beta)
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -449,6 +594,8 @@ def create_network(
in_dims=in_dims,
train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
verbose=verbose,
)
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
)
loras.append(lora)
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def update_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_norms()
def update_grad_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
def weight_norms(self) -> Tensor:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
def combined_weight_norms(self) -> Tensor:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file

View File

@@ -7,9 +7,11 @@ opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.4
pytorch-optimizer==3.5.0
prodigy-plus-schedule-free==1.9.0
prodigyopt==1.1.2
tensorboard
safetensors==0.4.4
# gradio==3.16.2

View File

@@ -24,7 +24,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:

View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -170,12 +170,9 @@ def process(args):
scale = max(cur_crop_width / w, cur_crop_height / h)
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
rw = int(w * scale + .5)
rh = int(h * scale + .5)
face_img = resize_image(face_img, w, h, rw, rh)
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)

View File

@@ -0,0 +1,166 @@
import argparse
import os
import gc
from typing import Dict, Optional, Union
import torch
from safetensors.torch import safe_open
from library.utils import setup_logging
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
setup_logging()
import logging
logger = logging.getLogger(__name__)
def merge_safetensors(
dit_path: str,
vae_path: Optional[str] = None,
clip_l_path: Optional[str] = None,
clip_g_path: Optional[str] = None,
t5xxl_path: Optional[str] = None,
output_path: str = "merged_model.safetensors",
device: str = "cpu",
save_precision: Optional[str] = None,
):
"""
Merge multiple safetensors files into a single file
Args:
dit_path: Path to the DiT/MMDiT model
vae_path: Path to the VAE model
clip_l_path: Path to the CLIP-L model
clip_g_path: Path to the CLIP-G model
t5xxl_path: Path to the T5-XXL model
output_path: Path to save the merged model
device: Device to load tensors to
save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16')
"""
logger.info("Starting to merge safetensors files...")
# Convert save_precision string to torch dtype if specified
if save_precision:
target_dtype = str_to_dtype(save_precision)
else:
target_dtype = None
# 1. Get DiT metadata if available
metadata = None
try:
with safe_open(dit_path, framework="pt") as f:
metadata = f.metadata() # may be None
if metadata:
logger.info(f"Found metadata in DiT model: {metadata}")
except Exception as e:
logger.warning(f"Failed to read metadata from DiT model: {e}")
# 2. Create empty merged state dict
merged_state_dict = {}
# 3. Load and merge each model with memory management
# DiT/MMDiT - prefix: model.diffusion_model.
# This state dict may have VAE keys.
logger.info(f"Loading DiT model from {dit_path}")
dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding DiT model with {len(dit_state_dict)} keys")
for key, value in dit_state_dict.items():
if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."):
merged_state_dict[key] = value
else:
merged_state_dict[f"model.diffusion_model.{key}"] = value
# Free memory
del dit_state_dict
gc.collect()
# VAE - prefix: first_stage_model.
# May be omitted if VAE is already included in DiT model.
if vae_path:
logger.info(f"Loading VAE model from {vae_path}")
vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding VAE model with {len(vae_state_dict)} keys")
for key, value in vae_state_dict.items():
if key.startswith("first_stage_model."):
merged_state_dict[key] = value
else:
merged_state_dict[f"first_stage_model.{key}"] = value
# Free memory
del vae_state_dict
gc.collect()
# CLIP-L - prefix: text_encoders.clip_l.
if clip_l_path:
logger.info(f"Loading CLIP-L model from {clip_l_path}")
clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys")
for key, value in clip_l_state_dict.items():
if key.startswith("text_encoders.clip_l.transformer."):
merged_state_dict[key] = value
else:
merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value
# Free memory
del clip_l_state_dict
gc.collect()
# CLIP-G - prefix: text_encoders.clip_g.
if clip_g_path:
logger.info(f"Loading CLIP-G model from {clip_g_path}")
clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys")
for key, value in clip_g_state_dict.items():
if key.startswith("text_encoders.clip_g.transformer."):
merged_state_dict[key] = value
else:
merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value
# Free memory
del clip_g_state_dict
gc.collect()
# T5-XXL - prefix: text_encoders.t5xxl.
if t5xxl_path:
logger.info(f"Loading T5-XXL model from {t5xxl_path}")
t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys")
for key, value in t5xxl_state_dict.items():
if key.startswith("text_encoders.t5xxl.transformer."):
merged_state_dict[key] = value
else:
merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value
# Free memory
del t5xxl_state_dict
gc.collect()
# 4. Save merged state dict
logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total")
mem_eff_save_file(merged_state_dict, output_path, metadata)
logger.info("Successfully merged safetensors files")
def main():
parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file")
parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model")
parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model")
parser.add_argument("--clip_l", help="Path to the CLIP-L model")
parser.add_argument("--clip_g", help="Path to the CLIP-G model")
parser.add_argument("--t5xxl", help="Path to the T5-XXL model")
parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model")
parser.add_argument("--device", default="cpu", help="Device to load tensors to")
parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)")
args = parser.parse_args()
merge_safetensors(
dit_path=args.dit,
vae_path=args.vae,
clip_l_path=args.clip_l,
clip_g_path=args.clip_g,
t5xxl_path=args.t5xxl,
output_path=args.output,
device=args.device,
save_precision=args.save_precision,
)
if __name__ == "__main__":
main()

View File

@@ -6,7 +6,7 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
else:
new_height, new_width = img.shape[0:2]
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / サイズの補方法')
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')

View File

@@ -69,13 +69,20 @@ class NetworkTrainer:
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
@@ -652,6 +659,10 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
@@ -1018,6 +1029,7 @@ class NetworkTrainer:
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1043,6 +1055,7 @@ class NetworkTrainer:
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
}
subsets_metadata = []
@@ -1060,6 +1073,7 @@ class NetworkTrainer:
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
"resize_interpolation": subset.resize_interpolation,
}
image_dir_or_metadata_file = None
@@ -1401,6 +1415,11 @@ class NetworkTrainer:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if hasattr(network, "update_grad_norms"):
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
@@ -1409,9 +1428,23 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1444,14 +1477,21 @@ class NetworkTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, logs, global_step, epoch + 1)