mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 09:30:28 +00:00
support LoRA training for Stable Cascade Stage C
This commit is contained in:
55
README.md
55
README.md
@@ -45,6 +45,32 @@ If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be rea
|
||||
|
||||
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
|
||||
|
||||
## LoRA training
|
||||
|
||||
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
|
||||
|
||||
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
|
||||
|
||||
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
|
||||
|
||||
Text Encoder LoRA training is implemented, but untested.
|
||||
|
||||
## Image generation
|
||||
|
||||
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
|
||||
|
||||
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
|
||||
|
||||
The following prompt options are available.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--t` Specifies the t_start of the generation.
|
||||
* `--f` Specifies the shift of the generation.
|
||||
|
||||
# Stable Cascade Stage C の学習
|
||||
|
||||
@@ -92,6 +118,35 @@ SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が
|
||||
|
||||
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
|
||||
|
||||
|
||||
## LoRA 等の学習
|
||||
|
||||
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
|
||||
|
||||
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
|
||||
|
||||
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習(Pivotal Tuning)も実装されていません。
|
||||
|
||||
Text Encoder の LoRA 学習は実装してありますが、未テストです。
|
||||
|
||||
## 画像生成
|
||||
|
||||
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
|
||||
|
||||
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
|
||||
|
||||
プロンプトオプションとして以下が使用できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--t` Specifies the t_start of the generation.
|
||||
* `--f` Specifies the shift of the generation.
|
||||
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
@@ -12,6 +12,9 @@ import torch.utils.checkpoint
|
||||
import torchvision
|
||||
|
||||
|
||||
MODEL_VERSION_STABLE_CASCADE = "stable_cascade"
|
||||
|
||||
|
||||
# region VectorQuantize
|
||||
|
||||
# from torchtools https://github.com/pabloppp/pytorch-tools
|
||||
|
||||
@@ -193,7 +193,7 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
|
||||
return previewer
|
||||
|
||||
|
||||
def get_sai_model_spec(args):
|
||||
def get_sai_model_spec(args, lora=False):
|
||||
timestamp = time.time()
|
||||
|
||||
reso = args.resolution
|
||||
@@ -212,7 +212,7 @@ def get_sai_model_spec(args):
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
lora,
|
||||
False,
|
||||
timestamp,
|
||||
title=title,
|
||||
|
||||
@@ -841,9 +841,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
is_group_conv2d = is_conv2d and child_module.groups > 1
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
# if is_group_conv2d:
|
||||
# logger.info(f"skip group conv2d: {name}.{child_name}")
|
||||
# continue
|
||||
|
||||
if is_linear or (is_conv2d and not is_group_conv2d):
|
||||
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
@@ -915,6 +920,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# XXX temporary solution for Stable Cascade Stage C: replace all modules
|
||||
if "StageC" in unet.__class__.__name__:
|
||||
logger.info("replace all modules for Stable Cascade Stage C")
|
||||
target_modules = ["Linear", "Conv2d"]
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@@ -66,6 +67,39 @@ def main(args):
|
||||
else:
|
||||
previewer = None
|
||||
|
||||
# LoRA
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
network_args = args.network_args[i]
|
||||
# TODO escape special chars
|
||||
network_args = network_args.split(";")
|
||||
for net_arg in network_args:
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
assert mergeable, "not-mergeable network is not supported yet."
|
||||
|
||||
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
|
||||
|
||||
# 謎のクラス gdf
|
||||
gdf_c = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
@@ -122,7 +156,7 @@ def main(args):
|
||||
cfg = 4
|
||||
timesteps = 20
|
||||
shift = 2
|
||||
t_start = 1.0 # t_start is not an option, but it is a parameter
|
||||
t_start = 1.0
|
||||
negative_prompt = ""
|
||||
seed = None
|
||||
|
||||
@@ -299,6 +333,26 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
||||
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
||||
parser.add_argument(
|
||||
"--network_module",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
1090
stable_cascade_train_c_network.py
Normal file
1090
stable_cascade_train_c_network.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user