Compare commits

..

2 Commits

Author SHA1 Message Date
Kohya S
a701fe5c37 fix typos 2023-10-03 23:07:36 +09:00
Kohya S
4c5d6d1ba3 initial version of wuerstchen 2023-10-03 22:59:56 +09:00
11 changed files with 869 additions and 679 deletions

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.16.15

View File

@@ -249,40 +249,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Oct 9. 2023 / 2023/10/9
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
- Please install Onnx and other required packages.
1. Uninstall TensorFlow.
1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
1. `pip install protobuf==3.20.3` This is required for Onnx.
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
- [OFT](https://oft.wyliu.com/) is now supported.
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
- Other bug fixes and improvements.
- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。
- Onnxを使用する場合は、`--onnx` オプションを指定してください。
- Onnx とその他の必要なパッケージをインストールしてください。
1. TensorFlow をアンインストールしてください。
1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要。
1. `pip install protobuf==3.20.3` Onnxのために必要。
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0`
- `tag_images_by_wd_14_tagger.py``--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。
- [OFT](https://oft.wyliu.com/) をサポートしました。
- `sdxl_train_network.py``--network_module``networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。
- `sdxl_gen_img.py` でも同様に OFT を指定できます。
- OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。
- 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。
- その他のバグ修正と改善。
### Oct 1. 2023 / 2023/10/1
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.

View File

@@ -1,14 +1,16 @@
import argparse
import csv
import glob
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
import library.train_util as train_util
@@ -18,7 +20,6 @@ IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
FILES_ONNX = ["model.onnx"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
@@ -80,10 +81,7 @@ def main(args):
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files += FILES_ONNX
for file in files:
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
@@ -98,46 +96,7 @@ def main(args):
print("using existing wd14 tagger model")
# 画像を読み込む
if args.onnx:
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
)
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
model = load_model(args.model_dir)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
@@ -165,14 +124,8 @@ def main(args):
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
@@ -212,27 +165,9 @@ def main(args):
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
tag_text = ", ".join(combined_tags)
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
@@ -348,15 +283,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
# スペルミスしていたオプションを復元する

View File

@@ -96,7 +96,6 @@ try:
except:
pass
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
@@ -104,14 +103,6 @@ try:
except:
pass
# JPEG-XL on Windows
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
@@ -2004,7 +1995,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
if show_input_ids:
print(f"input ids: {iid}")
if "input_ids2" in example:
if "input_ids2" in example and example["input_ids2"] is not None:
print(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
@@ -2021,6 +2012,11 @@ def debug_dataset(train_dataset, show_input_ids=False):
cond_img = cond_img[:, :, ::-1]
if os.name == "nt":
cv2.imshow("cond_img", cond_img)
for key in example.keys():
if key in ["images", "conditioning_images", "input_ids", "input_ids2"]:
continue
print(f"{key}: {example[key][j] if example[key] is not None else None}")
if os.name == "nt": # only windows
cv2.imshow("img", im)

View File

@@ -1,430 +0,0 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
):
"""
dim -> num blocks
alpha -> constraint
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if "Linear" in org_module.__class__.__name__:
out_dim = org_module.out_features
elif "Conv" in org_module.__class__.__name__:
out_dim = org_module.out_channels
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.out_dim = out_dim
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
R = torch.block_diag(*block_R_weighted)
return R
def forward(self, x, scale=None):
x = self.org_forward(x)
if self.multiplier == 0.0:
return x
R = self.get_weight().to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)
x = x.permute(0, 3, 1, 2)
else:
x = torch.matmul(x, R)
return x
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None, sign=1):
R = self.get_weight(multiplier) * sign
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
if enable_conv is not None:
enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
enable_conv=enable_conv,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
has_conv2d = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and param.dim() == 4:
has_conv2d = True
if all_linear is None:
if param.dim() == 3 and "attn" not in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None:
break
if has_conv2d is None:
has_conv2d = False
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
enable_conv=has_conv2d,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
enable_conv: Optional[bool] = False,
module_class: Type[object] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
print(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
is_conv2d = "Conv2d" in child_module.__class__.__name__
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# print(oft_name)
oft = module_class(
oft_name,
child_module,
self.multiplier,
dim,
alpha,
)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
if enable_conv:
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
print(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# print num of params
num_params = 0
for p in params:
num_params += p.numel()
print(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -19,14 +19,8 @@ huggingface-hub==0.15.1
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning (tensorflow)
# for WD14 captioning
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# for kohya_ss library

View File

@@ -1,84 +0,0 @@
import argparse
import os
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
def split(args):
# load embedding
if args.embedding.endswith(".safetensors"):
embedding = load_file(args.embedding)
with safe_open(args.embedding, framework="pt") as f:
metadata = f.metadata()
else:
embedding = torch.load(args.embedding)
metadata = None
# check format
if "emb_params" in embedding:
# SD1/2
keys = ["emb_params"]
elif "clip_l" in embedding:
# SDXL
keys = ["clip_l", "clip_g"]
else:
print("Unknown embedding format")
exit()
num_vectors = embedding[keys[0]].shape[0]
# prepare output directory
os.makedirs(args.output_dir, exist_ok=True)
# prepare splits
if args.vectors_per_split is not None:
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
vectors_for_split = [args.vectors_per_split] * num_splits
if sum(vectors_for_split) > num_vectors:
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
assert sum(vectors_for_split) == num_vectors
elif args.vectors is not None:
vectors_for_split = args.vectors
num_splits = len(vectors_for_split)
else:
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
exit()
assert (
sum(vectors_for_split) == num_vectors
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
# split
basename = os.path.splitext(os.path.basename(args.embedding))[0]
done_vectors = 0
for i, num_vectors in enumerate(vectors_for_split):
print(f"Splitting {num_vectors} vectors...")
split_embedding = {}
for key in keys:
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
save_file(split_embedding, output_file, metadata)
print(f"Saved to {output_file}")
done_vectors += num_vectors
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge models")
parser.add_argument("--embedding", type=str, help="Embedding to split")
parser.add_argument("--output_dir", type=str, help="Output directory")
parser.add_argument(
"--vectors_per_split",
type=int,
default=None,
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
)
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
args = parser.parse_args()
split(args)

View File

@@ -283,10 +283,7 @@ class NetworkTrainer:
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
if "dropout" not in net_kwargs:
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout
# LyCORIS will work with this...
network = network_module.create_network(
1.0,
args.network_dim,

View File

@@ -7,13 +7,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -170,13 +167,6 @@ class TextualInversionTrainer:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
assert (
args.token_string is not None or args.token_strings is not None
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
assert (
not use_template or args.token_strings is None
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -225,17 +215,9 @@ class TextualInversionTrainer:
# add new word to tokenizer, count is num_vectors_per_token
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
if args.token_strings is not None:
token_strings = args.token_strings
assert (
len(token_strings) == args.num_vectors_per_token
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
for token_string in token_strings:
self.assert_token_string(token_string, tokenizers)
else:
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
@@ -350,7 +332,7 @@ class TextualInversionTrainer:
prompt_replacement = None
else:
# サンプル生成用
if args.num_vectors_per_token > 1 and args.token_strings is None:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
@@ -770,13 +752,6 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument(
"--token_strings",
type=str,
default=None,
nargs="*",
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",

View File

@@ -0,0 +1,196 @@
# use Diffusers' pipeline to generate images
import argparse
import datetime
import math
import os
import random
import re
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_prior import WuerstchenPrior
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
# from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
from wuerstchen_train import EfficientNetEncoder
def generate(args):
dtype = torch.float32
if args.fp16:
dtype = torch.float16
elif args.bf16:
dtype = torch.bfloat16
device = args.device
os.makedirs(args.outdir, exist_ok=True)
# load tokenizer
print("load tokenizer")
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
# load text encoder
print("load text encoder")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=dtype
)
# load prior model
print("load prior model")
prior: WuerstchenPrior = WuerstchenPrior.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="prior", torch_dtype=dtype
)
# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(model)
# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(prior, True)
# load pipeline
print("load pipeline")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=prior,
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
)
pipeline = pipeline.to(device, torch_dtype=dtype)
# generate image
while True:
width = args.w
height = args.h
seed = args.seed
negative_prompt = None
if args.interactive:
print("prompt:")
prompt = input()
if prompt == "":
break
# parse prompt
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seed = int(m.group(1))
print(f"seed: {seed}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
else:
prompt = args.prompt
negative_prompt = args.negative_prompt
if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)
with torch.autocast(device):
image = pipeline(
prompt,
negative_prompt=negative_prompt,
# prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
width=width,
height=height,
).images[0]
# save image
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
image.save(os.path.join(args.outdir, f"image_{timestamp}.png"))
if not args.interactive:
break
print("Done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
# train_util.add_sd_models_arguments(parser)
parser.add_argument(
"--pretrained_prior_model_name_or_path",
type=str,
default="warp-ai/wuerstchen-prior",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_decoder_model_name_or_path",
type=str,
default="warp-ai/wuerstchen",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--outdir", type=str, default=".")
parser.add_argument("--w", type=int, default=1024)
parser.add_argument("--h", type=int, default=1024)
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
generate(args)

View File

@@ -0,0 +1,648 @@
# training with captions
# heavily based on https://github.com/kashif/diffusers
import argparse
import gc
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
from torchvision import transforms
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_prior import WuerstchenPrior
from huggingface_hub import hf_hub_download
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
class EfficientNetEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
super().__init__()
if effnet == "efficientnet_v2_s":
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
else:
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
self.mapper = torch.nn.Sequential(
torch.nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
torch.nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)
def forward(self, x):
return self.mapper(self.backbone(x))
class DatasetWrapper(train_util.DatasetGroup):
r"""
Wrapper for datasets to be used with DataLoader.
add effnet_pixel_values and text_mask to dataset.
"""
# なんかうまいことやればattributeをコピーしなくてもいい気がする
def __init__(self, dataset, tokenizer):
self.dataset = dataset
self.image_data = dataset.image_data
self.tokenizer = tokenizer
self.num_train_images = dataset.num_train_images
self.datasets = dataset.datasets
# images are already resized
self.effnet_transforms = transforms.Compose(
[
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
def __getitem__(self, idx):
item = self.dataset[idx]
# create attention mask by input_ids
input_ids = item["input_ids"]
attention_mask = torch.ones_like(input_ids)
attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
text_mask = attention_mask.bool()
item["text_mask"] = text_mask
# create effnet input
images = item["images"]
# effnet_pixel_values = [self.effnet_transforms(image) for image in images]
# effnet_pixel_values = torch.stack(effnet_pixel_values, dim=0)
effnet_pixel_values = self.effnet_transforms(((images) + 1.0) / 2.0)
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format)
item["effnet_pixel_values"] = effnet_pixel_values
return item
def __len__(self):
return len(self.dataset)
def add_replacement(self, str_from, str_to):
self.dataset.add_replacement(str_from, str_to)
def enable_XTI(self, *args, **kwargs):
self.dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
self.dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
self.dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
def set_caching_mode(self, caching_mode):
self.dataset.set_caching_mode(caching_mode)
def verify_bucket_reso_steps(self, min_steps: int):
self.dataset.verify_bucket_reso_steps(min_steps)
def is_latent_cacheable(self) -> bool:
return self.dataset.is_latent_cacheable()
def is_text_encoder_output_cacheable(self) -> bool:
return self.dataset.is_text_encoder_output_cacheable()
def set_current_epoch(self, epoch):
self.dataset.set_current_epoch(epoch)
def set_current_step(self, step):
self.dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
self.dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self):
self.dataset.disable_token_padding()
def get_hidden_states(args: argparse.Namespace, input_ids, text_mask, tokenizer, text_encoder, weight_dtype=None):
# with no_token_padding, the length is not max length, return result immediately
if input_ids.size()[-1] != tokenizer.model_max_length:
return text_encoder(input_ids, attention_mask=text_mask)[0]
# input_ids: b,n,77
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
text_mask = text_mask.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
if args.clip_skip is None:
encoder_hidden_states = text_encoder(input_ids)[0]
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states
def train(args):
# TODO: add checking for unsupported args
# TODO: cache image encoder outputs instead of latents
# train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
print("prepare tokenizer")
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
print("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
# wrap for wuestchen
train_dataset_group = DatasetWrapper(train_dataset_group, tokenizer)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# Load scheduler, effnet, tokenizer, clip_model
print("prepare scheduler, effnet, clip_model")
noise_scheduler = DDPMWuerstchenScheduler()
# TODO support explicit local caching for faster loading
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
image_encoder = EfficientNetEncoder()
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
image_encoder.eval()
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
)
# Freeze text_encoder and image_encoder
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
# load prior model
prior: WuerstchenPrior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
# EMA is not supported yet
# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(model)
# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
accelerator.print("Use xformers by Diffusers")
set_diffusers_xformers_flag(prior, True)
# 学習を準備する
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
# prior.enable_gradient_checkpointing()
print("*" * 80)
print("*** Prior model does not support gradient checkpointing. ***")
print("*" * 80)
training_models.append(prior)
text_encoder.requires_grad_(False)
text_encoder.eval()
for m in training_models:
m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters
n_params = 0
for p in params:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
prior.to(weight_dtype)
text_encoder.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
prior.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
prior, image_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
prior, image_encoder, optimizer, train_dataloader, lr_scheduler
)
(prior, image_encoder) = train_util.transform_models_if_DDP([prior, image_encoder])
text_encoder.to(weight_dtype)
text_encoder.to(accelerator.device)
image_encoder.to(weight_dtype)
image_encoder.to(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"wuerstchen_finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
# workaround for DDPMWuerstchenScheduler
def add_noise(
scheduler: DDPMWuerstchenScheduler,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod_timesteps = scheduler._alpha_cumprod(timesteps, original_samples.device)
sqrt_alpha_prod = alphas_cumprod_timesteps**0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod_timesteps) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
input_ids = batch["input_ids"]
text_mask = batch["text_mask"]
effnet_pixel_values = batch["effnet_pixel_values"]
with torch.no_grad():
input_ids = input_ids.to(accelerator.device)
text_mask = text_mask.to(accelerator.device)
prompt_embeds = get_hidden_states(
args, input_ids, text_mask, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
image_embeds = image_encoder(effnet_pixel_values)
image_embeds = image_embeds.add(1.0).div(42.0) # scale
# Sample noise that we'll add to the image_embeds
noise = torch.randn_like(image_embeds)
bsz = image_embeds.shape[0]
# Sample a random timestep for each image
# TODO support mul/add/clump
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
# add noise to latent: This is same to Diffuzz.diffuse in diffuzz.py
# noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
noisy_latents = add_noise(noise_scheduler, image_embeds, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = prior(noisy_latents, timesteps, prompt_embeds)
target = noise
# TODO add consistency loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# TODO ここでサンプルを生成する
# sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder, text_encoder2],
# prior,
# )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
# TODO simplify to save prior only
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
)
ckpt_name = train_util.get_step_ckpt_name(args, "", global_step)
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, ckpt_name))
# TODO remove older saved models
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
epoch_no = epoch + 1
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
if saving:
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
)
ckpt_name = train_util.get_epoch_ckpt_name(args, "", epoch)
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, ckpt_name))
# TODO remove older saved models
# TODO ここでサンプルを生成する
is_main_process = accelerator.is_main_process
accelerator.end_training()
if args.save_state: # and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
)
ckpt_name = train_util.get_last_ckpt_name(args, "")
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, ckpt_name))
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
# train_util.add_sd_models_arguments(parser)
parser.add_argument(
"--pretrained_prior_model_name_or_path",
type=str,
default="warp-ai/wuerstchen-prior",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_decoder_model_name_or_path",
type=str,
default="warp-ai/wuerstchen",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
# train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
# TODO add assertion for SD related arguments
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)