mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge pull request #2262 from duongve13112002/fix_lumina
Fix bug and optimization for Lumina model
This commit is contained in:
@@ -34,18 +34,18 @@ from library import custom_offloading_utils
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
except:
|
||||
except ImportError:
|
||||
# flash_attn may not be available but it is not required
|
||||
pass
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except:
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
@@ -98,7 +98,7 @@ except:
|
||||
x_dtype = x.dtype
|
||||
# To handle float8 we need to convert the tensor to float
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
|
||||
|
||||
@@ -370,7 +370,7 @@ class JointAttention(nn.Module):
|
||||
if self.use_sage_attn:
|
||||
# Handle GQA (Grouped Query Attention) if needed
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@@ -379,7 +379,7 @@ class JointAttention(nn.Module):
|
||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@@ -456,51 +456,47 @@ class JointAttention(nn.Module):
|
||||
bsz = q.shape[0]
|
||||
seqlen = q.shape[1]
|
||||
|
||||
# Transpose tensors to match SageAttention's expected format (HND layout)
|
||||
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
# Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
|
||||
q_transposed = q.permute(0, 2, 1, 3)
|
||||
k_transposed = k.permute(0, 2, 1, 3)
|
||||
v_transposed = v.permute(0, 2, 1, 3)
|
||||
|
||||
# Handle masking for SageAttention
|
||||
# We need to filter out masked positions - this approach handles variable sequence lengths
|
||||
# Fast path: if all tokens are valid, run batched SageAttention directly
|
||||
if x_mask.all():
|
||||
output = sageattn(
|
||||
q_transposed, k_transposed, v_transposed,
|
||||
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||
)
|
||||
# output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Slow path: per-batch loop to handle variable-length masking
|
||||
# SageAttention does not support attention masks natively
|
||||
outputs = []
|
||||
for b in range(bsz):
|
||||
# Find valid token positions from the mask
|
||||
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
|
||||
valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
|
||||
if valid_indices.numel() == 0:
|
||||
# If all tokens are masked, create a zero output
|
||||
batch_output = torch.zeros(
|
||||
outputs.append(torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
)
|
||||
else:
|
||||
# Extract only valid tokens for this batch
|
||||
batch_q = q_transposed[b, :, valid_indices, :]
|
||||
batch_k = k_transposed[b, :, valid_indices, :]
|
||||
batch_v = v_transposed[b, :, valid_indices, :]
|
||||
device=q.device, dtype=q.dtype,
|
||||
))
|
||||
continue
|
||||
|
||||
# Run SageAttention on valid tokens only
|
||||
batch_output_valid = sageattn(
|
||||
batch_q.unsqueeze(0), # Add batch dimension back
|
||||
batch_k.unsqueeze(0),
|
||||
batch_v.unsqueeze(0),
|
||||
tensor_layout="HND",
|
||||
is_causal=False,
|
||||
sm_scale=softmax_scale
|
||||
q_transposed[b:b+1, :, valid_indices, :],
|
||||
k_transposed[b:b+1, :, valid_indices, :],
|
||||
v_transposed[b:b+1, :, valid_indices, :],
|
||||
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||
)
|
||||
|
||||
# Create output tensor with zeros for masked positions
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
device=q.device, dtype=q.dtype,
|
||||
)
|
||||
# Place valid outputs back in the right positions
|
||||
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
# Stack batch outputs and reshape to expected format
|
||||
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
|
||||
output = torch.stack(outputs, dim=0)
|
||||
except NameError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
||||
@@ -1113,10 +1109,9 @@ class NextDiT(nn.Module):
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
|
||||
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
for i in range(bsz):
|
||||
x[i, :image_seq_len] = x[i]
|
||||
x_mask[i, :image_seq_len] = True
|
||||
# x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
|
||||
# The mask can be set without a loop since all samples have the same image_seq_len.
|
||||
x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
|
||||
|
||||
@@ -334,32 +334,35 @@ def sample_image_inference(
|
||||
|
||||
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
||||
|
||||
# Get sample prompts from cache
|
||||
# Get sample prompts from cache, fallback to live encoding
|
||||
gemma2_conds = None
|
||||
neg_gemma2_conds = None
|
||||
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
|
||||
if (
|
||||
sample_prompts_gemma2_outputs
|
||||
and negative_prompt in sample_prompts_gemma2_outputs
|
||||
):
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(
|
||||
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
|
||||
)
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
# Only encode if not found in cache
|
||||
if gemma2_conds is None and gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
if neg_gemma2_conds is None and gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
if gemma2_conds is None or neg_gemma2_conds is None:
|
||||
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
|
||||
continue
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
||||
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||
@@ -475,6 +478,7 @@ def sample_image_inference(
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
"""Apply time shifting to timesteps."""
|
||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
return t
|
||||
|
||||
@@ -483,7 +487,7 @@ def get_lin_function(
|
||||
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
"""
|
||||
Get linear function
|
||||
Get linear function for resolution-dependent shifting.
|
||||
|
||||
Args:
|
||||
image_seq_len,
|
||||
@@ -528,6 +532,7 @@ def get_schedule(
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
|
||||
image_seq_len
|
||||
)
|
||||
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
@@ -689,15 +694,15 @@ def denoise(
|
||||
|
||||
img_dtype = img.dtype
|
||||
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
img = img.to(img_dtype)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = -noise_pred
|
||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
img = img.to(img_dtype)
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
@@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "nextdit_shift":
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
|
||||
@@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
|
||||
@@ -370,19 +370,25 @@ def train(args):
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
# Lumina NextDiT architecture:
|
||||
# - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
|
||||
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
|
||||
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
|
||||
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
|
||||
block_type = "other"
|
||||
if np[0].startswith("layers."):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
block_type = "main"
|
||||
elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
|
||||
# All refiner blocks (context + noise) grouped together
|
||||
block_index = -1
|
||||
block_type = "refiner"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
@@ -759,7 +765,7 @@ def train(args):
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(
|
||||
args, timesteps, noise_scheduler
|
||||
args, 1000 - timesteps, noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
|
||||
@@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
train_dataset_group.verify_bucket_reso_steps(16)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
val_dataset_group.verify_bucket_reso_steps(16)
|
||||
|
||||
self.train_gemma2 = not args.network_train_unet_only
|
||||
|
||||
@@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
# Lumina uses a single text encoder (Gemma2) at index 0.
|
||||
# Check original dtype BEFORE casting to preserve fp8 detection.
|
||||
gemma2_original_dtype = text_encoders[0].dtype
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
if gemma2_original_dtype == torch.float8_e4m3fn:
|
||||
# Model was loaded as fp8 — apply fp8 optimization
|
||||
self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
# Otherwise, cast to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
|
||||
@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
else:
|
||||
# split_dims
|
||||
total_dims = sum(self.split_dims)
|
||||
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
|
||||
for i in range(len(self.split_dims)):
|
||||
# get up/down weight
|
||||
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
|
||||
|
||||
# pad up_weight -> (total_dims, rank)
|
||||
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
||||
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
||||
|
||||
# merge weight
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
# merge into the correct slice of the fused weight
|
||||
start = sum(self.split_dims[:i])
|
||||
end = sum(self.split_dims[:i + 1])
|
||||
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# Handle split_dims case where lora_down/lora_up are ModuleList
|
||||
if self.split_dims is not None:
|
||||
# Each sub-module produces a partial weight; concatenate along output dim
|
||||
weights = []
|
||||
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
|
||||
up_w = lora_up.weight.to(torch.float)
|
||||
down_w = lora_down.weight.to(torch.float)
|
||||
weights.append(up_w @ down_w)
|
||||
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
|
||||
return weight
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
|
||||
# create LoRA for U-Net
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# Filter by block type using name-based filtering in create_modules
|
||||
# All block types use JointTransformerBlock, so we filter by module path name
|
||||
block_filter = None # None means no filtering (train all)
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# TODO: limit different blocks
|
||||
block_filter = None
|
||||
elif self.train_blocks == "transformer":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "refiners":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
|
||||
elif self.train_blocks == "noise_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "cap_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "noise_refiner"
|
||||
elif self.train_blocks == "context_refiner":
|
||||
block_filter = "context_refiner"
|
||||
elif self.train_blocks == "refiners":
|
||||
block_filter = None # handled below with two calls
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
|
||||
if self.train_blocks == "refiners":
|
||||
# Refiners = noise_refiner + context_refiner, need two calls
|
||||
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
|
||||
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
|
||||
self.unet_loras = noise_loras + context_loras
|
||||
skipped_un = skipped_noise + skipped_context
|
||||
else:
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
|
||||
|
||||
# Handle embedders
|
||||
if self.embedder_dims:
|
||||
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
if "qkv" in key:
|
||||
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
|
||||
# Q=24*96=2304, K=8*96=768, V=8*96=768
|
||||
split_dims = [2304, 768, 768]
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user