support LoRA training for Stable Cascade Stage C

This commit is contained in:
Kohya S
2024-02-20 08:27:11 +09:00
parent 806a6237fb
commit 71e03559e2
6 changed files with 1217 additions and 5 deletions

View File

@@ -45,6 +45,32 @@ If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be rea
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
## LoRA training
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
Text Encoder LoRA training is implemented, but untested.
## Image generation
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
The following prompt options are available.
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
# Stable Cascade Stage C の学習
@@ -92,6 +118,35 @@ SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
## LoRA 等の学習
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習Pivotal Tuningも実装されていません。
Text Encoder の LoRA 学習は実装してありますが、未テストです。
## 画像生成
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
プロンプトオプションとして以下が使用できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).

View File

@@ -12,6 +12,9 @@ import torch.utils.checkpoint
import torchvision
MODEL_VERSION_STABLE_CASCADE = "stable_cascade"
# region VectorQuantize
# from torchtools https://github.com/pabloppp/pytorch-tools

View File

@@ -193,7 +193,7 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
return previewer
def get_sai_model_spec(args):
def get_sai_model_spec(args, lora=False):
timestamp = time.time()
reso = args.resolution
@@ -212,7 +212,7 @@ def get_sai_model_spec(args):
False,
False,
False,
False,
lora,
False,
timestamp,
title=title,

View File

@@ -841,9 +841,14 @@ class LoRANetwork(torch.nn.Module):
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
is_group_conv2d = is_conv2d and child_module.groups > 1
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
# if is_group_conv2d:
# logger.info(f"skip group conv2d: {name}.{child_name}")
# continue
if is_linear or (is_conv2d and not is_group_conv2d):
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
lora_name = lora_name.replace(".", "_")
dim = None
@@ -915,6 +920,11 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# XXX temporary solution for Stable Cascade Stage C: replace all modules
if "StageC" in unet.__class__.__name__:
logger.info("replace all modules for Stable Cascade Stage C")
target_modules = ["Linear", "Conv2d"]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")

View File

@@ -1,4 +1,5 @@
import argparse
import importlib
import math
import os
import random
@@ -66,6 +67,39 @@ def main(args):
else:
previewer = None
# LoRA
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
assert mergeable, "not-mergeable network is not supported yet."
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
# 謎のクラス gdf
gdf_c = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
@@ -122,7 +156,7 @@ def main(args):
cfg = 4
timesteps = 20
shift = 2
t_start = 1.0 # t_start is not an option, but it is a parameter
t_start = 1.0
negative_prompt = ""
seed = None
@@ -299,6 +333,26 @@ if __name__ == "__main__":
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument(
"--network_module",
type=str,
default=None,
nargs="*",
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
)
parser.add_argument(
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
)
parser.add_argument(
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
)
parser.add_argument(
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
args = parser.parse_args()
main(args)

File diff suppressed because it is too large Load Diff