Files
Kohya-ss-sd-scripts/library/anima_models.py
duongve13112002 e21a7736f8 Support Anima model (#2260)
* Support Anima model

* Update document and fix bug

* Fix latent normlization

* Fix typo

* Fix cache embedding

* fix typo in tests/test_anima_cache.py

* Remove redundant argument apply_t5_attn_mask

* Improving caching with argument caption_dropout_rate

* Fix W&B logging bugs

* Fix discrete_flow_shift default value
2026-02-08 10:18:55 +09:00

1631 lines
63 KiB
Python

# Anima Model Architecture
# Original code: NVIDIA CORPORATION & AFFILIATES, licensed under Apache-2.0
import math
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from library import custom_offloading_utils
from library.device_utils import clean_memory_on_device
def to_device(x, device):
if isinstance(x, torch.Tensor):
return x.to(device)
elif isinstance(x, (list, tuple)):
return type(x)(to_device(elem, device) for elem in x)
elif isinstance(x, dict):
return {k: to_device(v, device) for k, v in x.items()}
else:
return x
def to_cpu(x):
if isinstance(x, torch.Tensor):
return x.cpu()
elif isinstance(x, (list, tuple)):
return [to_cpu(elem) for elem in x]
elif isinstance(x, dict):
return {k: to_cpu(v) for k, v in x.items()}
else:
return x
# Unsloth Offloaded Gradient Checkpointing
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
except ImportError:
def detach_variable(inputs, device=None):
"""Detach tensors from computation graph, optionally moving to a device.
Reimplementation of deepspeed.runtime.activation_checkpointing.checkpointing.detach_variable
for environments without DeepSpeed installed.
"""
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
requires_grad = inp.requires_grad
if device is not None:
x = inp.to(device=device)
else:
x = inp
x = x.detach()
x.requires_grad = requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ",
type(inputs).__name__,
)
class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
"""Saves VRAM by offloading activations to CPU RAM using non-blocking transfers.
Compared to standard cpu_offload_checkpointing which uses blocking transfers,
this uses non_blocking=True to hide CPU<->GPU transfer latency behind compute.
"""
@staticmethod
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, forward_function, hidden_states, *args):
# Remember the original device for backward pass (multi-GPU support)
ctx.input_device = hidden_states.device
saved_hidden_states = hidden_states.to('cpu', non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
# NOTE: args stored directly on ctx (not via save_for_backward) because
# the training loop holds references to these tensors, preventing GC.
# Using save_for_backward for all args would add complexity for no benefit.
ctx.args = args
return output
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, *grads):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
hidden_states.requires_grad_(True)
args = detach_variable(ctx.args)
inputs = (hidden_states,) + args
with torch.enable_grad():
outputs = ctx.forward_function(*inputs)
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,),
grads if isinstance(grads, tuple) else (grads,)):
if isinstance(out, torch.Tensor) and out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
torch.autograd.backward(output_tensors, grad_tensors)
return (None,) + tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
@torch._disable_dynamo
def unsloth_checkpoint(function, *args):
"""Wrapper for UnslothOffloadedGradientCheckpointer."""
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
# Flash Attention support
try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
_flash_attn_func = None
FLASH_ATTN_AVAILABLE = False
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using Flash Attention.
Input format: (batch, seq_len, n_heads, head_dim)
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
"""
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
return rearrange(out, "b s h d -> b s (h d)")
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# Utility functions: RoPE for DiT
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
if not interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
start_positions: torch.Tensor = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1)
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
) -> torch.Tensor:
assert not (
cp_size > 1 and start_positions is not None
), "start_positions != None with CP SIZE > 1 is not supported!"
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
assert fused == False
if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
freqs,
start_positions=(
start_positions[idx : idx + 1] if start_positions is not None else None
),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
]
).squeeze(1)
if tensor_format == "sbhd":
seqlen = t.size(0)
elif tensor_format == "bshd":
seqlen = t.size(1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
return _apply_rotary_pos_emb_base(
t,
freqs,
start_positions,
tensor_format,
interleaved=interleaved,
)
# Basic building blocks
class RMSNorm(torch.nn.Module):
"""RMS Normalization for DiT blocks."""
def __init__(self, dim: int, eps: float = 1e-5) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.amp.autocast(device_type='cuda', dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
class GPT2FeedForward(nn.Module):
"""GELU feedforward network."""
def __init__(self, d_model: int, d_ff: int) -> None:
super().__init__()
self.activation = nn.GELU()
self.layer1 = nn.Linear(d_model, d_ff, bias=False)
self.layer2 = nn.Linear(d_ff, d_model, bias=False)
self._layer_id = None
self._dim = d_model
self._hidden_dim = d_ff
self.init_weights()
def init_weights(self) -> None:
std = 1.0 / math.sqrt(self._dim)
torch.nn.init.trunc_normal_(self.layer1.weight, std=std, a=-3 * std, b=3 * std)
std = 1.0 / math.sqrt(self._hidden_dim)
if self._layer_id is not None:
std = std / math.sqrt(2 * (self._layer_id + 1))
torch.nn.init.trunc_normal_(self.layer2.weight, std=std, a=-3 * std, b=3 * std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native scaled_dot_product_attention.
Input/output format: (batch, seq_len, n_heads, head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
)
return result_B_S_HD
# Attention module for DiT
class Attention(nn.Module):
"""Multi-head attention supporting both self-attention and cross-attention.
Uses QK-norm (RMSNorm on q/k) and optional RoPE (only for self-attention).
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
n_heads: int = 8,
head_dim: int = 64,
dropout: float = 0.0,
qkv_format: str = "bshd",
) -> None:
super().__init__()
self.is_selfattn = context_dim is None
context_dim = query_dim if context_dim is None else context_dim
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.qkv_format = qkv_format
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = nn.Linear(query_dim, inner_dim, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=1e-6)
self.k_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.k_norm = RMSNorm(self.head_dim, eps=1e-6)
self.v_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.v_norm = nn.Identity()
self.output_proj = nn.Linear(inner_dim, query_dim, bias=False)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
self.init_weights()
def init_weights(self) -> None:
std = 1.0 / math.sqrt(self._query_dim)
torch.nn.init.trunc_normal_(self.q_proj.weight, std=std, a=-3 * std, b=3 * std)
std = 1.0 / math.sqrt(self._context_dim)
torch.nn.init.trunc_normal_(self.k_proj.weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.v_proj.weight, std=std, a=-3 * std, b=3 * std)
std = 1.0 / math.sqrt(self._inner_dim)
torch.nn.init.trunc_normal_(self.output_proj.weight, std=std, a=-3 * std, b=3 * std)
for layer in self.q_norm, self.k_norm, self.v_norm:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
def compute_qkv(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> tuple:
q = self.q_proj(x)
context = x if context is None else context
k = self.k_proj(context)
v = self.v_proj(context)
q, k, v = map(
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
(q, k, v),
)
q = self.q_norm(q)
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None:
q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=False)
k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=False)
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
# Positional Embeddings
class VideoPositionEmb(nn.Module):
def __init__(self) -> None:
super().__init__()
@property
def seq_dim(self) -> int:
return 1
def forward(self, x_B_T_H_W_C: torch.Tensor, fps: Optional[torch.Tensor]) -> torch.Tensor:
B_T_H_W_C = x_B_T_H_W_C.shape
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps)
return embeddings
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps: Optional[torch.Tensor]) -> Any:
raise NotImplementedError
class VideoRopePosition3DEmb(VideoPositionEmb):
"""3D Rotary Position Embedding for video (T, H, W) dimensions."""
def __init__(
self,
*,
head_dim: int,
len_h: int,
len_w: int,
len_t: int,
base_fps: int = 24,
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
enable_fps_modulation: bool = True,
**kwargs,
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
self.max_t = len_t
self.enable_fps_modulation = enable_fps_modulation
dim = head_dim
dim_h = dim // 6 * 2
dim_w = dim_h
dim_t = dim - 2 * dim_h
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
self.register_buffer(
"dim_spatial_range",
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h,
persistent=True,
)
self.register_buffer(
"dim_temporal_range",
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t,
persistent=True,
)
self._dim_h = dim_h
self._dim_t = dim_t
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
self.reset_parameters()
def reset_parameters(self) -> None:
dim_h = self._dim_h
dim_t = self._dim_t
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
self.dim_spatial_range = (
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
)
self.dim_temporal_range = (
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
)
def generate_embeddings(
self,
B_T_H_W_C: torch.Size,
fps: Optional[torch.Tensor] = None,
h_ntk_factor: Optional[float] = None,
w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None,
) -> torch.Tensor:
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
h_theta = 10000.0 * h_ntk_factor
w_theta = 10000.0 * w_ntk_factor
t_theta = 10000.0 * t_ntk_factor
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range)
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range)
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range)
B, T, H, W, _ = B_T_H_W_C
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs)
if self.enable_fps_modulation:
uniform_fps = (fps is None) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
if fps is None:
assert T == 1, "T should be 1 for image batch."
half_emb_t = torch.outer(self.seq[:T], temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T], temporal_freqs)
em_T_H_W_D = torch.cat(
[
repeat(half_emb_t, "t d -> t h w d", h=H, w=W),
repeat(half_emb_h, "h d -> t h w d", t=T, w=W),
repeat(half_emb_w, "w d -> t h w d", t=T, h=H),
]
* 2,
dim=-1,
)
return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float()
@property
def seq_dim(self) -> int:
return 0
class LearnablePosEmbAxis(VideoPositionEmb):
"""Learnable axis-decomposed positional embeddings."""
def __init__(
self,
*,
interpolation: str,
model_channels: int,
len_h: int,
len_w: int,
len_t: int,
**kwargs,
):
del kwargs
super().__init__()
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.model_channels = model_channels
self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels))
self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels))
self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels))
self.reset_parameters()
def reset_parameters(self) -> None:
std = 1.0 / math.sqrt(self.model_channels)
torch.nn.init.trunc_normal_(self.pos_emb_h, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.pos_emb_w, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.pos_emb_t, std=std, a=-3 * std, b=3 * std)
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps: Optional[torch.Tensor]) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop":
emb_h_H = self.pos_emb_h[:H]
emb_w_W = self.pos_emb_w[:W]
emb_t_T = self.pos_emb_t[:T]
emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
)
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
else:
raise ValueError(f"Unknown interpolation method {self.interpolation}")
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
norm = torch.add(1e-6, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
return emb / norm.to(emb.dtype)
# Timestep Embedding
class Timesteps(nn.Module):
"""Sinusoidal timestep features."""
def __init__(self, num_channels: int):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
in_dtype = timesteps_B_T.dtype
timesteps = timesteps_B_T.flatten().float()
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return rearrange(emb.to(dtype=in_dtype), "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
class TimestepEmbedding(nn.Module):
"""Projects timestep features to model dimension, with optional AdaLN-LoRA."""
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False):
super().__init__()
self.in_dim = in_features
self.out_dim = out_features
self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
else:
self.linear_2 = nn.Linear(out_features, out_features, bias=False)
self.init_weights()
def init_weights(self) -> None:
std = 1.0 / math.sqrt(self.in_dim)
torch.nn.init.trunc_normal_(self.linear_1.weight, std=std, a=-3 * std, b=3 * std)
std = 1.0 / math.sqrt(self.out_dim)
torch.nn.init.trunc_normal_(self.linear_2.weight, std=std, a=-3 * std, b=3 * std)
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_T_3D = emb
emb_B_T_D = sample
else:
adaln_lora_B_T_3D = None
emb_B_T_D = emb
return emb_B_T_D, adaln_lora_B_T_3D
class FourierFeatures(nn.Module):
"""Fourier feature transform: [B] -> [B, D]."""
def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
self.bandwidth = bandwidth
self.num_channels = num_channels
self.reset_parameters()
def reset_parameters(self) -> None:
generator = torch.Generator()
generator.manual_seed(0)
self.freqs = (
2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
)
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
# Patch Embedding
class PatchEmbed(nn.Module):
"""Patch embedding: (B, C, T, H, W) -> (B, T', H', W', D)"""
def __init__(
self,
spatial_patch_size: int,
temporal_patch_size: int,
in_channels: int = 3,
out_channels: int = 768,
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
nn.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
self.init_weights()
def init_weights(self) -> None:
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.proj[1].weight, std=std, a=-3 * std, b=3 * std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
_, _, T, H, W = x.shape
assert (
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return x
# Final Layer
class FinalLayer(nn.Module):
"""Final layer with AdaLN modulation + unpatchify."""
def __init__(
self,
hidden_size: int,
spatial_patch_size: int,
temporal_patch_size: int,
out_channels: int,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
if use_adaln_lora:
self.adaln_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, adaln_lora_dim, bias=False),
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
)
self.init_weights()
def init_weights(self) -> None:
std = 1.0 / math.sqrt(self.hidden_size)
torch.nn.init.trunc_normal_(self.linear.weight, std=std, a=-3 * std, b=3 * std)
if self.use_adaln_lora:
torch.nn.init.trunc_normal_(self.adaln_modulation[1].weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.zeros_(self.adaln_modulation[2].weight)
else:
torch.nn.init.zeros_(self.adaln_modulation[1].weight)
self.layer_norm.reset_parameters()
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d")
scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d")
x_B_T_H_W_D = self.layer_norm(x_B_T_H_W_D) * (1 + scale_B_T_1_1_D) + shift_B_T_1_1_D
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
return x_B_T_H_W_O
# Transformer Block (DiT Block)
class Block(nn.Module):
"""Transformer block with self-attention + cross-attention + MLP, each modulated by AdaLN.
Each sublayer: x = x + gate * sublayer(norm(x) * (1 + scale) + shift)
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
):
super().__init__()
self.x_dim = x_dim
self.layer_norm_self_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.self_attn = Attention(
x_dim,
None,
num_heads,
x_dim // num_heads,
qkv_format="bshd",
)
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd",
)
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio))
self.use_adaln_lora = use_adaln_lora
if self.use_adaln_lora:
self.adaln_modulation_self_attn = nn.Sequential(
nn.SiLU(),
nn.Linear(x_dim, adaln_lora_dim, bias=False),
nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False),
)
self.adaln_modulation_cross_attn = nn.Sequential(
nn.SiLU(),
nn.Linear(x_dim, adaln_lora_dim, bias=False),
nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False),
)
self.adaln_modulation_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(x_dim, adaln_lora_dim, bias=False),
nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False),
)
else:
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False))
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False))
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False))
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.unsloth_offload_checkpointing = False
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload if not unsloth_offload else False
self.unsloth_offload_checkpointing = unsloth_offload
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.unsloth_offload_checkpointing = False
def reset_parameters(self) -> None:
self.layer_norm_self_attn.reset_parameters()
self.layer_norm_cross_attn.reset_parameters()
self.layer_norm_mlp.reset_parameters()
if self.use_adaln_lora:
std = 1.0 / math.sqrt(self.x_dim)
torch.nn.init.trunc_normal_(self.adaln_modulation_self_attn[1].weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.adaln_modulation_cross_attn[1].weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.adaln_modulation_mlp[1].weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.zeros_(self.adaln_modulation_self_attn[2].weight)
torch.nn.init.zeros_(self.adaln_modulation_cross_attn[2].weight)
torch.nn.init.zeros_(self.adaln_modulation_mlp[2].weight)
else:
torch.nn.init.zeros_(self.adaln_modulation_self_attn[1].weight)
torch.nn.init.zeros_(self.adaln_modulation_cross_attn[1].weight)
torch.nn.init.zeros_(self.adaln_modulation_mlp[1].weight)
def init_weights(self) -> None:
self.reset_parameters()
self.self_attn.init_weights()
self.cross_attn.init_weights()
self.mlp.init_weights()
def _forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
# Compute AdaLN modulation parameters
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D)
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
B, T, H, W, D = x_B_T_H_W_D.shape
def _adaln_fn(_x, _norm_layer, _scale, _shift):
return _norm_layer(_x) * (1 + _scale) + _shift
# 1. Self-attention
normalized_x = _adaln_fn(x_B_T_H_W_D, self.layer_norm_self_attn, scale_self_attn_B_T_1_1_D, shift_self_attn_B_T_1_1_D)
result = rearrange(
self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
# 2. Cross-attention
normalized_x = _adaln_fn(x_B_T_H_W_D, self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D)
result = rearrange(
self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
)
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
# 3. MLP
normalized_x = _adaln_fn(x_B_T_H_W_D, self.layer_norm_mlp, scale_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D)
result = self.mlp(normalized_x)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result
return x_B_T_H_W_D
def forward(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.training and self.gradient_checkpointing:
if self.unsloth_offload_checkpointing:
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
)
elif self.cpu_offload_checkpointing:
# Standard cpu offload: blocking transfers
def create_custom_forward(func):
def custom_forward(*inputs):
# Determine original device from first tensor input
device = next(t.device for t in inputs if isinstance(t, torch.Tensor))
device_inputs = to_device(inputs, device)
outputs = func(*device_inputs)
return to_cpu(outputs)
return custom_forward
return torch_checkpoint(
create_custom_forward(self._forward),
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
use_reentrant=False,
)
else:
# Standard gradient checkpointing (no offload)
return torch_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
use_reentrant=False,
)
else:
return self._forward(
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
)
# Main DiT Model: MiniTrainDIT
class MiniTrainDIT(nn.Module):
"""Cosmos-Predict2 DiT model for image/video generation.
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: int,
patch_temporal: int,
concat_padding_mask: bool = True,
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
crossattn_emb_channels: int = 1024,
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
min_fps: int = 1,
max_fps: int = 30,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False,
) -> None:
super().__init__()
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.concat_padding_mask = concat_padding_mask
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.min_fps = min_fps
self.max_fps = max_fps
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter
# Block swap support
self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
self.build_patch_embed()
self.build_pos_embed()
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.Sequential(
Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora),
)
if self.use_llm_adapter:
self.llm_adapter = LLMAdapter(
source_dim=1024,
target_dim=1024,
model_dim=1024,
num_layers=6,
self_attn=True,
)
self.blocks = nn.ModuleList(
[
Block(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
)
for _ in range(num_blocks)
]
)
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
)
self.t_embedding_norm = RMSNorm(model_channels, eps=1e-6)
self.init_weights()
def init_weights(self) -> None:
self.x_embedder.init_weights()
self.pos_embedder.reset_parameters()
if self.extra_per_block_abs_pos_emb:
self.extra_pos_embedder.reset_parameters()
self.t_embedder[1].init_weights()
for block in self.blocks:
block.init_weights()
self.final_layer.init_weights()
self.t_embedding_norm.reset_parameters()
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
for block in self.blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
def disable_gradient_checkpointing(self):
for block in self.blocks:
block.disable_gradient_checkpointing()
@property
def device(self):
return next(self.parameters()).device
def set_flash_attn(self, use_flash_attn: bool):
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
"""
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
for block in self.blocks:
block.self_attn.attn_op = attn_op
block.cross_attn.attn_op = attn_op
def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
in_channels=in_channels,
out_channels=self.model_channels,
)
def build_pos_embed(self) -> None:
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
max_fps=self.max_fps,
min_fps=self.min_fps,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
enable_fps_modulation=self.rope_enable_fps_modulation,
)
self.pos_embedder = cls_type(**kwargs)
if self.extra_per_block_abs_pos_emb:
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
from torchvision import transforms
if self.concat_padding_mask:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D)
return x_B_T_H_W_D, None, extra_pos_emb
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
x_B_C_Tt_Hp_Wp = rearrange(
x_B_T_H_W_M,
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
t=self.patch_temporal,
)
return x_B_C_Tt_Hp_Wp
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
assert (
self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.blocks, self.blocks_to_swap, device
)
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device):
# Move all modules to device except blocks (which are managed by offloader)
if self.blocks_to_swap:
save_blocks = self.blocks
self.blocks = None # Use None to skip .to() on blocks (consistent with flux_models.py)
self.to(device)
if self.blocks_to_swap:
self.blocks = save_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.prepare_block_devices_before_forward(self.blocks)
def forward(
self,
x_B_C_T_H_W: torch.Tensor,
timesteps_B_T: torch.Tensor,
crossattn_emb: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
t5_input_ids: Optional[torch.Tensor] = None,
t5_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x_B_C_T_H_W: (B, C, T, H, W) noisy latents
timesteps_B_T: (B,) or (B, T) timesteps
crossattn_emb: (B, N, D) cross-attention embeddings (or raw Qwen3 prompt_embeds if t5_input_ids provided)
fps: Optional frames per second
padding_mask: Optional padding mask
source_attention_mask: Optional attention mask for Qwen3 embeddings (used with LLM adapter)
t5_input_ids: Optional T5 token IDs (triggers LLM adapter when provided)
t5_attn_mask: Optional T5 attention mask
"""
# Run LLM adapter inside forward for correct DDP gradient synchronization
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'):
crossattn_emb = self.llm_adapter(
source_hidden_states=crossattn_emb,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=source_attention_mask,
)
if t5_attn_mask is not None:
crossattn_emb[~t5_attn_mask.bool()] = 0
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb = self.prepare_embedded_sequence(
x_B_C_T_H_W,
fps=fps,
padding_mask=padding_mask,
)
if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T)
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
block_kwargs = {
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb,
}
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
**block_kwargs,
)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
class LLMAdapterRMSNorm(nn.Module):
"""RMSNorm specifically for the LLM Adapter (T5-style, no mean subtraction)."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
def _adapter_rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _adapter_apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
x_embed = (x * cos) + (_adapter_rotate_half(x) * sin)
return x_embed
class AdapterRotaryEmbedding(nn.Module):
"""Rotary embedding for LLM Adapter."""
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LLMAdapterAttention(nn.Module):
"""Attention module for LLM Adapter with QK-norm and separate RoPE for query/key."""
def __init__(self, query_dim, context_dim, n_heads, head_dim):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = nn.Linear(query_dim, inner_dim, bias=False)
self.q_norm = LLMAdapterRMSNorm(self.head_dim)
self.k_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.k_norm = LLMAdapterRMSNorm(self.head_dim)
self.v_proj = nn.Linear(context_dim, inner_dim, bias=False)
self.o_proj = nn.Linear(inner_dim, query_dim, bias=False)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = _adapter_apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = _adapter_apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class LLMAdapterTransformerBlock(nn.Module):
"""Transformer block for LLM Adapter: optional self-attn + cross-attn + MLP."""
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, self_attn=False, layer_norm=False):
super().__init__()
self.has_self_attn = self_attn
if self.has_self_attn:
self.norm_self_attn = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.self_attn = LLMAdapterAttention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim // num_heads,
)
self.norm_cross_attn = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.cross_attn = LLMAdapterAttention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim // num_heads,
)
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None,
position_embeddings=None, position_embeddings_context=None):
if self.has_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
def init_weights(self):
torch.nn.init.zeros_(self.mlp[2].weight)
class LLMAdapter(nn.Module):
"""Bridge module: Qwen3 embeddings (source) → T5-compatible space (target).
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
"""
def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16,
embed=None, self_attn=False, layer_norm=False):
super().__init__()
if embed is not None:
self.embed = nn.Embedding.from_pretrained(embed.weight)
else:
self.embed = nn.Embedding(32128, target_dim)
if model_dim != target_dim:
self.in_proj = nn.Linear(target_dim, model_dim)
else:
self.in_proj = nn.Identity()
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList([
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads,
self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
])
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = LLMAdapterRMSNorm(target_dim)
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
return self.norm(self.out_proj(x))
# VAE Wrapper
# VAE normalization constants
ANIMA_VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
ANIMA_VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
# DiT config detection from state_dict
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
def get_dit_config(state_dict, key_prefix=''):
"""Derive DiT configuration from state_dict weight shapes."""
dit_config = {}
dit_config["max_img_h"] = 512
dit_config["max_img_w"] = 512
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = True
dit_config["pos_emb_interpolation"] = "crop"
dit_config["min_fps"] = 1
dit_config["max_fps"] = 30
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 2048:
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 16
elif dit_config["model_channels"] == 5120:
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
elif dit_config["model_channels"] == 1280:
dit_config["num_blocks"] = 20
dit_config["num_heads"] = 20
if dit_config["in_channels"] == 16:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
dit_config["extra_t_extrapolation_ratio"] = 1.0
dit_config["rope_enable_fps_modulation"] = False
return dit_config