mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Merge pull request #12 from rockerBOO/lumina-model-loading
Lumina 2 and Gemma 2 model loading
This commit is contained in:
@@ -21,7 +21,8 @@ import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
import warnings
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
memory_efficient_attention = None
|
||||
@@ -39,17 +40,20 @@ except:
|
||||
class LuminaParams:
|
||||
"""Parameters for Lumina model configuration"""
|
||||
patch_size: int = 2
|
||||
dim: int = 2592
|
||||
in_channels: int = 4
|
||||
dim: int = 4096
|
||||
n_layers: int = 30
|
||||
n_refiner_layers: int = 2
|
||||
n_heads: int = 24
|
||||
n_kv_heads: int = 8
|
||||
multiple_of: int = 256
|
||||
axes_dims: List[int] = None
|
||||
axes_lens: List[int] = None
|
||||
qk_norm: bool = False,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
scaling_factor: float = 1.0,
|
||||
cap_feat_dim: int = 32,
|
||||
qk_norm: bool = False
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
scaling_factor: float = 1.0
|
||||
cap_feat_dim: int = 32
|
||||
|
||||
def __post_init__(self):
|
||||
if self.axes_dims is None:
|
||||
@@ -62,12 +66,15 @@ class LuminaParams:
|
||||
"""Returns the configuration for the 2B parameter model"""
|
||||
return cls(
|
||||
patch_size=2,
|
||||
dim=2592,
|
||||
n_layers=30,
|
||||
in_channels=16,
|
||||
dim=2304,
|
||||
n_layers=26,
|
||||
n_heads=24,
|
||||
n_kv_heads=8,
|
||||
axes_dims=[36, 36, 36],
|
||||
axes_lens=[300, 512, 512]
|
||||
axes_dims=[32, 32, 32],
|
||||
axes_lens=[300, 512, 512],
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2304
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -696,8 +703,8 @@ class NextDiT(nn.Module):
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -1090,6 +1097,7 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, *
|
||||
|
||||
return NextDiT(
|
||||
patch_size=params.patch_size,
|
||||
in_channels=params.in_channels,
|
||||
dim=params.dim,
|
||||
n_layers=params.n_layers,
|
||||
n_heads=params.n_heads,
|
||||
@@ -1099,7 +1107,6 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, *
|
||||
qk_norm=params.qk_norm,
|
||||
ffn_dim_multiplier=params.ffn_dim_multiplier,
|
||||
norm_eps=params.norm_eps,
|
||||
scaling_factor=params.scaling_factor,
|
||||
cap_feat_dim=params.cap_feat_dim,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -27,14 +27,14 @@ def load_lumina_model(
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
) -> lumina_models.Lumina:
|
||||
):
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(
|
||||
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
|
||||
ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype
|
||||
)
|
||||
info = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
logger.info(f"Loaded Lumina: {info}")
|
||||
@@ -69,30 +69,39 @@ def load_gemma2(
|
||||
) -> Gemma2Model:
|
||||
logger.info("Building Gemma2")
|
||||
GEMMA2_CONFIG = {
|
||||
"_name_or_path": "google/gemma-2b",
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
"head_dim": 256,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 16384,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"pad_token_id": 0,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 10000.0,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.38.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 256000
|
||||
"_name_or_path": "google/gemma-2-2b",
|
||||
"architectures": [
|
||||
"Gemma2Model"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"bos_token_id": 2,
|
||||
"cache_implementation": "hybrid",
|
||||
"eos_token_id": 1,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"head_dim": 256,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2304,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9216,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "gemma2",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 26,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 0,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.44.2",
|
||||
"use_cache": True,
|
||||
"vocab_size": 256000
|
||||
}
|
||||
|
||||
config = Gemma2Config(**GEMMA2_CONFIG)
|
||||
with init_empty_weights():
|
||||
gemma2 = Gemma2Model._from_config(config)
|
||||
@@ -104,6 +113,13 @@ def load_gemma2(
|
||||
sd = load_safetensors(
|
||||
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
|
||||
)
|
||||
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = gemma2.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
|
||||
@@ -9,7 +9,9 @@ from library.strategy_base import (
|
||||
LatentsCachingStrategy,
|
||||
TokenizeStrategy,
|
||||
TextEncodingStrategy,
|
||||
TextEncoderOutputsCachingStrategy
|
||||
)
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
@@ -345,7 +345,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
lumina_train_utils.add_lumina_train_arguments(parser)
|
||||
lumina_train_util.add_lumina_train_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user