mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
2 Commits
feat-safet
...
free-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bdcd9b2db | ||
|
|
40525d4f4b |
@@ -996,6 +996,109 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
||||
)
|
||||
|
||||
# FreeU
|
||||
self.freeU = False
|
||||
self.freeUB1 = 1.0
|
||||
self.freeUB2 = 1.0
|
||||
self.freeUS1 = 1.0
|
||||
self.freeUS2 = 1.0
|
||||
self.freeURThres = 1
|
||||
|
||||
# implementation of FreeU
|
||||
# FreeU: Free Lunch in Diffusion U-Net https://arxiv.org/abs/2309.11497
|
||||
|
||||
def set_free_u_enabled(self, enabled: bool, b1=1.0, b2=1.0, s1=1.0, s2=1.0, rthresh=1):
|
||||
print(f"FreeU: {enabled}, b1={b1}, b2={b2}, s1={s1}, s2={s2}, rthresh={rthresh}")
|
||||
self.freeU = enabled
|
||||
self.freeUB1 = b1
|
||||
self.freeUB2 = b2
|
||||
self.freeUS1 = s1
|
||||
self.freeUS2 = s2
|
||||
self.freeURThres = rthresh
|
||||
|
||||
def spectral_modulation(self, skip_feature, sl=1.0, rthresh=1):
|
||||
"""
|
||||
スキップ特徴を周波数領域で修正する関数
|
||||
|
||||
:param skip_feature: スキップ特徴のテンソル [b, c, H, W]
|
||||
:param sl: スケーリング係数
|
||||
:param rthresh: 周波数の閾値
|
||||
:return: 修正されたスキップ特徴
|
||||
"""
|
||||
|
||||
import torch.fft
|
||||
|
||||
r"""
|
||||
# 論文に従った実装
|
||||
|
||||
org_dtype = skip_feature.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
skip_feature = skip_feature.to(torch.float32)
|
||||
|
||||
# FFTを計算
|
||||
F = torch.fft.fftn(skip_feature, dim=(2, 3))
|
||||
|
||||
# 周波数領域での座標を計算
|
||||
freq_x = torch.fft.fftfreq(skip_feature.size(2), d=1 / skip_feature.size(2)).to(skip_feature.device)
|
||||
freq_y = torch.fft.fftfreq(skip_feature.size(3), d=1 / skip_feature.size(3)).to(skip_feature.device)
|
||||
|
||||
# 2Dグリッドを作成
|
||||
freq_x = freq_x[:, None] # [H, 1]
|
||||
freq_y = freq_y[None, :] # [1, W]
|
||||
|
||||
# ラジアス(距離)を計算
|
||||
r = torch.sqrt(freq_x**2 + freq_y**2)
|
||||
# 32,32: tensor(0., device='cuda:0') tensor(22.6274, device='cuda:0') tensor(12.2521, device='cuda:0')
|
||||
# 64,64: tensor(0., device='cuda:0') tensor(45.2548, device='cuda:0') tensor(24.4908, device='cuda:0')
|
||||
# 128,128: tensor(0., device='cuda:0') tensor(90.5097, device='cuda:0') tensor(48.9748, device='cuda:0')
|
||||
|
||||
# マスクを作成
|
||||
mask = torch.ones_like(r)
|
||||
mask[r < rthresh] = sl
|
||||
|
||||
# b,c,H,Wの形状にブロードキャスト
|
||||
# TODO shapeごとに同じなのでキャッシュすると良さそう
|
||||
mask = mask[None, None, :, :]
|
||||
|
||||
# 周波数領域での要素ごとの乗算
|
||||
F_prime = F * mask
|
||||
|
||||
# 逆FFTを計算
|
||||
modified_skip_feature = torch.fft.ifftn(F_prime, dim=(2, 3))
|
||||
|
||||
modified_skip_feature = modified_skip_feature.real # 実部のみを取得
|
||||
"""
|
||||
|
||||
# 公式リポジトリの実装
|
||||
|
||||
org_dtype = skip_feature.dtype
|
||||
|
||||
x = skip_feature
|
||||
threshold = rthresh
|
||||
scale = sl
|
||||
|
||||
# FFT
|
||||
x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
|
||||
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
B, C, H, W = x_freq.shape
|
||||
mask = torch.ones((B, C, H, W), device=x.device)
|
||||
|
||||
crow, ccol = H // 2, W // 2
|
||||
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# IFFT
|
||||
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
modified_skip_feature = x_filtered
|
||||
|
||||
# if org_dtype == torch.bfloat16:
|
||||
modified_skip_feature = modified_skip_feature.to(org_dtype)
|
||||
|
||||
return modified_skip_feature
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
@@ -1079,11 +1182,30 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
if self.freeU:
|
||||
ch = h.shape[1]
|
||||
s = self.freeUS1 if ch == 1280 else (self.freeUS2 if ch == 640 else 1.0)
|
||||
if s == 1.0:
|
||||
h_mod = h
|
||||
else:
|
||||
h_mod = self.spectral_modulation(h, s, self.freeURThres)
|
||||
hs.append(h_mod)
|
||||
else:
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(self.middle_block, h, emb, context)
|
||||
|
||||
for module in self.output_blocks:
|
||||
if self.freeU:
|
||||
ch = h.shape[1]
|
||||
if ch == 1280:
|
||||
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB1
|
||||
elif ch == 640:
|
||||
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB2
|
||||
# else:
|
||||
# print(f"disable freeU: {ch}")
|
||||
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = call_module(module, h, emb, context)
|
||||
|
||||
|
||||
@@ -1521,6 +1521,10 @@ def main(args):
|
||||
text_encoder2.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
|
||||
# freeU
|
||||
# unet.set_free_u_enabled(False, 1.0, 1.0, 0)
|
||||
unet.set_free_u_enabled(True, 1.1, 1.2, 0.9, 0.2)
|
||||
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
networks = []
|
||||
|
||||
Reference in New Issue
Block a user