mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
1 Commits
v0.8.5
...
nw_applica
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae0872ba3b |
36
README.md
36
README.md
@@ -1,3 +1,39 @@
|
||||
## LoRAの層別適用率の探索について
|
||||
|
||||
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
|
||||
|
||||
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
|
||||
|
||||
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
|
||||
|
||||
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
|
||||
|
||||
(何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
|
||||
|
||||
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
|
||||
|
||||
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
|
||||
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
|
||||
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
|
||||
|
||||
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。(Text Encoder を対象とした LoRA の動作は未確認です。)
|
||||
|
||||
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
|
||||
|
||||
### 備考
|
||||
|
||||
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
|
||||
|
||||
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなる(=LoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
|
||||
|
||||
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります(-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
|
||||
|
||||
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
|
||||
|
||||
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
@@ -511,7 +511,9 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
print(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -1223,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
# region application weight
|
||||
|
||||
def get_number_of_blocks(self):
|
||||
# only for SDXL
|
||||
return 20
|
||||
|
||||
def has_text_encoder_block(self):
|
||||
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
|
||||
|
||||
def set_block_wise_weights(self, weights):
|
||||
if self.text_encoder_loras:
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = weights[0]
|
||||
|
||||
for lora in self.unet_loras:
|
||||
# determine block index
|
||||
key = lora.lora_name[10:] # remove "lora_unet_"
|
||||
if key.startswith("input_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 1 # 1-9
|
||||
elif key.startswith("middle_block"):
|
||||
block_index = 10 # int(key.split("_")[2]) + 10
|
||||
elif key.startswith("output_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 11 # 11-19
|
||||
else:
|
||||
print(f"unknown block: {key}")
|
||||
block_index = 0
|
||||
|
||||
lora.multiplier = weights[block_index]
|
||||
|
||||
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
|
||||
# print(f"set block-wise weights to {weights}")
|
||||
|
||||
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
1039
train_network_appl_weights.py
Normal file
1039
train_network_appl_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user