Add documentation to model, use SDPA attention, sample images

This commit is contained in:
rockerBOO
2025-02-18 00:58:53 -05:00
parent 1aa2f00e85
commit 98efbc3bb7
5 changed files with 643 additions and 333 deletions

View File

@@ -13,6 +13,7 @@ import math
from typing import List, Optional, Tuple
from dataclasses import dataclass
from einops import rearrange
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
import torch
@@ -23,24 +24,16 @@ import torch.nn.functional as F
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ModuleNotFoundError:
except:
import warnings
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
memory_efficient_attention = None
try:
import xformers
except:
pass
try:
from xformers.ops import memory_efficient_attention
except:
memory_efficient_attention = None
@dataclass
class LuminaParams:
"""Parameters for Lumina model configuration"""
patch_size: int = 2
in_channels: int = 4
dim: int = 4096
@@ -68,7 +61,7 @@ class LuminaParams:
"""Returns the configuration for the 2B parameter model"""
return cls(
patch_size=2,
in_channels=16,
in_channels=16, # VAE channels
dim=2304,
n_layers=26,
n_heads=24,
@@ -76,21 +69,13 @@ class LuminaParams:
axes_dims=[32, 32, 32],
axes_lens=[300, 512, 512],
qk_norm=True,
cap_feat_dim=2304
cap_feat_dim=2304, # Gemma 2 hidden_size
)
@classmethod
def get_7b_config(cls) -> "LuminaParams":
"""Returns the configuration for the 7B parameter model"""
return cls(
patch_size=2,
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
axes_dims=[64, 64, 64],
axes_lens=[300, 512, 512]
)
return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
class GradientCheckpointMixin(nn.Module):
@@ -112,6 +97,7 @@ class GradientCheckpointMixin(nn.Module):
else:
return self._forward(*args, **kwargs)
#############################################################################
# RMSNorm #
#############################################################################
@@ -148,9 +134,18 @@ class RMSNorm(torch.nn.Module):
"""
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor):
"""
Apply RMSNorm to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
@@ -204,17 +199,11 @@ class TimestepEmbedder(GradientCheckpointMixin):
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def _forward(self, t):
@@ -222,6 +211,7 @@ class TimestepEmbedder(GradientCheckpointMixin):
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
def to_cuda(x):
if isinstance(x, torch.Tensor):
return x.cuda()
@@ -266,6 +256,7 @@ class JointAttention(nn.Module):
dim (int): Number of input dimensions.
n_heads (int): Number of heads.
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
qk_norm (bool): Whether to use normalization for queries and keys.
"""
super().__init__()
@@ -295,6 +286,14 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
self.flash_attn = False
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention
def set_attention_processor(self, attention_processor):
self.attention_processor = attention_processor
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
@@ -326,16 +325,12 @@ class JointAttention(nn.Module):
return x_out.type_as(x_in)
# copied from huggingface modeling_llama.py
def _upad_input(
self, query_layer, key_layer, value_layer, attention_mask, query_length
):
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
@@ -355,9 +350,7 @@ class JointAttention(nn.Module):
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(
batch_size * kv_seq_len, self.n_local_heads, head_dim
),
query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
indices_k,
)
cu_seqlens_q = cu_seqlens_k
@@ -373,9 +366,7 @@ class JointAttention(nn.Module):
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
query_layer, attention_mask
)
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
@@ -388,10 +379,10 @@ class JointAttention(nn.Module):
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
x: Tensor,
x_mask: Tensor,
freqs_cis: Tensor,
) -> Tensor:
"""
Args:
@@ -425,7 +416,7 @@ class JointAttention(nn.Module):
softmax_scale = math.sqrt(1 / self.head_dim)
if dtype in [torch.float16, torch.bfloat16]:
if self.flash_attn:
# begin var_len flash attn
(
query_states,
@@ -459,14 +450,13 @@ class JointAttention(nn.Module):
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = (
F.scaled_dot_product_attention(
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool()
.view(bsz, 1, 1, seqlen)
.expand(-1, self.n_local_heads, seqlen, -1),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
@@ -474,10 +464,47 @@ class JointAttention(nn.Module):
)
output = output.flatten(-2)
return self.out(output)
def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with torch.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
class FeedForward(nn.Module):
def __init__(
self,
@@ -554,10 +581,13 @@ class JointTransformerBlock(GradientCheckpointMixin):
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int):
ffn_dim_multiplier (Optional[float]):
norm_eps (float):
multiple_of (int): Number of multiple of the hidden dimension.
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
feedforward layer.
norm_eps (float): Epsilon value for normalization.
qk_norm (bool): Whether to use normalization for queries and keys.
modulation (bool): Whether to use modulation for the attention
layer.
"""
super().__init__()
self.dim = dim
@@ -593,32 +623,30 @@ class JointTransformerBlock(GradientCheckpointMixin):
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
pe: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
x (Tensor): Input tensor.
pe (Tensor): Rope position embedding.
Returns:
torch.Tensor: Output tensor after applying attention and
Tensor: Output tensor after applying attention and
feedforward layers.
"""
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(
adaln_input
).chunk(4, dim=1)
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
pe,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -632,7 +660,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
pe,
)
)
x = x + self.ffn_norm2(
@@ -649,6 +677,14 @@ class FinalLayer(GradientCheckpointMixin):
"""
def __init__(self, hidden_size, patch_size, out_channels):
"""
Initialize the FinalLayer.
Args:
hidden_size (int): Hidden size of the input features.
patch_size (int): Patch size of the input features.
out_channels (int): Number of output channels.
"""
super().__init__()
self.norm_final = nn.LayerNorm(
hidden_size,
@@ -682,39 +718,21 @@ class FinalLayer(GradientCheckpointMixin):
class RopeEmbedder:
def __init__(
self,
theta: float = 10000.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
):
def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
super().__init__()
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(
self.axes_dims, self.axes_lens, theta=self.theta
)
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
def get_freqs_cis(self, ids: torch.Tensor):
def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = (
ids[:, :, i : i + 1]
.repeat(1, 1, self.freqs_cis[i].shape[-1])
.to(torch.int64)
)
axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1)
result.append(
torch.gather(
axes,
dim=1,
index=index,
)
)
freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1)
@@ -740,11 +758,63 @@ class NextDiT(nn.Module):
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
) -> None:
"""
Initialize the NextDiT model.
Args:
patch_size (int): Patch size of the input features.
in_channels (int): Number of input channels.
dim (int): Hidden size of the input features.
n_layers (int): Number of Transformer layers.
n_refiner_layers (int): Number of refiner layers.
n_heads (int): Number of attention heads.
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int): Multiple of the hidden size.
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
feedforward layer.
norm_eps (float): Epsilon value for normalization.
qk_norm (bool): Whether to use query key normalization.
cap_feat_dim (int): Dimension of the caption features.
axes_dims (List[int]): List of dimensions for the axes.
axes_lens (List[int]): List of lengths for the axes.
Returns:
None
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
dim,
bias=True,
),
)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.x_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=dim,
@@ -769,32 +839,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
dim,
bias=True,
),
)
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
# nn.init.zeros_(self.cap_embedder[1].weight)
nn.init.zeros_(self.cap_embedder[1].bias)
@@ -864,15 +909,26 @@ class NextDiT(nn.Module):
def unpatchify(
self,
x: torch.Tensor,
x: Tensor,
width: int,
height: int,
encoder_seq_lengths: List[int],
seq_lengths: List[int],
) -> torch.Tensor:
) -> Tensor:
"""
Unpatchify the input tensor and embed the caption features.
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
Args:
x (Tensor): Input tensor.
width (int): Width of the input tensor.
height (int): Height of the input tensor.
encoder_seq_lengths (List[int]): List of encoder sequence lengths.
seq_lengths (List[int]): List of sequence lengths
Returns:
output: (N, C, H, W)
"""
pH = pW = self.patch_size
@@ -891,13 +947,27 @@ class NextDiT(nn.Module):
def patchify_and_embed(
self,
x: torch.Tensor,
cap_feats: torch.Tensor,
cap_mask: torch.Tensor,
t: torch.Tensor,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int]
]:
x: Tensor,
cap_feats: Tensor,
cap_mask: Tensor,
t: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
"""
Patchify and embed the input image and caption features.
Args:
x: (N, C, H, W) image latents
cap_feats: (N, C, D) caption features
cap_mask: (N, C, D) caption attention mask
t: (N), T timesteps
Returns:
Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
"""
bsz, channels, height, width = x.shape
pH = pW = self.patch_size
device = x.device
@@ -915,40 +985,35 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = height // pH, width // pW
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len
row_ids = (
torch.arange(H_tokens, dtype=torch.int32, device=device)
.view(-1, 1)
.repeat(1, W_tokens)
.flatten()
)
col_ids = (
torch.arange(W_tokens, dtype=torch.int32, device=device)
.view(1, -1)
.repeat(H_tokens, 1)
.flatten()
)
position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids
position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids
position_ids[i, cap_len:seq_len, 0] = cap_len
freqs_cis = self.rope_embedder.get_freqs_cis(position_ids)
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:seq_len, 1] = row_ids
position_ids[i, cap_len:seq_len, 2] = col_ids
# Get combinded rotary embeddings
freqs_cis = self.rope_embedder(position_ids)
# Create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len]
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len]
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
# refine context
# Refine caption context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
x = self.x_embedder(x)
# Refine image context
for layer in self.noise_refiner:
x = layer(x, x_mask, img_freqs_cis, t)
@@ -963,19 +1028,23 @@ class NextDiT(nn.Module):
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor:
"""
Forward pass of NextDiT.
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of text tokens/features
Args:
x: (N, C, H, W) image latents
t: (N,) tensor of diffusion timesteps
cap_feats: (N, L, D) caption features
cap_mask: (N, L) caption attention mask
Returns:
x: (N, C, H, W) denoised latents
"""
_, _, height, width = x.shape # B, C, H, W
_, _, height, width = x.shape # B, C, H, W
t = self.t_embedder(t) # (N, D)
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(
x, cap_feats, cap_mask, t
)
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
for layer in self.layers:
x = layer(x, mask, freqs_cis, t)
@@ -986,7 +1055,14 @@ class NextDiT(nn.Module):
return x
def forward_with_cfg(
self, x, t, cap_feats, cap_mask, cfg_scale, cfg_trunc=100, renorm_cfg=1
self,
x: Tensor,
t: Tensor,
cap_feats: Tensor,
cap_mask: Tensor,
cfg_scale: float,
cfg_trunc: int = 100,
renorm_cfg: float = 1.0,
):
"""
Forward pass of NextDiT, but also batches the unconditional forward pass
@@ -996,9 +1072,10 @@ class NextDiT(nn.Module):
half = x[: len(x) // 2]
if t[0] < cfg_trunc:
combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
model_out = self.forward(
combined, t, cap_feats, cap_mask
) # [2, 16, 128, 128]
assert (
cap_mask.shape[0] == combined.shape[0]
), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}"
model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128]
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
@@ -1009,13 +1086,9 @@ class NextDiT(nn.Module):
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(
cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
)
ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
else:
@@ -1040,7 +1113,7 @@ class NextDiT(nn.Module):
dim: List[int],
end: List[int],
theta: float = 10000.0,
):
) -> List[Tensor]:
"""
Precompute the frequency tensor for complex exponentials (cis) with
given dimensions.
@@ -1057,19 +1130,17 @@ class NextDiT(nn.Module):
Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex
List[torch.Tensor]: Precomputed frequency tensor with complex
exponentials.
"""
freqs_cis = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (
theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)
)
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(
torch.complex64
) # complex64
pos = torch.arange(e, dtype=freqs_dtype, device="cpu")
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d))
freqs = torch.outer(pos, freqs)
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
freqs_cis.append(freqs_cis_i)
return freqs_cis
@@ -1102,7 +1173,7 @@ class NextDiT(nn.Module):
def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs):
if params is None:
params = LuminaParams.get_2b_config()
return NextDiT(
patch_size=params.patch_size,
in_channels=params.in_channels,

View File

@@ -2,20 +2,20 @@ import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Any
import torch
from torch import Tensor
from accelerate import Accelerator, PartialState
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Gemma2Model
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file
from library import lumina_models, lumina_util, strategy_base, train_util
from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex()
@@ -30,19 +30,38 @@ logger = logging.getLogger(__name__)
# region sample images
@torch.no_grad()
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
nextdit,
ae,
gemma2_model,
sample_prompts_gemma2_outputs,
prompt_replacement=None,
controlnet=None
epoch: int,
global_step: int,
nextdit: lumina_models.NextDiT,
vae: torch.nn.Module,
gemma2_model: Gemma2Model,
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
if steps == 0:
"""
Generate sample images using the NextDiT model.
Args:
accelerator (Accelerator): Accelerator instance.
args (argparse.Namespace): Command-line arguments.
epoch (int): Current epoch number.
global_step (int): Current global step number.
nextdit (lumina_models.NextDiT): The NextDiT model instance.
vae (torch.nn.Module): The VAE module.
gemma2_model (Gemma2Model): The Gemma2 model instance.
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample.
prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None.
controlnet:: ControlNet model
Returns:
None
"""
if global_step == 0:
if not args.sample_at_first:
return
else:
@@ -53,11 +72,15 @@ def sample_images(
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
assert (
args.sample_prompts is not None
), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
@@ -87,22 +110,21 @@ def sample_images(
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet
)
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
vae,
save_dir,
prompt_dict,
epoch,
global_step,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
@@ -110,23 +132,22 @@ def sample_images(
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet
)
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
vae,
save_dir,
prompt_dict,
epoch,
global_step,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
@@ -135,43 +156,60 @@ def sample_images(
clean_memory_on_device(accelerator.device)
@torch.no_grad()
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
nextdit,
gemma2_model,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_gemma2_outputs,
prompt_replacement,
# controlnet
nextdit: lumina_models.NextDiT,
gemma2_model: Gemma2Model,
vae: torch.nn.Module,
save_dir: str,
prompt_dict: Dict[str, str],
epoch: int,
global_step: int,
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
"""
Generates sample images
Args:
accelerator (Accelerator): Accelerator object
args (argparse.Namespace): Arguments object
nextdit (lumina_models.NextDiT): NextDiT model
gemma2_model (Gemma2Model): Gemma2 model
vae (torch.nn.Module): VAE model
save_dir (str): Directory to save images
prompt_dict (Dict[str, str]): Prompt dictionary
epoch (int): Epoch number
steps (int): Number of steps to run
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs
prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
Returns:
None
"""
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
sample_steps = prompt_dict.get("sample_steps", 38)
width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 1024)
guidance_scale: int = prompt_dict.get("scale", 3.5)
seed: int = prompt_dict.get("seed", None)
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
negative_prompt: str = prompt_dict.get("negative_prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
generator = torch.Generator(device=accelerator.device)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
generator.manual_seed(seed)
# if negative_prompt is None:
# negative_prompt = ""
@@ -182,7 +220,7 @@ def sample_image_inference(
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
logger.info(f"scale: {guidance_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
@@ -191,14 +229,16 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
gemma2_conds = []
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
print(f"Using cached Gemma2 outputs for prompt: {prompt}")
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if gemma2_model is not None:
print(f"Encoding prompt with Gemma2: {prompt}")
logger.info(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_gemma2_attn_mask option
encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
# if gemma2_conds is not cached, use encoded_gemma2_conds
@@ -211,22 +251,26 @@ def sample_image_inference(
gemma2_conds[i] = encoded_gemma2_conds[i]
# Unpack Gemma2 outputs
gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds
gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
weight_dtype = vae.dtype # TOFO give dtype as argument
latent_height = height // 8
latent_width = width // 8
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
16,
latent_height,
latent_width,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
generator=generator,
)
# Prompts are paired positive/negative
noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
# img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
# if controlnet_image is not None:
@@ -235,18 +279,18 @@ def sample_image_inference(
# controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
with accelerator.autocast():
x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale)
x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = ae.device # will be on cpu
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast(), torch.no_grad():
x = ae.decode(x)
ae.to(org_vae_device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast():
x = vae.decode(x)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
@@ -257,9 +301,9 @@ def sample_image_inference(
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
i: int = int(prompt_dict.get("enum", 0))
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
@@ -273,11 +317,34 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
def time_shift(mu: float, sigma: float, t: Tensor):
"""
Get time shift
Args:
mu (float): mu value.
sigma (float): sigma value.
t (Tensor): timestep.
Return:
float: time shift
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
"""
Get linear function
Args:
x1 (float, optional): x1 value. Defaults to 256.
y1 (float, optional): y1 value. Defaults to 0.5.
x2 (float, optional): x2 value. Defaults to 4096.
y2 (float, optional): y2 value. Defaults to 1.15.
Return:
Callable[[float], float]: linear function
"""
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
@@ -290,6 +357,19 @@ def get_schedule(
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
"""
Get timesteps schedule
Args:
num_steps (int): Number of steps in the schedule.
image_seq_len (int): Sequence length of the image.
base_shift (float, optional): Base shift value. Defaults to 0.5.
max_shift (float, optional): Maximum shift value. Defaults to 1.15.
shift (bool, optional): Whether to shift the schedule. Defaults to True.
Return:
List[float]: timesteps schedule
"""
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
@@ -301,11 +381,63 @@ def get_schedule(
return timesteps.tolist()
def denoise(
model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0
):
"""
Denoise an image using the NextDiT model.
Args:
model (lumina_models.NextDiT): The NextDiT model instance.
img (Tensor): The input image tensor.
txt (Tensor): The input text tensor.
txt_mask (Tensor): The input text mask tensor.
timesteps (List[float]): A list of timesteps for the denoising process.
guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0.
Returns:
img (Tensor): Denoised tensor
"""
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# model.prepare_block_swap_before_forward()
# block_samples = None
# block_single_samples = None
pred = model.forward_with_cfg(
x=img, # image latents (B, C, H, W)
t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=txt, # Gemma2的hidden states作为caption features
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
cfg_scale=guidance,
)
img = img + (t_prev - t_curr) * pred
# model.prepare_block_swap_before_forward()
return img
# endregion
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
def get_sigmas(
noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32
) -> Tensor:
"""
Get sigmas for timesteps
Args:
noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance.
timesteps (Tensor): A tensor of timesteps for the denoising process.
device (torch.device): The device on which the tensors are stored.
n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4.
dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32.
Returns:
sigmas (Tensor): The sigmas tensor.
"""
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
@@ -320,11 +452,22 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
Args:
weighting_scheme (str): The weighting scheme to use.
batch_size (int): The batch size for the sampling process.
logit_mean (float, optional): The mean of the logit distribution. Defaults to None.
logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None.
mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None.
Returns:
u (Tensor): The sampled timesteps.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
@@ -338,12 +481,19 @@ def compute_density_for_timestep_sampling(
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor:
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
Args:
weighting_scheme (str): The weighting scheme to use.
sigmas (Tensor, optional): The sigmas tensor. Defaults to None.
Returns:
u (Tensor): The sampled timesteps.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
@@ -355,9 +505,24 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
bsz, _, h, w = latents.shape
sigmas = None
@@ -412,7 +577,21 @@ def get_noisy_model_input_and_timesteps(
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
def apply_model_prediction_type(
args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Apply model prediction type to the model prediction and the sigmas.
Args:
args (argparse.Namespace): Arguments.
model_pred (Tensor): Model prediction.
noisy_model_input (Tensor): Noisy model input.
sigmas (Tensor): Sigmas.
Return:
Tuple[Tensor, Optional[Tensor]]:
"""
weighting = None
if args.model_prediction_type == "raw":
pass
@@ -433,10 +612,22 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
def save_models(
ckpt_path: str,
lumina: lumina_models.NextDiT,
sai_metadata: Optional[dict],
sai_metadata: Dict[str, Any],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
"""
Save the model to the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
lumina (lumina_models.NextDiT): NextDIT model.
sai_metadata (Optional[dict]): Metadata for the SAI model.
save_dtype (Optional[torch.dtype]): Data
Return:
None
"""
state_dict = {}
def update_sd(prefix, sd):
@@ -458,7 +649,9 @@ def save_lumina_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2"
)
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@@ -469,15 +662,29 @@ def save_lumina_model_on_train_end(
def save_lumina_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
accelerator: Accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
lumina: lumina_models.NextDiT,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
"""
Save the model to the checkpoint path.
Args:
args (argparse.Namespace): Arguments.
save_dtype (torch.dtype): Data type.
epoch (int): Epoch.
global_step (int): Global step.
lumina (lumina_models.NextDiT): NextDIT model.
Return:
None
"""
def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(

View File

@@ -11,23 +11,33 @@ from safetensors.torch import load_file
from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import lumina_models, flux_models
from library.utils import load_safetensors
import logging
setup_logging()
logger = logging.getLogger(__name__)
MODEL_VERSION_LUMINA_V2 = "lumina2"
def load_lumina_model(
ckpt_path: str,
dtype: torch.dtype,
device: Union[str, torch.device],
device: torch.device,
disable_mmap: bool = False,
):
"""
Load the Lumina model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (torch.device): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
Returns:
model (lumina_models.NextDiT): The loaded model.
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
@@ -46,6 +56,18 @@ def load_ae(
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> flux_models.AutoEncoder:
"""
Load the AutoEncoder model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
Returns:
ae (flux_models.AutoEncoder): The loaded model.
"""
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
@@ -67,6 +89,19 @@ def load_gemma2(
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> Gemma2Model:
"""
Load the Gemma2 model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
Returns:
gemma2 (Gemma2Model): The loaded model
"""
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2-2b",

View File

@@ -130,11 +130,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return False
if "input_ids" not in npz:
return False
if "apply_gemma2_attn_mask" not in npz:
return False
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
if not npz_apply_gemma2_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -142,11 +137,17 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
input_ids = data["input_ids"]
return [hidden_state, attention_mask, input_ids]
return [hidden_state, input_ids, attention_mask]
def cache_batch_outputs(
self,
@@ -193,8 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
apply_gemma2_attn_mask=True
input_ids=input_ids_i
)
else:
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]

View File

@@ -2,9 +2,10 @@ import argparse
import copy
import math
import random
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Tuple
import torch
from torch import Tensor
from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex
@@ -165,36 +166,31 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
)
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = (
strategy_base.TokenizeStrategy.get_strategy()
)
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
strategy_base.TextEncodingStrategy.get_strategy()
)
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = (
{}
) # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]:
if p not in sample_prompts_te_outputs:
logger.info(
f"cache Text Encoder outputs for prompt: {p}"
)
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = (
text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
args.apply_t5_attn_mask,
)
)
for prompt_dict in sample_prompts:
prompts = [prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", "")]
logger.info(
f"cache Text Encoder outputs for prompt: {prompts[0]}"
)
tokens_and_masks = tokenize_strategy.tokenize(prompts)
sample_prompts_te_outputs[prompts[0]] = (
text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
)
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
@@ -220,7 +216,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
epoch,
global_step,
device,
ae,
vae,
tokenizer,
text_encoder,
lumina,
@@ -231,7 +227,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
epoch,
global_step,
lumina,
ae,
vae,
self.get_models_for_text_encoding(args, accelerator, text_encoder),
self.sample_prompts_te_outputs,
)
@@ -258,12 +254,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def get_noise_pred_and_target(
self,
args,
accelerator,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: lumina_models.NextDiT,
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
dit: lumina_models.NextDiT,
network,
weight_dtype,
train_unet,
@@ -296,7 +292,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = unet(
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
@@ -341,7 +337,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
network.set_multiplier(0.0)
with torch.no_grad():
model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices],
img=noisy_model_input[diff_output_pr_indices],
gemma2_hidden_states=gemma2_hidden_states[
diff_output_pr_indices
],
@@ -350,9 +346,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
)
network.set_multiplier(1.0)
model_pred_prior = lumina_util.unpack_latents(
model_pred_prior, packed_latent_height, packed_latent_width
)
# model_pred_prior = lumina_util.unpack_latents(
# model_pred_prior, packed_latent_height, packed_latent_width
# )
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
args,
model_pred_prior,
@@ -404,7 +400,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
nextdit: lumina_models.Nextdit = unet
nextdit = unet
assert isinstance(nextdit, lumina_models.NextDiT)
nextdit = accelerator.prepare(
nextdit, device_placement=[not self.is_swapping_blocks]
)