mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Add documentation to model, use SDPA attention, sample images
This commit is contained in:
@@ -13,6 +13,7 @@ import math
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
import torch
|
||||
@@ -23,24 +24,16 @@ import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except ModuleNotFoundError:
|
||||
except:
|
||||
import warnings
|
||||
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
memory_efficient_attention = None
|
||||
try:
|
||||
import xformers
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
except:
|
||||
memory_efficient_attention = None
|
||||
|
||||
@dataclass
|
||||
class LuminaParams:
|
||||
"""Parameters for Lumina model configuration"""
|
||||
|
||||
patch_size: int = 2
|
||||
in_channels: int = 4
|
||||
dim: int = 4096
|
||||
@@ -68,7 +61,7 @@ class LuminaParams:
|
||||
"""Returns the configuration for the 2B parameter model"""
|
||||
return cls(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
in_channels=16, # VAE channels
|
||||
dim=2304,
|
||||
n_layers=26,
|
||||
n_heads=24,
|
||||
@@ -76,21 +69,13 @@ class LuminaParams:
|
||||
axes_dims=[32, 32, 32],
|
||||
axes_lens=[300, 512, 512],
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2304
|
||||
cap_feat_dim=2304, # Gemma 2 hidden_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_7b_config(cls) -> "LuminaParams":
|
||||
"""Returns the configuration for the 7B parameter model"""
|
||||
return cls(
|
||||
patch_size=2,
|
||||
dim=4096,
|
||||
n_layers=32,
|
||||
n_heads=32,
|
||||
n_kv_heads=8,
|
||||
axes_dims=[64, 64, 64],
|
||||
axes_lens=[300, 512, 512]
|
||||
)
|
||||
return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
|
||||
|
||||
|
||||
class GradientCheckpointMixin(nn.Module):
|
||||
@@ -112,6 +97,7 @@ class GradientCheckpointMixin(nn.Module):
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
|
||||
#############################################################################
|
||||
# RMSNorm #
|
||||
#############################################################################
|
||||
@@ -148,9 +134,18 @@ class RMSNorm(torch.nn.Module):
|
||||
"""
|
||||
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
Apply RMSNorm to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
# To handle float8 we need to convert the tensor to float
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
@@ -204,17 +199,11 @@ class TimestepEmbedder(GradientCheckpointMixin):
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def _forward(self, t):
|
||||
@@ -222,6 +211,7 @@ class TimestepEmbedder(GradientCheckpointMixin):
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
return t_emb
|
||||
|
||||
|
||||
def to_cuda(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cuda()
|
||||
@@ -266,6 +256,7 @@ class JointAttention(nn.Module):
|
||||
dim (int): Number of input dimensions.
|
||||
n_heads (int): Number of heads.
|
||||
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
||||
qk_norm (bool): Whether to use normalization for queries and keys.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -295,6 +286,14 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
self.flash_attn = False
|
||||
|
||||
# self.attention_processor = xformers.ops.memory_efficient_attention
|
||||
self.attention_processor = F.scaled_dot_product_attention
|
||||
|
||||
def set_attention_processor(self, attention_processor):
|
||||
self.attention_processor = attention_processor
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
@@ -326,16 +325,12 @@ class JointAttention(nn.Module):
|
||||
return x_out.type_as(x_in)
|
||||
|
||||
# copied from huggingface modeling_llama.py
|
||||
def _upad_input(
|
||||
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
||||
):
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
@@ -355,9 +350,7 @@ class JointAttention(nn.Module):
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(
|
||||
batch_size * kv_seq_len, self.n_local_heads, head_dim
|
||||
),
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
@@ -373,9 +366,7 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
||||
query_layer, attention_mask
|
||||
)
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
@@ -388,10 +379,10 @@ class JointAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x: Tensor,
|
||||
x_mask: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@@ -425,7 +416,7 @@ class JointAttention(nn.Module):
|
||||
|
||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
if self.flash_attn:
|
||||
# begin var_len flash attn
|
||||
(
|
||||
query_states,
|
||||
@@ -459,14 +450,13 @@ class JointAttention(nn.Module):
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
output = (
|
||||
F.scaled_dot_product_attention(
|
||||
self.attention_processor(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
attn_mask=x_mask.bool()
|
||||
.view(bsz, 1, 1, seqlen)
|
||||
.expand(-1, self.n_local_heads, seqlen, -1),
|
||||
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
||||
scale=softmax_scale,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
@@ -474,10 +464,47 @@ class JointAttention(nn.Module):
|
||||
)
|
||||
|
||||
output = output.flatten(-2)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def apply_rope(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x_in)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -554,10 +581,13 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
n_kv_heads (Optional[int]): Number of attention heads in key and
|
||||
value features (if using GQA), or set to None for the same as
|
||||
query.
|
||||
multiple_of (int):
|
||||
ffn_dim_multiplier (Optional[float]):
|
||||
norm_eps (float):
|
||||
|
||||
multiple_of (int): Number of multiple of the hidden dimension.
|
||||
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
|
||||
feedforward layer.
|
||||
norm_eps (float): Epsilon value for normalization.
|
||||
qk_norm (bool): Whether to use normalization for queries and keys.
|
||||
modulation (bool): Whether to use modulation for the attention
|
||||
layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -593,32 +623,30 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
||||
x (Tensor): Input tensor.
|
||||
pe (Tensor): Rope position embedding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after applying attention and
|
||||
Tensor: Output tensor after applying attention and
|
||||
feedforward layers.
|
||||
|
||||
"""
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
adaln_input
|
||||
).chunk(4, dim=1)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
pe,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
@@ -632,7 +660,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
pe,
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
@@ -649,6 +677,14 @@ class FinalLayer(GradientCheckpointMixin):
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
"""
|
||||
Initialize the FinalLayer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Hidden size of the input features.
|
||||
patch_size (int): Patch size of the input features.
|
||||
out_channels (int): Number of output channels.
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(
|
||||
hidden_size,
|
||||
@@ -682,39 +718,21 @@ class FinalLayer(GradientCheckpointMixin):
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 10000.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
):
|
||||
def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
||||
self.axes_dims, self.axes_lens, theta=self.theta
|
||||
)
|
||||
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
|
||||
def get_freqs_cis(self, ids: torch.Tensor):
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
device = ids.device
|
||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = (
|
||||
ids[:, :, i : i + 1]
|
||||
.repeat(1, 1, self.freqs_cis[i].shape[-1])
|
||||
.to(torch.int64)
|
||||
)
|
||||
|
||||
axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1)
|
||||
|
||||
result.append(
|
||||
torch.gather(
|
||||
axes,
|
||||
dim=1,
|
||||
index=index,
|
||||
)
|
||||
)
|
||||
freqs = self.freqs_cis[i].to(ids.device)
|
||||
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
@@ -740,11 +758,63 @@ class NextDiT(nn.Module):
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the NextDiT model.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch size of the input features.
|
||||
in_channels (int): Number of input channels.
|
||||
dim (int): Hidden size of the input features.
|
||||
n_layers (int): Number of Transformer layers.
|
||||
n_refiner_layers (int): Number of refiner layers.
|
||||
n_heads (int): Number of attention heads.
|
||||
n_kv_heads (Optional[int]): Number of attention heads in key and
|
||||
value features (if using GQA), or set to None for the same as
|
||||
query.
|
||||
multiple_of (int): Multiple of the hidden size.
|
||||
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
|
||||
feedforward layer.
|
||||
norm_eps (float): Epsilon value for normalization.
|
||||
qk_norm (bool): Whether to use query key normalization.
|
||||
cap_feat_dim (int): Dimension of the caption features.
|
||||
axes_dims (List[int]): List of dimensions for the axes.
|
||||
axes_lens (List[int]): List of lengths for the axes.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024))
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=dim,
|
||||
@@ -769,32 +839,7 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024))
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
|
||||
# nn.init.zeros_(self.cap_embedder[1].weight)
|
||||
nn.init.zeros_(self.cap_embedder[1].bias)
|
||||
@@ -864,15 +909,26 @@ class NextDiT(nn.Module):
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x: Tensor,
|
||||
width: int,
|
||||
height: int,
|
||||
encoder_seq_lengths: List[int],
|
||||
seq_lengths: List[int],
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Unpatchify the input tensor and embed the caption features.
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
width (int): Width of the input tensor.
|
||||
height (int): Height of the input tensor.
|
||||
encoder_seq_lengths (List[int]): List of encoder sequence lengths.
|
||||
seq_lengths (List[int]): List of sequence lengths
|
||||
|
||||
Returns:
|
||||
output: (N, C, H, W)
|
||||
"""
|
||||
pH = pW = self.patch_size
|
||||
|
||||
@@ -891,13 +947,27 @@ class NextDiT(nn.Module):
|
||||
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cap_feats: torch.Tensor,
|
||||
cap_mask: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int]
|
||||
]:
|
||||
x: Tensor,
|
||||
cap_feats: Tensor,
|
||||
cap_mask: Tensor,
|
||||
t: Tensor,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
|
||||
"""
|
||||
Patchify and embed the input image and caption features.
|
||||
|
||||
Args:
|
||||
x: (N, C, H, W) image latents
|
||||
cap_feats: (N, C, D) caption features
|
||||
cap_mask: (N, C, D) caption attention mask
|
||||
t: (N), T timesteps
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
|
||||
|
||||
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
|
||||
"""
|
||||
bsz, channels, height, width = x.shape
|
||||
pH = pW = self.patch_size
|
||||
device = x.device
|
||||
@@ -915,40 +985,35 @@ class NextDiT(nn.Module):
|
||||
H_tokens, W_tokens = height // pH, width // pW
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len
|
||||
row_ids = (
|
||||
torch.arange(H_tokens, dtype=torch.int32, device=device)
|
||||
.view(-1, 1)
|
||||
.repeat(1, W_tokens)
|
||||
.flatten()
|
||||
)
|
||||
col_ids = (
|
||||
torch.arange(W_tokens, dtype=torch.int32, device=device)
|
||||
.view(1, -1)
|
||||
.repeat(H_tokens, 1)
|
||||
.flatten()
|
||||
)
|
||||
position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids
|
||||
position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids
|
||||
position_ids[i, cap_len:seq_len, 0] = cap_len
|
||||
|
||||
freqs_cis = self.rope_embedder.get_freqs_cis(position_ids)
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
position_ids[i, cap_len:seq_len, 1] = row_ids
|
||||
position_ids[i, cap_len:seq_len, 2] = col_ids
|
||||
|
||||
# Get combinded rotary embeddings
|
||||
freqs_cis = self.rope_embedder(position_ids)
|
||||
|
||||
# Create separate rotary embeddings for captions and images
|
||||
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len]
|
||||
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len]
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# refine context
|
||||
# Refine caption context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# Refine image context
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_mask, img_freqs_cis, t)
|
||||
|
||||
@@ -963,19 +1028,23 @@ class NextDiT(nn.Module):
|
||||
|
||||
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward pass of NextDiT.
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of text tokens/features
|
||||
Args:
|
||||
x: (N, C, H, W) image latents
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cap_feats: (N, L, D) caption features
|
||||
cap_mask: (N, L) caption attention mask
|
||||
|
||||
Returns:
|
||||
x: (N, C, H, W) denoised latents
|
||||
"""
|
||||
_, _, height, width = x.shape # B, C, H, W
|
||||
_, _, height, width = x.shape # B, C, H, W
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
@@ -986,7 +1055,14 @@ class NextDiT(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_with_cfg(
|
||||
self, x, t, cap_feats, cap_mask, cfg_scale, cfg_trunc=100, renorm_cfg=1
|
||||
self,
|
||||
x: Tensor,
|
||||
t: Tensor,
|
||||
cap_feats: Tensor,
|
||||
cap_mask: Tensor,
|
||||
cfg_scale: float,
|
||||
cfg_trunc: int = 100,
|
||||
renorm_cfg: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Forward pass of NextDiT, but also batches the unconditional forward pass
|
||||
@@ -996,9 +1072,10 @@ class NextDiT(nn.Module):
|
||||
half = x[: len(x) // 2]
|
||||
if t[0] < cfg_trunc:
|
||||
combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
|
||||
model_out = self.forward(
|
||||
combined, t, cap_feats, cap_mask
|
||||
) # [2, 16, 128, 128]
|
||||
assert (
|
||||
cap_mask.shape[0] == combined.shape[0]
|
||||
), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}"
|
||||
model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128]
|
||||
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
||||
# three channels by default. The standard approach to cfg applies it to all channels.
|
||||
# This can be done by uncommenting the following line and commenting-out the line following that.
|
||||
@@ -1009,13 +1086,9 @@ class NextDiT(nn.Module):
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
if float(renorm_cfg) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(
|
||||
cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
|
||||
)
|
||||
ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True)
|
||||
max_new_norm = ori_pos_norm * float(renorm_cfg)
|
||||
new_pos_norm = torch.linalg.vector_norm(
|
||||
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
|
||||
)
|
||||
new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True)
|
||||
if new_pos_norm >= max_new_norm:
|
||||
half_eps = half_eps * (max_new_norm / new_pos_norm)
|
||||
else:
|
||||
@@ -1040,7 +1113,7 @@ class NextDiT(nn.Module):
|
||||
dim: List[int],
|
||||
end: List[int],
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
) -> List[Tensor]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with
|
||||
given dimensions.
|
||||
@@ -1057,19 +1130,17 @@ class NextDiT(nn.Module):
|
||||
Defaults to 10000.0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex
|
||||
List[torch.Tensor]: Precomputed frequency tensor with complex
|
||||
exponentials.
|
||||
"""
|
||||
freqs_cis = []
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)
|
||||
)
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(
|
||||
torch.complex64
|
||||
) # complex64
|
||||
pos = torch.arange(e, dtype=freqs_dtype, device="cpu")
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d))
|
||||
freqs = torch.outer(pos, freqs)
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
@@ -1102,7 +1173,7 @@ class NextDiT(nn.Module):
|
||||
def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs):
|
||||
if params is None:
|
||||
params = LuminaParams.get_2b_config()
|
||||
|
||||
|
||||
return NextDiT(
|
||||
patch_size=params.patch_size,
|
||||
in_channels=params.in_channels,
|
||||
|
||||
@@ -2,20 +2,20 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import toml
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from accelerate import Accelerator, PartialState
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import Gemma2Model
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from library import lumina_models, lumina_util, strategy_base, train_util
|
||||
from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -30,19 +30,38 @@ logger = logging.getLogger(__name__)
|
||||
# region sample images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2_model,
|
||||
sample_prompts_gemma2_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
nextdit: lumina_models.NextDiT,
|
||||
vae: torch.nn.Module,
|
||||
gemma2_model: Gemma2Model,
|
||||
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
|
||||
prompt_replacement: Optional[Tuple[str, str]] = None,
|
||||
controlnet=None,
|
||||
):
|
||||
if steps == 0:
|
||||
"""
|
||||
Generate sample images using the NextDiT model.
|
||||
|
||||
Args:
|
||||
accelerator (Accelerator): Accelerator instance.
|
||||
args (argparse.Namespace): Command-line arguments.
|
||||
epoch (int): Current epoch number.
|
||||
global_step (int): Current global step number.
|
||||
nextdit (lumina_models.NextDiT): The NextDiT model instance.
|
||||
vae (torch.nn.Module): The VAE module.
|
||||
gemma2_model (Gemma2Model): The Gemma2 model instance.
|
||||
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample.
|
||||
prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None.
|
||||
controlnet:: ControlNet model
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if global_step == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
@@ -53,11 +72,15 @@ def sample_images(
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
assert (
|
||||
args.sample_prompts is not None
|
||||
), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None:
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
@@ -87,22 +110,21 @@ def sample_images(
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
nextdit,
|
||||
gemma2_model,
|
||||
ae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_gemma2_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
)
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
nextdit,
|
||||
gemma2_model,
|
||||
vae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
global_step,
|
||||
sample_prompts_gemma2_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
@@ -110,23 +132,22 @@ def sample_images(
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
nextdit,
|
||||
gemma2_model,
|
||||
ae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_gemma2_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
)
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
nextdit,
|
||||
gemma2_model,
|
||||
vae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
global_step,
|
||||
sample_prompts_gemma2_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
@@ -135,43 +156,60 @@ def sample_images(
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
nextdit,
|
||||
gemma2_model,
|
||||
ae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_gemma2_outputs,
|
||||
prompt_replacement,
|
||||
# controlnet
|
||||
nextdit: lumina_models.NextDiT,
|
||||
gemma2_model: Gemma2Model,
|
||||
vae: torch.nn.Module,
|
||||
save_dir: str,
|
||||
prompt_dict: Dict[str, str],
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
|
||||
prompt_replacement: Optional[Tuple[str, str]] = None,
|
||||
controlnet=None,
|
||||
):
|
||||
"""
|
||||
Generates sample images
|
||||
|
||||
Args:
|
||||
accelerator (Accelerator): Accelerator object
|
||||
args (argparse.Namespace): Arguments object
|
||||
nextdit (lumina_models.NextDiT): NextDiT model
|
||||
gemma2_model (Gemma2Model): Gemma2 model
|
||||
vae (torch.nn.Module): VAE model
|
||||
save_dir (str): Directory to save images
|
||||
prompt_dict (Dict[str, str]): Prompt dictionary
|
||||
epoch (int): Epoch number
|
||||
steps (int): Number of steps to run
|
||||
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs
|
||||
prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
sample_steps = prompt_dict.get("sample_steps", 38)
|
||||
width = prompt_dict.get("width", 1024)
|
||||
height = prompt_dict.get("height", 1024)
|
||||
guidance_scale: int = prompt_dict.get("scale", 3.5)
|
||||
seed: int = prompt_dict.get("seed", None)
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
negative_prompt: str = prompt_dict.get("negative_prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
@@ -182,7 +220,7 @@ def sample_image_inference(
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
logger.info(f"scale: {guidance_scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
@@ -191,14 +229,16 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
gemma2_conds = []
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
print(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
if gemma2_model is not None:
|
||||
print(f"Encoding prompt with Gemma2: {prompt}")
|
||||
logger.info(f"Encoding prompt with Gemma2: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_gemma2_attn_mask option
|
||||
encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
|
||||
# if gemma2_conds is not cached, use encoded_gemma2_conds
|
||||
@@ -211,22 +251,26 @@ def sample_image_inference(
|
||||
gemma2_conds[i] = encoded_gemma2_conds[i]
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
packed_latent_height = height // 16
|
||||
packed_latent_width = width // 16
|
||||
weight_dtype = vae.dtype # TOFO give dtype as argument
|
||||
latent_height = height // 8
|
||||
latent_width = width // 8
|
||||
noise = torch.randn(
|
||||
1,
|
||||
packed_latent_height * packed_latent_width,
|
||||
16 * 2 * 2,
|
||||
16,
|
||||
latent_height,
|
||||
latent_width,
|
||||
device=accelerator.device,
|
||||
dtype=weight_dtype,
|
||||
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
||||
generator=generator,
|
||||
)
|
||||
# Prompts are paired positive/negative
|
||||
noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1)
|
||||
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
|
||||
img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
# img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
|
||||
|
||||
# if controlnet_image is not None:
|
||||
@@ -235,18 +279,18 @@ def sample_image_inference(
|
||||
# controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
||||
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
with accelerator.autocast():
|
||||
x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale)
|
||||
|
||||
x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
# x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# latent to image
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = ae.device # will be on cpu
|
||||
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = ae.decode(x)
|
||||
ae.to(org_vae_device)
|
||||
org_vae_device = vae.device # will be on cpu
|
||||
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
||||
with accelerator.autocast():
|
||||
x = vae.decode(x)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
@@ -257,9 +301,9 @@ def sample_image_inference(
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = prompt_dict["enum"]
|
||||
i: int = int(prompt_dict.get("enum", 0))
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
@@ -273,11 +317,34 @@ def sample_image_inference(
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
"""
|
||||
Get time shift
|
||||
|
||||
Args:
|
||||
mu (float): mu value.
|
||||
sigma (float): sigma value.
|
||||
t (Tensor): timestep.
|
||||
|
||||
Return:
|
||||
float: time shift
|
||||
"""
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
"""
|
||||
Get linear function
|
||||
|
||||
Args:
|
||||
x1 (float, optional): x1 value. Defaults to 256.
|
||||
y1 (float, optional): y1 value. Defaults to 0.5.
|
||||
x2 (float, optional): x2 value. Defaults to 4096.
|
||||
y2 (float, optional): y2 value. Defaults to 1.15.
|
||||
|
||||
Return:
|
||||
Callable[[float], float]: linear function
|
||||
"""
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
@@ -290,6 +357,19 @@ def get_schedule(
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Get timesteps schedule
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of steps in the schedule.
|
||||
image_seq_len (int): Sequence length of the image.
|
||||
base_shift (float, optional): Base shift value. Defaults to 0.5.
|
||||
max_shift (float, optional): Maximum shift value. Defaults to 1.15.
|
||||
shift (bool, optional): Whether to shift the schedule. Defaults to True.
|
||||
|
||||
Return:
|
||||
List[float]: timesteps schedule
|
||||
"""
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
@@ -301,11 +381,63 @@ def get_schedule(
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0
|
||||
):
|
||||
"""
|
||||
Denoise an image using the NextDiT model.
|
||||
|
||||
Args:
|
||||
model (lumina_models.NextDiT): The NextDiT model instance.
|
||||
img (Tensor): The input image tensor.
|
||||
txt (Tensor): The input text tensor.
|
||||
txt_mask (Tensor): The input text mask tensor.
|
||||
timesteps (List[float]): A list of timesteps for the denoising process.
|
||||
guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0.
|
||||
|
||||
Returns:
|
||||
img (Tensor): Denoised tensor
|
||||
"""
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
# model.prepare_block_swap_before_forward()
|
||||
# block_samples = None
|
||||
# block_single_samples = None
|
||||
pred = model.forward_with_cfg(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=txt, # Gemma2的hidden states作为caption features
|
||||
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
cfg_scale=guidance,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
# model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region train
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
def get_sigmas(
|
||||
noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32
|
||||
) -> Tensor:
|
||||
"""
|
||||
Get sigmas for timesteps
|
||||
|
||||
Args:
|
||||
noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance.
|
||||
timesteps (Tensor): A tensor of timesteps for the denoising process.
|
||||
device (torch.device): The device on which the tensors are stored.
|
||||
n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4.
|
||||
dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32.
|
||||
|
||||
Returns:
|
||||
sigmas (Tensor): The sigmas tensor.
|
||||
"""
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
@@ -320,11 +452,22 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
"""
|
||||
Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
|
||||
Args:
|
||||
weighting_scheme (str): The weighting scheme to use.
|
||||
batch_size (int): The batch size for the sampling process.
|
||||
logit_mean (float, optional): The mean of the logit distribution. Defaults to None.
|
||||
logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None.
|
||||
mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None.
|
||||
|
||||
Returns:
|
||||
u (Tensor): The sampled timesteps.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
@@ -338,12 +481,19 @@ def compute_density_for_timestep_sampling(
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor:
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
|
||||
Args:
|
||||
weighting_scheme (str): The weighting scheme to use.
|
||||
sigmas (Tensor, optional): The sigmas tensor. Defaults to None.
|
||||
|
||||
Returns:
|
||||
u (Tensor): The sampled timesteps.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
@@ -355,9 +505,24 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Arguments.
|
||||
noise_scheduler (noise_scheduler): Noise scheduler.
|
||||
latents (Tensor): Latents.
|
||||
noise (Tensor): Latent noise.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): Data type
|
||||
|
||||
Return:
|
||||
Tuple[Tensor, Tensor, Tensor]:
|
||||
noisy model input
|
||||
timesteps
|
||||
sigmas
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
@@ -412,7 +577,21 @@ def get_noisy_model_input_and_timesteps(
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
def apply_model_prediction_type(
|
||||
args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Apply model prediction type to the model prediction and the sigmas.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Arguments.
|
||||
model_pred (Tensor): Model prediction.
|
||||
noisy_model_input (Tensor): Noisy model input.
|
||||
sigmas (Tensor): Sigmas.
|
||||
|
||||
Return:
|
||||
Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
pass
|
||||
@@ -433,10 +612,22 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
def save_models(
|
||||
ckpt_path: str,
|
||||
lumina: lumina_models.NextDiT,
|
||||
sai_metadata: Optional[dict],
|
||||
sai_metadata: Dict[str, Any],
|
||||
save_dtype: Optional[torch.dtype] = None,
|
||||
use_mem_eff_save: bool = False,
|
||||
):
|
||||
"""
|
||||
Save the model to the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
lumina (lumina_models.NextDiT): NextDIT model.
|
||||
sai_metadata (Optional[dict]): Metadata for the SAI model.
|
||||
save_dtype (Optional[torch.dtype]): Data
|
||||
|
||||
Return:
|
||||
None
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
@@ -458,7 +649,9 @@ def save_lumina_model_on_train_end(
|
||||
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2"
|
||||
)
|
||||
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
@@ -469,15 +662,29 @@ def save_lumina_model_on_train_end(
|
||||
def save_lumina_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
accelerator: Accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
lumina: lumina_models.NextDiT,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
|
||||
"""
|
||||
Save the model to the checkpoint path.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Arguments.
|
||||
save_dtype (torch.dtype): Data type.
|
||||
epoch (int): Epoch.
|
||||
global_step (int): Global step.
|
||||
lumina (lumina_models.NextDiT): NextDIT model.
|
||||
|
||||
Return:
|
||||
None
|
||||
"""
|
||||
|
||||
def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
|
||||
sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
|
||||
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
|
||||
@@ -11,23 +11,33 @@ from safetensors.torch import load_file
|
||||
from transformers import Gemma2Config, Gemma2Model
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import lumina_models, flux_models
|
||||
from library.utils import load_safetensors
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_VERSION_LUMINA_V2 = "lumina2"
|
||||
|
||||
def load_lumina_model(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
device: torch.device,
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
"""
|
||||
Load the Lumina model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (torch.device): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
|
||||
Returns:
|
||||
model (lumina_models.NextDiT): The loaded model.
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
|
||||
@@ -46,6 +56,18 @@ def load_ae(
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
) -> flux_models.AutoEncoder:
|
||||
"""
|
||||
Load the AutoEncoder model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ae (flux_models.AutoEncoder): The loaded model.
|
||||
"""
|
||||
logger.info("Building AutoEncoder")
|
||||
with torch.device("meta"):
|
||||
# dev and schnell have the same AE params
|
||||
@@ -67,6 +89,19 @@ def load_gemma2(
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> Gemma2Model:
|
||||
"""
|
||||
Load the Gemma2 model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
|
||||
|
||||
Returns:
|
||||
gemma2 (Gemma2Model): The loaded model
|
||||
"""
|
||||
logger.info("Building Gemma2")
|
||||
GEMMA2_CONFIG = {
|
||||
"_name_or_path": "google/gemma-2-2b",
|
||||
|
||||
@@ -130,11 +130,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
return False
|
||||
if "input_ids" not in npz:
|
||||
return False
|
||||
if "apply_gemma2_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
|
||||
if not npz_apply_gemma2_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -142,11 +137,17 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Load outputs from a npz file
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: hidden_state, input_ids, attention_mask
|
||||
"""
|
||||
data = np.load(npz_path)
|
||||
hidden_state = data["hidden_state"]
|
||||
attention_mask = data["attention_mask"]
|
||||
input_ids = data["input_ids"]
|
||||
return [hidden_state, attention_mask, input_ids]
|
||||
return [hidden_state, input_ids, attention_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
@@ -193,8 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
apply_gemma2_attn_mask=True
|
||||
input_ids=input_ids_i
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]
|
||||
|
||||
@@ -2,9 +2,10 @@ import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from accelerate import Accelerator
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
@@ -165,36 +166,31 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
|
||||
)
|
||||
|
||||
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = (
|
||||
strategy_base.TokenizeStrategy.get_strategy()
|
||||
)
|
||||
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
|
||||
strategy_base.TextEncodingStrategy.get_strategy()
|
||||
)
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = (
|
||||
{}
|
||||
) # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for prompt: {p}"
|
||||
)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
)
|
||||
for prompt_dict in sample_prompts:
|
||||
prompts = [prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", "")]
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for prompt: {prompts[0]}"
|
||||
)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompts)
|
||||
sample_prompts_te_outputs[prompts[0]] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
)
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -220,7 +216,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
epoch,
|
||||
global_step,
|
||||
device,
|
||||
ae,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
lumina,
|
||||
@@ -231,7 +227,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
epoch,
|
||||
global_step,
|
||||
lumina,
|
||||
ae,
|
||||
vae,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoder),
|
||||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
@@ -258,12 +254,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet: lumina_models.NextDiT,
|
||||
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
|
||||
dit: lumina_models.NextDiT,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
@@ -296,7 +292,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = unet(
|
||||
model_pred = dit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
@@ -341,7 +337,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=packed_noisy_model_input[diff_output_pr_indices],
|
||||
img=noisy_model_input[diff_output_pr_indices],
|
||||
gemma2_hidden_states=gemma2_hidden_states[
|
||||
diff_output_pr_indices
|
||||
],
|
||||
@@ -350,9 +346,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
network.set_multiplier(1.0)
|
||||
|
||||
model_pred_prior = lumina_util.unpack_latents(
|
||||
model_pred_prior, packed_latent_height, packed_latent_width
|
||||
)
|
||||
# model_pred_prior = lumina_util.unpack_latents(
|
||||
# model_pred_prior, packed_latent_height, packed_latent_width
|
||||
# )
|
||||
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
||||
args,
|
||||
model_pred_prior,
|
||||
@@ -404,7 +400,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit: lumina_models.Nextdit = unet
|
||||
nextdit = unet
|
||||
assert isinstance(nextdit, lumina_models.NextDiT)
|
||||
nextdit = accelerator.prepare(
|
||||
nextdit, device_placement=[not self.is_swapping_blocks]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user