move code for xformers to train_util

This commit is contained in:
Kohya S
2023-01-02 16:08:21 +09:00
parent bda0e8333c
commit 6b522b34c1
4 changed files with 940 additions and 1501 deletions

View File

@@ -1,35 +1,4 @@
# v2: select precision for saved checkpoint
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model
# v5: refactor to use model_util, support safetensors, add settings to use Diffusers' xformers, add log prefix
# v6: model_util update
# v7: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0, support full path in metadata
# v8: experimental full fp16 training.
# v9: add keep_tokens and save_model_as option, flip augmentation
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
# License:
# Copyright 2022 Kohya S. @kohya_ss
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# License of included scripts:
# Diffusers: ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
# Memory efficient attention:
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# training with captions
import argparse
import math
@@ -51,17 +20,7 @@ from einops import rearrange
from torch import einsum
import library.model_util as model_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
# checkpointファイル名
EPOCH_STATE_NAME = "epoch-{:06d}-state"
LAST_STATE_NAME = "last-state"
LAST_DIFFUSERS_DIR_NAME = "last"
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
import library.train_util as train_util
def collate_fn(examples):
@@ -288,9 +247,9 @@ def train(args):
# tokenizerを読み込む
print("prepare tokenizer")
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH)
if args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
@@ -403,7 +362,7 @@ def train(args):
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
if not fine_tuning:
# Hypernetwork
@@ -667,7 +626,7 @@ def train(args):
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
else:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet),
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
@@ -676,7 +635,7 @@ def train(args):
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1)))
is_main_process = accelerator.is_main_process
if is_main_process:
@@ -690,7 +649,7 @@ def train(args):
if args.save_state:
print("saving last state.")
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME))
del accelerator # この後メモリを使うのでこれは消す
@@ -706,7 +665,7 @@ def train(args):
else:
# Create the pipeline using using the trained modules and save it.
print(f"save trained model as Diffusers to {args.output_dir}")
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME)
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
@@ -717,275 +676,6 @@ def train(args):
print("model saved.")
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
"""
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# constants
EPSILON = 1e-6
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
class FlashAttentionFunction(torch.autograd.function.Function):
@ staticmethod
@ torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
""" Algorithm 2 in the paper """
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
scale = (q.shape[-1] ** -0.5)
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, 'b n -> b 1 1 n')
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
device=device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@ staticmethod
@ torch.no_grad()
def backward(ctx, do):
""" Algorithm 4 in the paper """
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2)
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
device=device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.)
p = exp_attn_weights / lc
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
elif xformers:
replace_unet_cross_attn_to_xformers()
def replace_unet_cross_attn_to_memory_efficient():
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, x, context=None, mask=None):
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, 'b h n d -> b n (h d)')
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers():
print("Replace CrossAttention.forward to use xformers")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
def forward_xformers(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context = context.to(x.dtype)
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_xformers
# endregion
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()