mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge branch 'sd3' into lumina
This commit is contained in:
20
README.md
20
README.md
@@ -14,6 +14,19 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Mar 30, 2025:
|
||||
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
|
||||
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
|
||||
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
|
||||
|
||||
Mar 20, 2025:
|
||||
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
|
||||
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
|
||||
|
||||
Mar 6, 2025:
|
||||
|
||||
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
|
||||
|
||||
Feb 26, 2025:
|
||||
|
||||
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
|
||||
@@ -744,6 +757,8 @@ Not available yet.
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
更新履歴は[ページ末尾](#change-history)に移しました。
|
||||
|
||||
Latest update: 2025-03-21 (Version 0.9.1)
|
||||
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
|
||||
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
||||
@@ -887,6 +902,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
|
||||
|
||||
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
|
||||
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
|
||||
|
||||
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
||||
|
||||
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
|
||||
@@ -152,6 +152,7 @@ These options are related to subset configuration.
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` | (not specified) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
|
||||
@@ -165,6 +166,8 @@ These options are related to subset configuration.
|
||||
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
|
||||
* `enable_wildcard`
|
||||
* Enables wildcard notation. This will be explained later.
|
||||
* `resize_interpolation`
|
||||
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
|
||||
|
||||
### DreamBooth-specific options
|
||||
|
||||
|
||||
@@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` |(通常は設定しません) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
||||
@@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
* `enable_wildcard`
|
||||
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
||||
|
||||
* `resize_interpolation`
|
||||
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
|
||||
|
||||
### DreamBooth 方式専用のオプション
|
||||
|
||||
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
||||
|
||||
@@ -11,7 +11,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -42,10 +42,7 @@ def preprocess_image(image):
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
if size > IMAGE_SIZE:
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
|
||||
else:
|
||||
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
|
||||
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
|
||||
@@ -76,6 +76,7 @@ class BaseSubsetParams:
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -108,7 +109,7 @@ class BaseDatasetParams:
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -199,6 +200,7 @@ class ConfigSanitizer:
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"system_prompt": str,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -245,6 +247,7 @@ class ConfigSanitizer:
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"system_prompt": str,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -529,6 +532,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
system_prompt: {dataset.system_prompt}
|
||||
""")
|
||||
@@ -563,6 +567,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
resize_interpolation: {subset.resize_interpolation}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
system_prompt: {subset.system_prompt}
|
||||
"""), " ")
|
||||
|
||||
186
library/jpeg_xl_util.py
Normal file
186
library/jpeg_xl_util.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
|
||||
# Added partial read support for up to 200x speedup
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
class JXLBitstream:
|
||||
"""
|
||||
A stream of bits with methods for easy handling.
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
||||
self.shift = 0
|
||||
self.bitstream = bytearray()
|
||||
self.file = file
|
||||
self.offset = offset
|
||||
self.offsets = offsets
|
||||
if self.offsets:
|
||||
self.offset = self.offsets[0][1]
|
||||
self.previous_data_len = 0
|
||||
self.index = 0
|
||||
self.file.seek(self.offset)
|
||||
|
||||
def get_bits(self, length: int = 1) -> int:
|
||||
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_to_read_length = length
|
||||
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(0, length)
|
||||
self.bitstream.extend(self.file.read(self.partial_to_read_length))
|
||||
else:
|
||||
self.bitstream.extend(self.file.read(length))
|
||||
bitmask = 2**length - 1
|
||||
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
|
||||
self.shift += length
|
||||
return bits
|
||||
|
||||
def partial_read(self, current_length: int, length: int) -> None:
|
||||
self.previous_data_len += self.offsets[self.index][2]
|
||||
to_read_length = self.previous_data_len - (self.shift + current_length)
|
||||
self.bitstream.extend(self.file.read(to_read_length))
|
||||
current_length += to_read_length
|
||||
self.partial_to_read_length -= to_read_length
|
||||
self.index += 1
|
||||
self.file.seek(self.offsets[self.index][1])
|
||||
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(current_length, length)
|
||||
|
||||
|
||||
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
||||
"""
|
||||
Decodes the actual codestream.
|
||||
JXL codestream specification: http://www-internal/2022/18181-1
|
||||
"""
|
||||
|
||||
# Convert codestream to int within an object to get some handy methods.
|
||||
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
|
||||
|
||||
# Skip signature
|
||||
codestream.get_bits(16)
|
||||
|
||||
# SizeHeader
|
||||
div8 = codestream.get_bits(1)
|
||||
if div8:
|
||||
height = 8 * (1 + codestream.get_bits(5))
|
||||
else:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
height = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
height = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
height = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
height = 1 + codestream.get_bits(30)
|
||||
ratio = codestream.get_bits(3)
|
||||
if div8 and not ratio:
|
||||
width = 8 * (1 + codestream.get_bits(5))
|
||||
elif not ratio:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
width = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
width = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
width = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
width = 1 + codestream.get_bits(30)
|
||||
else:
|
||||
match ratio:
|
||||
case 1:
|
||||
width = height
|
||||
case 2:
|
||||
width = (height * 12) // 10
|
||||
case 3:
|
||||
width = (height * 4) // 3
|
||||
case 4:
|
||||
width = (height * 3) // 2
|
||||
case 5:
|
||||
width = (height * 16) // 9
|
||||
case 6:
|
||||
width = (height * 5) // 4
|
||||
case 7:
|
||||
width = (height * 2) // 1
|
||||
return width, height
|
||||
|
||||
|
||||
def decode_container(file) -> Tuple[int,int]:
|
||||
"""
|
||||
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
||||
JXL container specification: http://www-internal/2022/18181-2
|
||||
"""
|
||||
|
||||
def parse_box(file, file_start: int) -> dict:
|
||||
file.seek(file_start)
|
||||
LBox = int.from_bytes(file.read(4), "big")
|
||||
XLBox = None
|
||||
if 1 < LBox <= 8:
|
||||
raise ValueError(f"Invalid LBox at byte {file_start}.")
|
||||
if LBox == 1:
|
||||
file.seek(file_start + 8)
|
||||
XLBox = int.from_bytes(file.read(8), "big")
|
||||
if XLBox <= 16:
|
||||
raise ValueError(f"Invalid XLBox at byte {file_start}.")
|
||||
if XLBox:
|
||||
header_length = 16
|
||||
box_length = XLBox
|
||||
else:
|
||||
header_length = 8
|
||||
if LBox == 0:
|
||||
box_length = os.fstat(file.fileno()).st_size - file_start
|
||||
else:
|
||||
box_length = LBox
|
||||
file.seek(file_start + 4)
|
||||
box_type = file.read(4)
|
||||
file.seek(file_start)
|
||||
return {
|
||||
"length": box_length,
|
||||
"type": box_type,
|
||||
"offset": header_length,
|
||||
}
|
||||
|
||||
file.seek(0)
|
||||
# Reject files missing required boxes. These two boxes are required to be at
|
||||
# the start and contain no values, so we can manually check there presence.
|
||||
# Signature box. (Redundant as has already been checked.)
|
||||
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
|
||||
raise ValueError("Invalid signature box.")
|
||||
# File Type box.
|
||||
if file.read(20) != bytes.fromhex(
|
||||
"00000014 66747970 6A786C20 00000000 6A786C20"
|
||||
):
|
||||
raise ValueError("Invalid file type box.")
|
||||
|
||||
offset = 0
|
||||
offsets = []
|
||||
data_offset_not_found = True
|
||||
container_pointer = 32
|
||||
file_size = os.fstat(file.fileno()).st_size
|
||||
while data_offset_not_found:
|
||||
box = parse_box(file, container_pointer)
|
||||
match box["type"]:
|
||||
case b"jxlc":
|
||||
offset = container_pointer + box["offset"]
|
||||
data_offset_not_found = False
|
||||
case b"jxlp":
|
||||
file.seek(container_pointer + box["offset"])
|
||||
index = int.from_bytes(file.read(4), "big")
|
||||
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
|
||||
container_pointer += box["length"]
|
||||
if container_pointer >= file_size:
|
||||
data_offset_not_found = False
|
||||
|
||||
if offsets:
|
||||
offsets.sort(key=lambda i: i[0])
|
||||
file.seek(0)
|
||||
|
||||
return decode_codestream(file, offset=offset, offsets=offsets)
|
||||
|
||||
|
||||
def get_jxl_size(path: str) -> Tuple[int,int]:
|
||||
with open(path, "rb") as file:
|
||||
if file.read(2) == bytes.fromhex("FF0A"):
|
||||
return decode_codestream(file)
|
||||
return decode_container(file)
|
||||
@@ -74,7 +74,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image, validate_interpolation_fn
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -113,14 +113,16 @@ except:
|
||||
# JPEG-XL on Linux
|
||||
try:
|
||||
from jxlpy import JXLImagePlugin
|
||||
from library.jpeg_xl_util import get_jxl_size
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Windows
|
||||
# JPEG-XL on Linux and Windows
|
||||
try:
|
||||
import pillow_jxl
|
||||
from library.jpeg_xl_util import get_jxl_size
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
@@ -205,6 +207,7 @@ class ImageInfo:
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.resize_interpolation: Optional[str] = None
|
||||
|
||||
self.system_prompt: Optional[str] = None
|
||||
|
||||
@@ -432,6 +435,7 @@ class BaseSubset:
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -463,6 +467,7 @@ class BaseSubset:
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
@@ -496,6 +501,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -524,6 +530,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -567,6 +574,7 @@ class FineTuningSubset(BaseSubset):
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -595,6 +603,7 @@ class FineTuningSubset(BaseSubset):
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -634,6 +643,7 @@ class ControlNetSubset(BaseSubset):
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -662,6 +672,7 @@ class ControlNetSubset(BaseSubset):
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -682,6 +693,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -716,6 +728,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
if resize_interpolation is not None:
|
||||
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
@@ -1459,6 +1475,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
|
||||
return get_jxl_size(image_path)
|
||||
# return imagesize.get(image_path)
|
||||
image_size = imagesize.get(image_path)
|
||||
if image_size[0] <= 0:
|
||||
@@ -1505,7 +1523,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
nh = int(height * scale + 0.5)
|
||||
nw = int(width * scale + 0.5)
|
||||
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
||||
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
||||
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
|
||||
face_cx = int(face_cx * scale + 0.5)
|
||||
face_cy = int(face_cy * scale + 0.5)
|
||||
height, width = nh, nw
|
||||
@@ -1602,7 +1620,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
|
||||
)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
@@ -1866,8 +1884,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
system_prompt: Optional[str],
|
||||
resize_interpolation: Optional[str],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -2099,6 +2118,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else ""
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -2154,8 +2174,9 @@ class FineTuningDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -2381,9 +2402,10 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2415,6 +2437,7 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_suffix,
|
||||
subset.token_warmup_min,
|
||||
subset.token_warmup_step,
|
||||
resize_interpolation=subset.resize_interpolation,
|
||||
)
|
||||
db_subsets.append(db_subset)
|
||||
|
||||
@@ -2433,6 +2456,7 @@ class ControlNetDataset(BaseDataset):
|
||||
debug_dataset,
|
||||
validation_split,
|
||||
validation_seed,
|
||||
resize_interpolation,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
@@ -2441,7 +2465,8 @@ class ControlNetDataset(BaseDataset):
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_seed = validation_seed
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
@@ -2529,9 +2554,8 @@ class ControlNetDataset(BaseDataset):
|
||||
assert (
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
cond_img = cv2.resize(
|
||||
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
|
||||
) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
|
||||
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
|
||||
# TODO support random crop
|
||||
# 現在サポートしているcropはrandomではなく中央のみ
|
||||
@@ -2545,7 +2569,7 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
|
||||
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -2942,17 +2966,13 @@ def load_image(image_path, alpha=False):
|
||||
|
||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
|
||||
def trim_and_resize_if_required(
|
||||
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
|
||||
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
|
||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
|
||||
image_height, image_width = image.shape[0:2]
|
||||
original_size = (image_width, image_height) # size before resize
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
else:
|
||||
image = pil_resize(image, resized_size)
|
||||
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
@@ -2997,7 +3017,7 @@ def load_images_and_masks_for_caching(
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
|
||||
original_sizes.append(original_size)
|
||||
crop_ltrbs.append(crop_ltrb)
|
||||
@@ -3038,7 +3058,7 @@ def cache_batch_latents(
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
@@ -4518,7 +4538,13 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resize_interpolation",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"],
|
||||
help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_warmup_min",
|
||||
type=int,
|
||||
@@ -6566,3 +6592,4 @@ class LossRecorder:
|
||||
if losses == 0:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
||||
|
||||
131
library/utils.py
131
library/utils.py
@@ -16,7 +16,6 @@ from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -261,11 +262,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -276,6 +276,9 @@ class MemoryEfficientSafeOpen:
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
@@ -293,10 +296,9 @@ class MemoryEfficientSafeOpen:
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
||||
header_json = f.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
@@ -377,7 +379,7 @@ def load_safetensors(
|
||||
# region Image utils
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
def pil_resize(image, size, interpolation):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
if has_alpha:
|
||||
@@ -385,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
else:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
resized_pil = pil_image.resize(size, interpolation)
|
||||
resized_pil = pil_image.resize(size, resample=interpolation)
|
||||
|
||||
# Convert back to cv2 format
|
||||
if has_alpha:
|
||||
@@ -396,6 +398,117 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
Args:
|
||||
image: numpy.ndarray
|
||||
width: int Original image width
|
||||
height: int Original image height
|
||||
resized_width: int Resized image width
|
||||
resized_height: int Resized image height
|
||||
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
|
||||
|
||||
Returns:
|
||||
image
|
||||
"""
|
||||
|
||||
# Ensure all size parameters are actual integers
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
resized_width = int(resized_width)
|
||||
resized_height = int(resized_height)
|
||||
|
||||
if resize_interpolation is None:
|
||||
if width >= resized_width and height >= resized_height:
|
||||
resize_interpolation = "area"
|
||||
else:
|
||||
resize_interpolation = "lanczos"
|
||||
|
||||
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
|
||||
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
|
||||
|
||||
resized_size = (resized_width, resized_height)
|
||||
if use_pil:
|
||||
interpolation = get_pil_interpolation(resize_interpolation)
|
||||
image = pil_resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (PIL)")
|
||||
else:
|
||||
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||
image = cv2.resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (cv2)")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
"""
|
||||
Convert interpolation value to cv2 interpolation integer
|
||||
|
||||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
return cv2.INTER_NEAREST_EXACT
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# bilinear interpolation
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# bicubic interpolation
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
elif interpolation == "box":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
|
||||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return Image.Resampling.LANCZOS
|
||||
elif interpolation == "nearest":
|
||||
# Pick one nearest pixel from the input image. Ignore all other input pixels.
|
||||
return Image.Resampling.NEAREST
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
|
||||
return Image.Resampling.BILINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||
return Image.Resampling.BICUBIC
|
||||
elif interpolation == "area":
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Area interpolation is related to cv2.INTER_AREA
|
||||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||
return Image.Resampling.HAMMING
|
||||
elif interpolation == "box":
|
||||
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
|
||||
return Image.Resampling.BOX
|
||||
else:
|
||||
return None
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -9,11 +9,13 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
else:
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
|
||||
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_norm_cache(self, org_module_weight: Tensor):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
# Not running GGPO so not currently running update norms
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
module_weights = self.lora_up.weight @ self.lora_down.weight
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
if self.training is False:
|
||||
print(f"skipping update_grad_norms for {self.lora_name}")
|
||||
return
|
||||
|
||||
lora_down_grad = None
|
||||
lora_up_grad = None
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if name == "lora_down.weight":
|
||||
lora_down_grad = param.grad
|
||||
elif name == "lora_up.weight":
|
||||
lora_up_grad = param.grad
|
||||
|
||||
# Calculate gradient norms if we have both gradients
|
||||
if lora_down_grad is not None and lora_up_grad is not None:
|
||||
with torch.autocast(self.device.type):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
@@ -420,6 +555,16 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
ggpo_beta = kwargs.get("ggpo_beta", None)
|
||||
ggpo_sigma = kwargs.get("ggpo_sigma", None)
|
||||
|
||||
if ggpo_beta is not None:
|
||||
ggpo_beta = float(ggpo_beta)
|
||||
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -449,6 +594,8 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def update_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_norms()
|
||||
|
||||
def update_grad_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
|
||||
def weight_norms(self) -> Tensor:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -7,9 +7,11 @@ opencv-python==4.8.1.78
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.44.0
|
||||
prodigyopt==1.0
|
||||
lion-pytorch==0.0.6
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.5.0
|
||||
prodigy-plus-schedule-free==1.9.0
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
safetensors==0.4.4
|
||||
# gradio==3.16.2
|
||||
|
||||
@@ -24,7 +24,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -170,12 +170,9 @@ def process(args):
|
||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||
|
||||
if scale != 1.0:
|
||||
w = int(w * scale + .5)
|
||||
h = int(h * scale + .5)
|
||||
if scale < 1.0:
|
||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
face_img = pil_resize(face_img, (w, h))
|
||||
rw = int(w * scale + .5)
|
||||
rh = int(h * scale + .5)
|
||||
face_img = resize_image(face_img, w, h, rw, rh)
|
||||
cx = int(cx * scale + .5)
|
||||
cy = int(cy * scale + .5)
|
||||
fw = int(fw * scale + .5)
|
||||
|
||||
166
tools/merge_sd3_safetensors.py
Normal file
166
tools/merge_sd3_safetensors.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import argparse
|
||||
import os
|
||||
import gc
|
||||
from typing import Dict, Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_safetensors(
|
||||
dit_path: str,
|
||||
vae_path: Optional[str] = None,
|
||||
clip_l_path: Optional[str] = None,
|
||||
clip_g_path: Optional[str] = None,
|
||||
t5xxl_path: Optional[str] = None,
|
||||
output_path: str = "merged_model.safetensors",
|
||||
device: str = "cpu",
|
||||
save_precision: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Merge multiple safetensors files into a single file
|
||||
|
||||
Args:
|
||||
dit_path: Path to the DiT/MMDiT model
|
||||
vae_path: Path to the VAE model
|
||||
clip_l_path: Path to the CLIP-L model
|
||||
clip_g_path: Path to the CLIP-G model
|
||||
t5xxl_path: Path to the T5-XXL model
|
||||
output_path: Path to save the merged model
|
||||
device: Device to load tensors to
|
||||
save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16')
|
||||
"""
|
||||
logger.info("Starting to merge safetensors files...")
|
||||
|
||||
# Convert save_precision string to torch dtype if specified
|
||||
if save_precision:
|
||||
target_dtype = str_to_dtype(save_precision)
|
||||
else:
|
||||
target_dtype = None
|
||||
|
||||
# 1. Get DiT metadata if available
|
||||
metadata = None
|
||||
try:
|
||||
with safe_open(dit_path, framework="pt") as f:
|
||||
metadata = f.metadata() # may be None
|
||||
if metadata:
|
||||
logger.info(f"Found metadata in DiT model: {metadata}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read metadata from DiT model: {e}")
|
||||
|
||||
# 2. Create empty merged state dict
|
||||
merged_state_dict = {}
|
||||
|
||||
# 3. Load and merge each model with memory management
|
||||
|
||||
# DiT/MMDiT - prefix: model.diffusion_model.
|
||||
# This state dict may have VAE keys.
|
||||
logger.info(f"Loading DiT model from {dit_path}")
|
||||
dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding DiT model with {len(dit_state_dict)} keys")
|
||||
for key, value in dit_state_dict.items():
|
||||
if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"model.diffusion_model.{key}"] = value
|
||||
# Free memory
|
||||
del dit_state_dict
|
||||
gc.collect()
|
||||
|
||||
# VAE - prefix: first_stage_model.
|
||||
# May be omitted if VAE is already included in DiT model.
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE model from {vae_path}")
|
||||
vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding VAE model with {len(vae_state_dict)} keys")
|
||||
for key, value in vae_state_dict.items():
|
||||
if key.startswith("first_stage_model."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"first_stage_model.{key}"] = value
|
||||
# Free memory
|
||||
del vae_state_dict
|
||||
gc.collect()
|
||||
|
||||
# CLIP-L - prefix: text_encoders.clip_l.
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading CLIP-L model from {clip_l_path}")
|
||||
clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys")
|
||||
for key, value in clip_l_state_dict.items():
|
||||
if key.startswith("text_encoders.clip_l.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del clip_l_state_dict
|
||||
gc.collect()
|
||||
|
||||
# CLIP-G - prefix: text_encoders.clip_g.
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading CLIP-G model from {clip_g_path}")
|
||||
clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys")
|
||||
for key, value in clip_g_state_dict.items():
|
||||
if key.startswith("text_encoders.clip_g.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del clip_g_state_dict
|
||||
gc.collect()
|
||||
|
||||
# T5-XXL - prefix: text_encoders.t5xxl.
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading T5-XXL model from {t5xxl_path}")
|
||||
t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys")
|
||||
for key, value in t5xxl_state_dict.items():
|
||||
if key.startswith("text_encoders.t5xxl.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del t5xxl_state_dict
|
||||
gc.collect()
|
||||
|
||||
# 4. Save merged state dict
|
||||
logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total")
|
||||
mem_eff_save_file(merged_state_dict, output_path, metadata)
|
||||
logger.info("Successfully merged safetensors files")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file")
|
||||
parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model")
|
||||
parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model")
|
||||
parser.add_argument("--clip_l", help="Path to the CLIP-L model")
|
||||
parser.add_argument("--clip_g", help="Path to the CLIP-G model")
|
||||
parser.add_argument("--t5xxl", help="Path to the T5-XXL model")
|
||||
parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model")
|
||||
parser.add_argument("--device", default="cpu", help="Device to load tensors to")
|
||||
parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
merge_safetensors(
|
||||
dit_path=args.dit,
|
||||
vae_path=args.vae,
|
||||
clip_l_path=args.clip_l,
|
||||
clip_g_path=args.clip_g,
|
||||
t5xxl_path=args.t5xxl,
|
||||
output_path=args.output,
|
||||
device=args.device,
|
||||
save_precision=args.save_precision,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,7 @@ import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Select interpolation method
|
||||
if interpolation == 'lanczos4':
|
||||
pil_interpolation = Image.LANCZOS
|
||||
elif interpolation == 'cubic':
|
||||
pil_interpolation = Image.BICUBIC
|
||||
else:
|
||||
cv2_interpolation = cv2.INTER_AREA
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
||||
for filename in os.listdir(src_img_folder):
|
||||
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image
|
||||
if cv2_interpolation:
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||
else:
|
||||
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
|
||||
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
|
||||
else:
|
||||
new_height, new_width = img.shape[0:2]
|
||||
|
||||
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
||||
parser.add_argument('--divisible_by', type=int,
|
||||
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
||||
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
|
||||
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
|
||||
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
||||
parser.add_argument('--copy_associated_files', action='store_true',
|
||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||
|
||||
@@ -69,13 +69,20 @@ class NetworkTrainer:
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
mean_grad_norm=None,
|
||||
mean_combined_norm=None,
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/average_key_norm"] = mean_norm
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
if mean_norm is not None:
|
||||
logs["norm/avg_key_norm"] = mean_norm
|
||||
if mean_grad_norm is not None:
|
||||
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||
if mean_combined_norm is not None:
|
||||
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lrs):
|
||||
@@ -652,6 +659,10 @@ class NetworkTrainer:
|
||||
return
|
||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||
|
||||
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
|
||||
# if not hasattr(network, "prepare_network"):
|
||||
# network.prepare_network = lambda args: None
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
@@ -1018,6 +1029,7 @@ class NetworkTrainer:
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1043,6 +1055,7 @@ class NetworkTrainer:
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
}
|
||||
|
||||
subsets_metadata = []
|
||||
@@ -1060,6 +1073,7 @@ class NetworkTrainer:
|
||||
"enable_wildcard": bool(subset.enable_wildcard),
|
||||
"caption_prefix": subset.caption_prefix,
|
||||
"caption_suffix": subset.caption_suffix,
|
||||
"resize_interpolation": subset.resize_interpolation,
|
||||
}
|
||||
|
||||
image_dir_or_metadata_file = None
|
||||
@@ -1401,6 +1415,11 @@ class NetworkTrainer:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
if hasattr(network, "update_grad_norms"):
|
||||
network.update_grad_norms()
|
||||
if hasattr(network, "update_norms"):
|
||||
network.update_norms()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
@@ -1409,9 +1428,23 @@ class NetworkTrainer:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
weight_norms = network.weight_norms()
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1444,14 +1477,21 @@ class NetworkTrainer:
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user