support sdpa

This commit is contained in:
ykume
2023-06-11 21:26:15 +09:00
parent 4d0c06e397
commit 9e1683cf2b
9 changed files with 177 additions and 84 deletions

View File

@@ -75,8 +75,6 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set
accelerate config
```
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
Answers to accelerate config:
```txt
@@ -94,6 +92,30 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
(Single GPU with id `0` will be used.)
### Experimental: Use PyTorch 2.0
In this case, you need to install PyTorch 2.0 and xformers 0.0.20. Instead of the above, please type the following:
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```
Answers to accelerate config should be the same as above.
### about PyTorch and xformers
Other versions of PyTorch and xformers seem to have problems with training.

View File

@@ -141,7 +141,7 @@ def train(args):
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:

View File

@@ -137,7 +137,7 @@ USE_CUTOUTS = False
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
@@ -151,56 +151,26 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
# TODO common train_util.py
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
replace_vae_attn_to_xformers()
elif sdpa:
replace_vae_attn_to_sdpa()
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
@@ -238,6 +208,15 @@ def replace_vae_attn_to_memory_efficient():
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_flash_attn(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
else:
@@ -248,40 +227,6 @@ def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers_0_14(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
@@ -319,12 +264,75 @@ def replace_vae_attn_to_xformers():
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_xformers(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_xformers
def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj)
)
out = torch.nn.functional.scaled_dot_product_attention(
query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = rearrange(out, "b n h d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_sdpa_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_sdpa(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_sdpa
# endregion
# region 画像生成の本体lpw_stable_diffusion.py ASLからコピーして修正
@@ -2082,8 +2090,9 @@ def main(args):
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
replace_unet_modules(unet, not args.xformers, args.xformers)
replace_vae_modules(vae, not args.xformers, args.xformers)
mem_eff = not (args.xformers or args.sdpa)
replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa)
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
print("loading tokenizer")
@@ -3176,6 +3185,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa")
parser.add_argument(
"--diffusers_xformers",
action="store_true",

View File

@@ -494,6 +494,9 @@ class DownBlock2D(nn.Module):
def set_use_memory_efficient_attention(self, xformers, mem_eff):
pass
def set_use_sdpa(self, sdpa):
pass
def forward(self, hidden_states, temb=None):
output_states = ()
@@ -564,11 +567,15 @@ class CrossAttention(nn.Module):
self.use_memory_efficient_attention_xformers = False
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
def set_use_sdpa(self, sdpa):
self.use_sdpa = sdpa
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
@@ -588,6 +595,8 @@ class CrossAttention(nn.Module):
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
if self.use_sdpa:
return self.forward_sdpa(hidden_states, context, mask)
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
@@ -676,6 +685,26 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def forward_sdpa(self, x, context=None, mask=None):
import xformers.ops
h = self.heads
q_in = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out[0](out)
return out
# feedforward
class GEGLU(nn.Module):
@@ -759,6 +788,10 @@ class BasicTransformerBlock(nn.Module):
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool):
self.attn1.set_use_sdpa(sdpa)
self.attn2.set_use_sdpa(sdpa)
def forward(self, hidden_states, context=None, timestep=None):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
@@ -820,6 +853,10 @@ class Transformer2DModel(nn.Module):
for transformer in self.transformer_blocks:
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa):
for transformer in self.transformer_blocks:
transformer.set_use_sdpa(sdpa)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# 1. Input
batch, _, height, weight = hidden_states.shape
@@ -901,6 +938,10 @@ class CrossAttnDownBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(sdpa)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
@@ -978,6 +1019,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(sdpa)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
for i, resnet in enumerate(self.resnets):
attn = None if i == 0 else self.attentions[i - 1]
@@ -1079,6 +1124,9 @@ class UpBlock2D(nn.Module):
def set_use_memory_efficient_attention(self, xformers, mem_eff):
pass
def set_use_sdpa(self, sdpa):
pass
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
@@ -1159,6 +1207,10 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
for attn in self.attentions:
attn.set_use_sdpa(spda)
def forward(
self,
hidden_states,
@@ -1393,10 +1445,15 @@ class UNet2DConditionModel(nn.Module):
def disable_gradient_checkpointing(self):
self.set_gradient_checkpointing(value=False)
def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None:
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
module.set_use_memory_efficient_attention(xformers,mem_eff)
module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None:
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
module.set_use_sdpa(sdpa)
def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks

View File

@@ -1788,7 +1788,7 @@ class FlashAttentionFunction(torch.autograd.function.Function):
return dq, dk, dv, None, None, None, None
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
unet.set_use_memory_efficient_attention(False, True)
@@ -1800,6 +1800,9 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_sdpa(True)
"""
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
@@ -2048,6 +2051,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使うPyTorch 2.0が必要)")
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
)

View File

@@ -119,7 +119,7 @@ def train(args):
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:

View File

@@ -160,7 +160,7 @@ def train(args):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 差分追加学習のためにモデルを読み込む
import sys

View File

@@ -231,7 +231,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:

View File

@@ -264,7 +264,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
original_unet.UNet2DConditionModel.forward = unet_forward_XTI
original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI