mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
support SD3.5M
This commit is contained in:
@@ -51,7 +51,7 @@ class SD3Params:
|
||||
pos_embed_max_size: int
|
||||
adm_in_channels: int
|
||||
qk_norm: Optional[str]
|
||||
x_block_self_attn_layers: List[int]
|
||||
x_block_self_attn_layers: list[int]
|
||||
context_embedder_in_features: int
|
||||
context_embedder_out_features: int
|
||||
model_type: str
|
||||
@@ -510,6 +510,7 @@ class SingleDiTBlock(nn.Module):
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -519,13 +520,14 @@ class SingleDiTBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = AttentionLinears(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)
|
||||
|
||||
self.x_block_self_attn = x_block_self_attn
|
||||
if self.x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)
|
||||
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
@@ -546,7 +548,9 @@ class SingleDiTBlock(nn.Module):
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
if self.x_block_self_attn:
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
@@ -556,63 +560,64 @@ class SingleDiTBlock(nn.Module):
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(6, dim=-1)
|
||||
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
(
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(4, dim=-1)
|
||||
(scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
)
|
||||
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(2, dim=-1)
|
||||
(shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)
|
||||
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
|
||||
assert not self.pre_only
|
||||
if attn1_dropout > 0.0:
|
||||
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
||||
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
||||
else:
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + attn_
|
||||
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
||||
x = x + attn2_
|
||||
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
x = x + mlp_
|
||||
return x
|
||||
|
||||
|
||||
# JointBlock + block_mixing in mmdit.py
|
||||
class MMDiTBlock(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn")
|
||||
|
||||
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
|
||||
self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs)
|
||||
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
|
||||
|
||||
self.head_dim = self.x_block.attn.head_dim
|
||||
self.mode = self.x_block.attn_mode
|
||||
self.gradient_checkpointing = False
|
||||
@@ -622,7 +627,11 @@ class MMDiTBlock(nn.Module):
|
||||
|
||||
def _forward(self, context, x, c):
|
||||
ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
|
||||
x_qkv, x_intermediate = self.x_block.pre_attention(x, c)
|
||||
|
||||
if self.x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = self.x_block.pre_attention(x, c)
|
||||
|
||||
ctx_len = ctx_qkv[0].size(1)
|
||||
|
||||
@@ -634,11 +643,18 @@ class MMDiTBlock(nn.Module):
|
||||
ctx_attn_out = attn[:, :ctx_len]
|
||||
x_attn_out = attn[:, ctx_len:]
|
||||
|
||||
x = self.x_block.post_attention(x_attn_out, *x_intermediate)
|
||||
if self.x_block.x_block_self_attn:
|
||||
x_q2, x_k2, x_v2 = x_qkv2
|
||||
attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads)
|
||||
x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
|
||||
else:
|
||||
x = self.x_block.post_attention(x_attn_out, *x_intermediates)
|
||||
|
||||
if not self.context_block.pre_only:
|
||||
context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
|
||||
else:
|
||||
context = None
|
||||
|
||||
return context, x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -678,7 +694,9 @@ class MMDiT(nn.Module):
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches=None,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn_layers: Optional[list[int]] = [],
|
||||
qkv_bias: bool = True,
|
||||
pos_emb_random_crop_rate: float = 0.0,
|
||||
model_type: str = "sd3m",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -691,6 +709,8 @@ class MMDiT(nn.Module):
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||
self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
|
||||
self.gradient_checkpointing = use_checkpoint
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
@@ -751,6 +771,7 @@ class MMDiT(nn.Module):
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
@@ -832,7 +853,10 @@ class MMDiT(nn.Module):
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def cropped_pos_embed(self, h, w, device=None):
|
||||
def set_pos_emb_random_crop_rate(self, rate: float):
|
||||
self.pos_emb_random_crop_rate = rate
|
||||
|
||||
def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
|
||||
p = self.x_embedder.patch_size
|
||||
# patched size
|
||||
h = (h + 1) // p
|
||||
@@ -842,8 +866,14 @@ class MMDiT(nn.Module):
|
||||
assert self.pos_embed_max_size is not None
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
|
||||
if not random_crop:
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
else:
|
||||
top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
|
||||
left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()
|
||||
|
||||
spatial_pos_embed = self.pos_embed.reshape(
|
||||
1,
|
||||
self.pos_embed_max_size,
|
||||
@@ -896,9 +926,12 @@ class MMDiT(nn.Module):
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, D) tensor of class labels
|
||||
"""
|
||||
pos_emb_random_crop = (
|
||||
False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
|
||||
)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype)
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
@@ -977,6 +1010,7 @@ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
|
||||
depth=params.depth,
|
||||
mlp_ratio=4,
|
||||
qk_norm=params.qk_norm,
|
||||
x_block_self_attn_layers=params.x_block_self_attn_layers,
|
||||
num_patches=params.num_patches,
|
||||
attn_mode=attn_mode,
|
||||
model_type=params.model_type,
|
||||
|
||||
@@ -239,6 +239,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
default=0.0,
|
||||
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pos_emb_random_crop_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
|
||||
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
|
||||
@@ -41,20 +41,21 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
|
||||
|
||||
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
|
||||
x_block_self_attn_layers = []
|
||||
re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight")
|
||||
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
|
||||
for key in list(state_dict.keys()):
|
||||
m = re_attn.match(key)
|
||||
m = re_attn.search(key)
|
||||
if m:
|
||||
x_block_self_attn_layers.append(int(m.group(1)))
|
||||
|
||||
assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported"
|
||||
|
||||
context_embedder_in_features = context_shape[1]
|
||||
context_embedder_out_features = context_shape[0]
|
||||
|
||||
# only supports 3-5-large and 3-medium
|
||||
# only supports 3-5-large, medium or 3-medium
|
||||
if qk_norm is not None:
|
||||
model_type = "3-5-large"
|
||||
if len(x_block_self_attn_layers) == 0:
|
||||
model_type = "3-5-large"
|
||||
else:
|
||||
model_type = "3-5-medium"
|
||||
else:
|
||||
model_type = "3-medium"
|
||||
|
||||
|
||||
@@ -353,17 +353,15 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# load MMDIT
|
||||
mmdit = sd3_utils.load_mmdit(
|
||||
sd3_state_dict,
|
||||
model_dtype,
|
||||
"cpu",
|
||||
)
|
||||
mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu")
|
||||
|
||||
# attn_mode = "xformers" if args.xformers else "torch"
|
||||
# assert (
|
||||
# attn_mode == "torch"
|
||||
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
||||
|
||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
mmdit.enable_gradient_checkpointing()
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
|
||||
self.model_type = mmdit.model_type
|
||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
|
||||
Reference in New Issue
Block a user