Merge pull request #12 from rockerBOO/lumina-model-loading

Lumina 2 and Gemma 2 model loading
This commit is contained in:
青龍聖者@bdsqlsz
2025-02-17 00:47:08 +08:00
committed by GitHub
4 changed files with 65 additions and 40 deletions

View File

@@ -21,7 +21,8 @@ import torch.nn.functional as F
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
except ModuleNotFoundError:
import warnings
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
memory_efficient_attention = None
@@ -39,17 +40,20 @@ except:
class LuminaParams:
"""Parameters for Lumina model configuration"""
patch_size: int = 2
dim: int = 2592
in_channels: int = 4
dim: int = 4096
n_layers: int = 30
n_refiner_layers: int = 2
n_heads: int = 24
n_kv_heads: int = 8
multiple_of: int = 256
axes_dims: List[int] = None
axes_lens: List[int] = None
qk_norm: bool = False,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
scaling_factor: float = 1.0,
cap_feat_dim: int = 32,
qk_norm: bool = False
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
scaling_factor: float = 1.0
cap_feat_dim: int = 32
def __post_init__(self):
if self.axes_dims is None:
@@ -62,12 +66,15 @@ class LuminaParams:
"""Returns the configuration for the 2B parameter model"""
return cls(
patch_size=2,
dim=2592,
n_layers=30,
in_channels=16,
dim=2304,
n_layers=26,
n_heads=24,
n_kv_heads=8,
axes_dims=[36, 36, 36],
axes_lens=[300, 512, 512]
axes_dims=[32, 32, 32],
axes_lens=[300, 512, 512],
qk_norm=True,
cap_feat_dim=2304
)
@classmethod
@@ -696,8 +703,8 @@ class NextDiT(nn.Module):
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
) -> None:
super().__init__()
self.in_channels = in_channels
@@ -1090,6 +1097,7 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, *
return NextDiT(
patch_size=params.patch_size,
in_channels=params.in_channels,
dim=params.dim,
n_layers=params.n_layers,
n_heads=params.n_heads,
@@ -1099,7 +1107,6 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, *
qk_norm=params.qk_norm,
ffn_dim_multiplier=params.ffn_dim_multiplier,
norm_eps=params.norm_eps,
scaling_factor=params.scaling_factor,
cap_feat_dim=params.cap_feat_dim,
**kwargs,
)

View File

@@ -27,14 +27,14 @@ def load_lumina_model(
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> lumina_models.Lumina:
):
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype
)
info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
@@ -69,30 +69,39 @@ def load_gemma2(
) -> Gemma2Model:
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2b",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"head_dim": 256,
"hidden_act": "gelu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 16384,
"max_position_embeddings": 8192,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"torch_dtype": "bfloat16",
"transformers_version": "4.38.0.dev0",
"use_cache": true,
"vocab_size": 256000
"_name_or_path": "google/gemma-2-2b",
"architectures": [
"Gemma2Model"
],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": 1,
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": True,
"vocab_size": 256000
}
config = Gemma2Config(**GEMMA2_CONFIG)
with init_empty_weights():
gemma2 = Gemma2Model._from_config(config)
@@ -104,6 +113,13 @@ def load_gemma2(
sd = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
)
for key in list(sd.keys()):
new_key = key.replace("model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = gemma2.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Gemma2: {info}")
return gemma2

View File

@@ -9,7 +9,9 @@ from library.strategy_base import (
LatentsCachingStrategy,
TokenizeStrategy,
TextEncodingStrategy,
TextEncoderOutputsCachingStrategy
)
import numpy as np
from library.utils import setup_logging
setup_logging()

View File

@@ -345,7 +345,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
lumina_train_utils.add_lumina_train_arguments(parser)
lumina_train_util.add_lumina_train_arguments(parser)
return parser