From 1640e533925f5f33a85230d159c33d4c1643096f Mon Sep 17 00:00:00 2001 From: Duoong Date: Thu, 12 Feb 2026 22:52:28 +0700 Subject: [PATCH] Fix bug and optimization Lumina training --- library/lumina_models.py | 97 +++++++++++++++++------------------- library/lumina_train_util.py | 39 +++++++++------ lumina_train.py | 24 +++++---- lumina_train_network.py | 17 ++++--- networks/lora_lumina.py | 68 +++++++++++++++---------- 5 files changed, 137 insertions(+), 108 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 7e925352..c51c900e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -34,18 +34,18 @@ from library import custom_offloading_utils try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -except: +except ImportError: # flash_attn may not be available but it is not required pass try: from sageattention import sageattn -except: +except ImportError: pass try: from apex.normalization import FusedRMSNorm as RMSNorm -except: +except ImportError: import warnings warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") @@ -98,7 +98,7 @@ except: x_dtype = x.dtype # To handle float8 we need to convert the tensor to float x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) @@ -370,7 +370,7 @@ class JointAttention(nn.Module): if self.use_sage_attn: # Handle GQA (Grouped Query Attention) if needed n_rep = self.n_local_heads // self.n_local_kv_heads - if n_rep >= 1: + if n_rep > 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) @@ -379,7 +379,7 @@ class JointAttention(nn.Module): output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) else: n_rep = self.n_local_heads // self.n_local_kv_heads - if n_rep >= 1: + if n_rep > 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) @@ -456,51 +456,47 @@ class JointAttention(nn.Module): bsz = q.shape[0] seqlen = q.shape[1] - # Transpose tensors to match SageAttention's expected format (HND layout) - q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] - k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] - v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] - - # Handle masking for SageAttention - # We need to filter out masked positions - this approach handles variable sequence lengths - outputs = [] - for b in range(bsz): - # Find valid token positions from the mask - valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1) - if valid_indices.numel() == 0: - # If all tokens are masked, create a zero output - batch_output = torch.zeros( - seqlen, self.n_local_heads, self.head_dim, - device=q.device, dtype=q.dtype - ) - else: - # Extract only valid tokens for this batch - batch_q = q_transposed[b, :, valid_indices, :] - batch_k = k_transposed[b, :, valid_indices, :] - batch_v = v_transposed[b, :, valid_indices, :] - - # Run SageAttention on valid tokens only + # Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim] + q_transposed = q.permute(0, 2, 1, 3) + k_transposed = k.permute(0, 2, 1, 3) + v_transposed = v.permute(0, 2, 1, 3) + + # Fast path: if all tokens are valid, run batched SageAttention directly + if x_mask.all(): + output = sageattn( + q_transposed, k_transposed, v_transposed, + tensor_layout="HND", is_causal=False, sm_scale=softmax_scale, + ) + # output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim] + output = output.permute(0, 2, 1, 3) + else: + # Slow path: per-batch loop to handle variable-length masking + # SageAttention does not support attention masks natively + outputs = [] + for b in range(bsz): + valid_indices = x_mask[b].nonzero(as_tuple=True)[0] + if valid_indices.numel() == 0: + outputs.append(torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype, + )) + continue + batch_output_valid = sageattn( - batch_q.unsqueeze(0), # Add batch dimension back - batch_k.unsqueeze(0), - batch_v.unsqueeze(0), - tensor_layout="HND", - is_causal=False, - sm_scale=softmax_scale + q_transposed[b:b+1, :, valid_indices, :], + k_transposed[b:b+1, :, valid_indices, :], + v_transposed[b:b+1, :, valid_indices, :], + tensor_layout="HND", is_causal=False, sm_scale=softmax_scale, ) - - # Create output tensor with zeros for masked positions + batch_output = torch.zeros( - seqlen, self.n_local_heads, self.head_dim, - device=q.device, dtype=q.dtype + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype, ) - # Place valid outputs back in the right positions batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2) - - outputs.append(batch_output) - - # Stack batch outputs and reshape to expected format - output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim] + outputs.append(batch_output) + + output = torch.stack(outputs, dim=0) except NameError as e: raise RuntimeError( f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}" @@ -1113,10 +1109,9 @@ class NextDiT(nn.Module): x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) - x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) - for i in range(bsz): - x[i, :image_seq_len] = x[i] - x_mask[i, :image_seq_len] = True + # x.shape[1] == image_seq_len after patchify, so this was assigning to itself. + # The mask can be set without a loop since all samples have the same image_seq_len. + x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device) x = self.x_embedder(x) @@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs): axes_dims=[40, 40, 40], axes_lens=[300, 512, 512], **kwargs, - ) + ) \ No newline at end of file diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 244d2360..afbfc241 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -334,32 +334,35 @@ def sample_image_inference( # No need to add system prompt here, as it has been handled in the tokenize_strategy - # Get sample prompts from cache + # Get sample prompts from cache, fallback to live encoding + gemma2_conds = None + neg_gemma2_conds = None + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: gemma2_conds = sample_prompts_gemma2_outputs[prompt] logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") - if ( - sample_prompts_gemma2_outputs - and negative_prompt in sample_prompts_gemma2_outputs - ): + if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] - logger.info( - f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}" - ) + logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") - # Load sample prompts from Gemma 2 - if gemma2_model is not None: + # Only encode if not found in cache + if gemma2_conds is None and gemma2_model is not None: tokens_and_masks = tokenize_strategy.tokenize(prompt) gemma2_conds = encoding_strategy.encode_tokens( tokenize_strategy, gemma2_model, tokens_and_masks ) + if neg_gemma2_conds is None and gemma2_model is not None: tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) neg_gemma2_conds = encoding_strategy.encode_tokens( tokenize_strategy, gemma2_model, tokens_and_masks ) + if gemma2_conds is None or neg_gemma2_conds is None: + logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}") + continue + # Unpack Gemma2 outputs gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds @@ -475,6 +478,7 @@ def sample_image_inference( def time_shift(mu: float, sigma: float, t: torch.Tensor): + """Apply time shifting to timesteps.""" t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) return t @@ -483,7 +487,7 @@ def get_lin_function( x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15 ) -> Callable[[float], float]: """ - Get linear function + Get linear function for resolution-dependent shifting. Args: image_seq_len, @@ -528,6 +532,7 @@ def get_schedule( mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)( image_seq_len ) + timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() @@ -689,15 +694,15 @@ def denoise( img_dtype = img.dtype - if img.dtype != img_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - img = img.to(img_dtype) - # compute the previous noisy sample x_t -> x_t-1 noise_pred = -noise_pred img = scheduler.step(noise_pred, t, img, return_dict=False)[0] + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + if img.dtype != img_dtype: + if torch.backends.mps.is_available(): + img = img.to(img_dtype) + model.prepare_block_swap_before_forward() return img @@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps( timesteps = sigmas * num_timesteps elif args.timestep_sampling == "nextdit_shift": sigmas = torch.rand((bsz,), device=device) + sigmas = torch.clamp(sigmas, min=1e-7).to(device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) sigmas = time_shift(mu, 1.0, sigmas) @@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps( sigmas = torch.randn(bsz, device=device) sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas.sigmoid() + sigmas = torch.clamp(sigmas, min=1e-7).to(device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size sigmas = time_shift(mu, 1.0, sigmas) timesteps = sigmas * num_timesteps diff --git a/lumina_train.py b/lumina_train.py index 580b170c..cf6e7fdb 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -370,19 +370,25 @@ def train(args): grouped_params = [] param_group = {} for group in params_to_optimize: - named_parameters = list(nextdit.named_parameters()) + named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad] assert len(named_parameters) == len( group["params"] - ), "number of parameters does not match" + ), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})" for p, np in zip(group["params"], named_parameters): # determine target layer and block index for each parameter - block_type = "other" # double, single or other - if np[0].startswith("double_blocks"): + # Lumina NextDiT architecture: + # - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B) + # - "context_refiner.{i}.*" : context refiner blocks (2 blocks) + # - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks) + # - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer + block_type = "other" + if np[0].startswith("layers."): block_index = int(np[0].split(".")[1]) - block_type = "double" - elif np[0].startswith("single_blocks"): - block_index = int(np[0].split(".")[1]) - block_type = "single" + block_type = "main" + elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."): + # All refiner blocks (context + noise) grouped together + block_index = -1 + block_type = "refiner" else: block_index = -1 @@ -759,7 +765,7 @@ def train(args): # calculate loss huber_c = train_util.get_huber_threshold_if_needed( - args, timesteps, noise_scheduler + args, 1000 - timesteps, noise_scheduler ) loss = train_util.conditional_loss( model_pred.float(), target.float(), args.loss_type, "none", huber_c diff --git a/lumina_train_network.py b/lumina_train_network.py index ad29d2f2..58f9f4f1 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): logger.warning("Enabling cache_text_encoder_outputs due to disk caching") args.cache_text_encoder_outputs = True - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(16) if val_dataset_group is not None: - val_dataset_group.verify_bucket_reso_steps(32) + val_dataset_group.verify_bucket_reso_steps(16) self.train_gemma2 = not args.network_train_unet_only @@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + # Lumina uses a single text encoder (Gemma2) at index 0. + # Check original dtype BEFORE casting to preserve fp8 detection. + gemma2_original_dtype = text_encoders[0].dtype + text_encoders[0].to(accelerator.device) - if text_encoders[0].dtype == torch.float8_e4m3fn: - # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + if gemma2_original_dtype == torch.float8_e4m3fn: + # Model was loaded as fp8 — apply fp8 optimization + self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype) else: - # otherwise, we need to convert it to target dtype + # Otherwise, cast to target dtype text_encoders[0].to(weight_dtype) with accelerator.autocast(): diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 0929e839..8e672091 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule): org_sd["weight"] = weight.to(dtype) self.org_module.load_state_dict(org_sd) else: - # split_dims - total_dims = sum(self.split_dims) + # split_dims: merge each split's LoRA into the correct slice of the fused QKV weight for i in range(len(self.split_dims)): # get up/down weight down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) - up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank) - # pad up_weight -> (total_dims, rank) - padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) - padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight - - # merge weight - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + # merge into the correct slice of the fused weight + start = sum(self.split_dims[:i]) + end = sum(self.split_dims[:i + 1]) + weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale # set weight to org_module org_sd["weight"] = weight.to(dtype) @@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule): if multiplier is None: multiplier = self.multiplier + # Handle split_dims case where lora_down/lora_up are ModuleList + if self.split_dims is not None: + # Each sub-module produces a partial weight; concatenate along output dim + weights = [] + for lora_up, lora_down in zip(self.lora_up, self.lora_down): + up_w = lora_up.weight.to(torch.float) + down_w = lora_down.weight.to(torch.float) + weights.append(up_w @ down_w) + weight = self.multiplier * torch.cat(weights, dim=0) * self.scale + return weight + # get up/down weight from module up_weight = self.lora_up.weight.to(torch.float) down_weight = self.lora_down.weight.to(torch.float) @@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei weights_sd = load_file(file) else: - weights_sd = torch.load(file, map_location="cpu") + weights_sd = torch.load(file, map_location="cpu", weights_only=False) # get dim/alpha mapping, and train t5xxl modules_dim = {} @@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module): skipped_te += skipped # create LoRA for U-Net + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + # Filter by block type using name-based filtering in create_modules + # All block types use JointTransformerBlock, so we filter by module path name + block_filter = None # None means no filtering (train all) if self.train_blocks == "all": - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE - # TODO: limit different blocks + block_filter = None elif self.train_blocks == "transformer": - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE - elif self.train_blocks == "refiners": - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..." elif self.train_blocks == "noise_refiner": - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE - elif self.train_blocks == "cap_refiner": - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + block_filter = "noise_refiner" + elif self.train_blocks == "context_refiner": + block_filter = "context_refiner" + elif self.train_blocks == "refiners": + block_filter = None # handled below with two calls self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] - self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) + if self.train_blocks == "refiners": + # Refiners = noise_refiner + context_refiner, need two calls + noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner") + context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner") + self.unet_loras = noise_loras + context_loras + skipped_un = skipped_noise + skipped_context + else: + self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter) # Handle embedders if self.embedder_dims: @@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module): weights_sd = load_file(file) else: - weights_sd = torch.load(file, map_location="cpu") + weights_sd = torch.load(file, map_location="cpu", weights_only=False) info = self.load_state_dict(weights_sd, False) return info @@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module): state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) new_state_dict = {} for key in list(state_dict.keys()): - if "double" in key and "qkv" in key: - split_dims = [3072] * 3 - elif "single" in key and "linear1" in key: - split_dims = [3072] * 3 + [12288] + if "qkv" in key: + # Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96 + # Q=24*96=2304, K=8*96=768, V=8*96=768 + split_dims = [2304, 768, 768] else: new_state_dict[key] = state_dict[key] continue @@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module): scalednorm = updown.norm() * ratio norms.append(scalednorm.item()) - return keys_scaled, sum(norms) / len(norms), max(norms) + return keys_scaled, sum(norms) / len(norms), max(norms) \ No newline at end of file