Merge pull request #2262 from duongve13112002/fix_lumina

Fix bug and optimization for Lumina model
This commit is contained in:
Kohya S.
2026-02-16 07:54:49 +09:00
committed by GitHub
5 changed files with 137 additions and 108 deletions

View File

@@ -34,18 +34,18 @@ from library import custom_offloading_utils
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
except:
except ImportError:
# flash_attn may not be available but it is not required
pass
try:
from sageattention import sageattn
except:
except ImportError:
pass
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
except ImportError:
import warnings
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
@@ -98,7 +98,7 @@ except:
x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
@@ -370,7 +370,7 @@ class JointAttention(nn.Module):
if self.use_sage_attn:
# Handle GQA (Grouped Query Attention) if needed
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
if n_rep > 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
@@ -379,7 +379,7 @@ class JointAttention(nn.Module):
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
if n_rep > 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
@@ -456,51 +456,47 @@ class JointAttention(nn.Module):
bsz = q.shape[0]
seqlen = q.shape[1]
# Transpose tensors to match SageAttention's expected format (HND layout)
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
# Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
q_transposed = q.permute(0, 2, 1, 3)
k_transposed = k.permute(0, 2, 1, 3)
v_transposed = v.permute(0, 2, 1, 3)
# Handle masking for SageAttention
# We need to filter out masked positions - this approach handles variable sequence lengths
# Fast path: if all tokens are valid, run batched SageAttention directly
if x_mask.all():
output = sageattn(
q_transposed, k_transposed, v_transposed,
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
)
# output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
output = output.permute(0, 2, 1, 3)
else:
# Slow path: per-batch loop to handle variable-length masking
# SageAttention does not support attention masks natively
outputs = []
for b in range(bsz):
# Find valid token positions from the mask
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
if valid_indices.numel() == 0:
# If all tokens are masked, create a zero output
batch_output = torch.zeros(
outputs.append(torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
)
else:
# Extract only valid tokens for this batch
batch_q = q_transposed[b, :, valid_indices, :]
batch_k = k_transposed[b, :, valid_indices, :]
batch_v = v_transposed[b, :, valid_indices, :]
device=q.device, dtype=q.dtype,
))
continue
# Run SageAttention on valid tokens only
batch_output_valid = sageattn(
batch_q.unsqueeze(0), # Add batch dimension back
batch_k.unsqueeze(0),
batch_v.unsqueeze(0),
tensor_layout="HND",
is_causal=False,
sm_scale=softmax_scale
q_transposed[b:b+1, :, valid_indices, :],
k_transposed[b:b+1, :, valid_indices, :],
v_transposed[b:b+1, :, valid_indices, :],
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
)
# Create output tensor with zeros for masked positions
batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
device=q.device, dtype=q.dtype,
)
# Place valid outputs back in the right positions
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
outputs.append(batch_output)
# Stack batch outputs and reshape to expected format
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
output = torch.stack(outputs, dim=0)
except NameError as e:
raise RuntimeError(
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
@@ -1113,10 +1109,9 @@ class NextDiT(nn.Module):
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
for i in range(bsz):
x[i, :image_seq_len] = x[i]
x_mask[i, :image_seq_len] = True
# x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
# The mask can be set without a loop since all samples have the same image_seq_len.
x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
x = self.x_embedder(x)

View File

@@ -334,32 +334,35 @@ def sample_image_inference(
# No need to add system prompt here, as it has been handled in the tokenize_strategy
# Get sample prompts from cache
# Get sample prompts from cache, fallback to live encoding
gemma2_conds = None
neg_gemma2_conds = None
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if (
sample_prompts_gemma2_outputs
and negative_prompt in sample_prompts_gemma2_outputs
):
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info(
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
)
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
# Load sample prompts from Gemma 2
if gemma2_model is not None:
# Only encode if not found in cache
if gemma2_conds is None and gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, gemma2_model, tokens_and_masks
)
if neg_gemma2_conds is None and gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
neg_gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, gemma2_model, tokens_and_masks
)
if gemma2_conds is None or neg_gemma2_conds is None:
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
continue
# Unpack Gemma2 outputs
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
@@ -475,6 +478,7 @@ def sample_image_inference(
def time_shift(mu: float, sigma: float, t: torch.Tensor):
"""Apply time shifting to timesteps."""
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
return t
@@ -483,7 +487,7 @@ def get_lin_function(
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
) -> Callable[[float], float]:
"""
Get linear function
Get linear function for resolution-dependent shifting.
Args:
image_seq_len,
@@ -528,6 +532,7 @@ def get_schedule(
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
image_seq_len
)
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
@@ -689,15 +694,15 @@ def denoise(
img_dtype = img.dtype
if img.dtype != img_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
img = img.to(img_dtype)
# compute the previous noisy sample x_t -> x_t-1
noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
if img.dtype != img_dtype:
if torch.backends.mps.is_available():
img = img.to(img_dtype)
model.prepare_block_swap_before_forward()
return img
@@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps(
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "nextdit_shift":
sigmas = torch.rand((bsz,), device=device)
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
sigmas = time_shift(mu, 1.0, sigmas)
@@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps

View File

@@ -370,19 +370,25 @@ def train(args):
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(nextdit.named_parameters())
named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
assert len(named_parameters) == len(
group["params"]
), "number of parameters does not match"
), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
# Lumina NextDiT architecture:
# - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
block_type = "other"
if np[0].startswith("layers."):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
block_type = "main"
elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
# All refiner blocks (context + noise) grouped together
block_index = -1
block_type = "refiner"
else:
block_index = -1
@@ -759,7 +765,7 @@ def train(args):
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(
args, timesteps, noise_scheduler
args, 1000 - timesteps, noise_scheduler
)
loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c

View File

@@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
args.cache_text_encoder_outputs = True
train_dataset_group.verify_bucket_reso_steps(32)
train_dataset_group.verify_bucket_reso_steps(16)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
val_dataset_group.verify_bucket_reso_steps(16)
self.train_gemma2 = not args.network_train_unet_only
@@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
# Lumina uses a single text encoder (Gemma2) at index 0.
# Check original dtype BEFORE casting to preserve fp8 detection.
gemma2_original_dtype = text_encoders[0].dtype
text_encoders[0].to(accelerator.device)
if text_encoders[0].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
if gemma2_original_dtype == torch.float8_e4m3fn:
# Model was loaded as fp8 — apply fp8 optimization
self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
# Otherwise, cast to target dtype
text_encoders[0].to(weight_dtype)
with accelerator.autocast():

View File

@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
else:
# split_dims
total_dims = sum(self.split_dims)
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
for i in range(len(self.split_dims)):
# get up/down weight
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
# pad up_weight -> (total_dims, rank)
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
# merge weight
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
# merge into the correct slice of the fused weight
start = sum(self.split_dims[:i])
end = sum(self.split_dims[:i + 1])
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
if multiplier is None:
multiplier = self.multiplier
# Handle split_dims case where lora_down/lora_up are ModuleList
if self.split_dims is not None:
# Each sub-module produces a partial weight; concatenate along output dim
weights = []
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
up_w = lora_up.weight.to(torch.float)
down_w = lora_down.weight.to(torch.float)
weights.append(up_w @ down_w)
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
return weight
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped
# create LoRA for U-Net
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# Filter by block type using name-based filtering in create_modules
# All block types use JointTransformerBlock, so we filter by module path name
block_filter = None # None means no filtering (train all)
if self.train_blocks == "all":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# TODO: limit different blocks
block_filter = None
elif self.train_blocks == "transformer":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "refiners":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
elif self.train_blocks == "noise_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "cap_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
block_filter = "noise_refiner"
elif self.train_blocks == "context_refiner":
block_filter = "context_refiner"
elif self.train_blocks == "refiners":
block_filter = None # handled below with two calls
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
if self.train_blocks == "refiners":
# Refiners = noise_refiner + context_refiner, need two calls
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
self.unet_loras = noise_loras + context_loras
skipped_un = skipped_noise + skipped_context
else:
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
# Handle embedders
if self.embedder_dims:
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
info = self.load_state_dict(weights_sd, False)
return info
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
if "qkv" in key:
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
# Q=24*96=2304, K=8*96=768, V=8*96=768
split_dims = [2304, 768, 768]
else:
new_state_dict[key] = state_dict[key]
continue