mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-07 14:08:49 +00:00
Compare commits
6 Commits
v0.8.3
...
nw_applica
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae0872ba3b | ||
|
|
7f948db158 | ||
|
|
9d7729c00d | ||
|
|
988dee02b9 | ||
|
|
d4b9568269 | ||
|
|
ccc3a481e7 |
36
README.md
36
README.md
@@ -1,3 +1,39 @@
|
||||
## LoRAの層別適用率の探索について
|
||||
|
||||
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
|
||||
|
||||
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
|
||||
|
||||
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
|
||||
|
||||
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
|
||||
|
||||
(何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
|
||||
|
||||
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
|
||||
|
||||
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
|
||||
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
|
||||
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
|
||||
|
||||
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。(Text Encoder を対象とした LoRA の動作は未確認です。)
|
||||
|
||||
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
|
||||
|
||||
### 備考
|
||||
|
||||
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
|
||||
|
||||
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなる(=LoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
|
||||
|
||||
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります(-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
|
||||
|
||||
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
|
||||
|
||||
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
@@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
@@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "11.7"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
||||
torch.cuda.get_device_properties.major = 11
|
||||
torch.cuda.get_device_properties.minor = 7
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
|
||||
@@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
torch.xpu.synchronize(input.device)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
@@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
||||
)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
torch.xpu.synchronize(query.device)
|
||||
return hidden_states
|
||||
|
||||
@@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
@@ -283,6 +284,7 @@ class AttnProcessor:
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import contextlib
|
||||
import os
|
||||
from functools import wraps
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import numpy as np
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
@@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return contextlib.nullcontext()
|
||||
return nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
@@ -25,15 +30,17 @@ def return_xpu(device):
|
||||
|
||||
|
||||
# Autocast
|
||||
original_autocast = torch.autocast
|
||||
def ipex_autocast(*args, **kwargs):
|
||||
if len(args) > 0 and args[0] == "cuda":
|
||||
return original_autocast("xpu", *args[1:], **kwargs)
|
||||
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
||||
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
||||
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type == "cuda":
|
||||
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
else:
|
||||
return original_autocast(*args, **kwargs)
|
||||
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
|
||||
# Latent Antialias CPU Offload:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
return_device = tensor.device
|
||||
@@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||
|
||||
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
original_from_numpy = torch.from_numpy
|
||||
@wraps(torch.from_numpy)
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
else:
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
if torch.xpu.has_fp64_dtype():
|
||||
original_as_tensor = torch.as_tensor
|
||||
@wraps(torch.as_tensor)
|
||||
def as_tensor(data, dtype=None, device=None):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
||||
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
||||
return original_as_tensor(data, dtype=torch.float32, device=device)
|
||||
else:
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
@@ -66,20 +87,25 @@ else:
|
||||
|
||||
|
||||
# Data Type Errors:
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
@wraps(torch.nn.functional.group_norm)
|
||||
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
|
||||
# A1111 BF16
|
||||
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||
@wraps(torch.nn.functional.layer_norm)
|
||||
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
|
||||
|
||||
# Training
|
||||
original_functional_linear = torch.nn.functional.linear
|
||||
@wraps(torch.nn.functional.linear)
|
||||
def functional_linear(input, weight, bias=None):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None):
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
@wraps(torch.nn.functional.conv2d)
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
@wraps(torch.cat)
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
@@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs):
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
@wraps(torch.nn.functional.pad)
|
||||
def functional_pad(input, pad, mode='constant', value=None):
|
||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
@@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None):
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
def torch_tensor(*args, device=None, **kwargs):
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_tensor(*args, device=device, **kwargs)
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
@wraps(torch.Tensor.cuda)
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs):
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs):
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_ones = torch.ones
|
||||
@wraps(torch.ones)
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs):
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_zeros = torch.zeros
|
||||
@wraps(torch.zeros)
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs):
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs):
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
@@ -208,12 +255,14 @@ def torch_Generator(device=None):
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
torch.tensor = torch_tensor
|
||||
@@ -232,7 +281,7 @@ def ipex_hijacks():
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.autocast = ipex_autocast
|
||||
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
@@ -244,5 +293,6 @@ def ipex_hijacks():
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
|
||||
@@ -511,7 +511,9 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
print(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -1223,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
# region application weight
|
||||
|
||||
def get_number_of_blocks(self):
|
||||
# only for SDXL
|
||||
return 20
|
||||
|
||||
def has_text_encoder_block(self):
|
||||
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
|
||||
|
||||
def set_block_wise_weights(self, weights):
|
||||
if self.text_encoder_loras:
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = weights[0]
|
||||
|
||||
for lora in self.unet_loras:
|
||||
# determine block index
|
||||
key = lora.lora_name[10:] # remove "lora_unet_"
|
||||
if key.startswith("input_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 1 # 1-9
|
||||
elif key.startswith("middle_block"):
|
||||
block_index = 10 # int(key.split("_")[2]) + 10
|
||||
elif key.startswith("output_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 11 # 11-19
|
||||
else:
|
||||
print(f"unknown block: {key}")
|
||||
block_index = 0
|
||||
|
||||
lora.multiplier = weights[block_index]
|
||||
|
||||
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
|
||||
# print(f"set block-wise weights to {weights}")
|
||||
|
||||
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
@@ -6,8 +5,6 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
1039
train_network_appl_weights.py
Normal file
1039
train_network_appl_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user