mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
387 lines
16 KiB
Python
387 lines
16 KiB
Python
import argparse
|
|
import copy
|
|
from typing import Any, Tuple
|
|
|
|
import torch
|
|
|
|
from library.device_utils import clean_memory_on_device, init_ipex
|
|
|
|
init_ipex()
|
|
|
|
from torch import Tensor
|
|
from accelerate import Accelerator
|
|
|
|
|
|
import train_network
|
|
from library import (
|
|
lumina_models,
|
|
lumina_util,
|
|
lumina_train_util,
|
|
sd3_train_utils,
|
|
strategy_base,
|
|
strategy_lumina,
|
|
train_util,
|
|
)
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sample_prompts_te_outputs = None
|
|
self.is_swapping_blocks: bool = False
|
|
|
|
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
|
|
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
|
|
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
|
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
|
args.cache_text_encoder_outputs = True
|
|
|
|
train_dataset_group.verify_bucket_reso_steps(16)
|
|
if val_dataset_group is not None:
|
|
val_dataset_group.verify_bucket_reso_steps(16)
|
|
|
|
self.train_gemma2 = not args.network_train_unet_only
|
|
|
|
def load_target_model(self, args, weight_dtype, accelerator):
|
|
loading_dtype = None if args.fp8_base else weight_dtype
|
|
|
|
model = lumina_util.load_lumina_model(
|
|
args.pretrained_model_name_or_path,
|
|
loading_dtype,
|
|
torch.device("cpu"),
|
|
disable_mmap=args.disable_mmap_load_safetensors,
|
|
use_flash_attn=args.use_flash_attn,
|
|
use_sage_attn=args.use_sage_attn,
|
|
)
|
|
|
|
if args.fp8_base:
|
|
# check dtype of model
|
|
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
|
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
|
elif model.dtype == torch.float8_e4m3fn:
|
|
logger.info("Loaded fp8 Lumina 2 model")
|
|
else:
|
|
logger.info(
|
|
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
|
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
|
)
|
|
model.to(torch.float8_e4m3fn)
|
|
|
|
if args.blocks_to_swap:
|
|
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
|
|
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
|
self.is_swapping_blocks = True
|
|
|
|
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
|
|
gemma2.eval()
|
|
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
|
|
|
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
|
|
|
def get_tokenize_strategy(self, args):
|
|
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
|
|
|
|
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
|
|
return [tokenize_strategy.tokenizer]
|
|
|
|
def get_latents_caching_strategy(self, args):
|
|
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
|
|
|
def get_text_encoding_strategy(self, args):
|
|
return strategy_lumina.LuminaTextEncodingStrategy()
|
|
|
|
def get_text_encoders_train_flags(self, args, text_encoders):
|
|
return [self.train_gemma2]
|
|
|
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
|
if args.cache_text_encoder_outputs:
|
|
# if the text encoders is trained, we need tokenization, so is_partial is True
|
|
return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
|
args.cache_text_encoder_outputs_to_disk,
|
|
args.text_encoder_batch_size,
|
|
args.skip_cache_check,
|
|
is_partial=self.train_gemma2,
|
|
)
|
|
else:
|
|
return None
|
|
|
|
def cache_text_encoder_outputs_if_needed(
|
|
self,
|
|
args,
|
|
accelerator: Accelerator,
|
|
unet,
|
|
vae,
|
|
text_encoders,
|
|
dataset,
|
|
weight_dtype,
|
|
):
|
|
if args.cache_text_encoder_outputs:
|
|
if not args.lowram:
|
|
# メモリ消費を減らす
|
|
logger.info("move vae and unet to cpu to save memory")
|
|
org_vae_device = vae.device
|
|
org_unet_device = unet.device
|
|
vae.to("cpu")
|
|
unet.to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
|
logger.info("move text encoders to gpu")
|
|
# Lumina uses a single text encoder (Gemma2) at index 0.
|
|
# Check original dtype BEFORE casting to preserve fp8 detection.
|
|
gemma2_original_dtype = text_encoders[0].dtype
|
|
text_encoders[0].to(accelerator.device)
|
|
|
|
if gemma2_original_dtype == torch.float8_e4m3fn:
|
|
# Model was loaded as fp8 — apply fp8 optimization
|
|
self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
|
|
else:
|
|
# Otherwise, cast to target dtype
|
|
text_encoders[0].to(weight_dtype)
|
|
|
|
with accelerator.autocast():
|
|
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
|
|
|
# cache sample prompts
|
|
if args.sample_prompts is not None:
|
|
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
|
|
|
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
|
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
|
|
|
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
|
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
|
|
|
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
|
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
|
with accelerator.autocast(), torch.no_grad():
|
|
for prompt_dict in sample_prompts:
|
|
prompts = [
|
|
prompt_dict.get("prompt", ""),
|
|
prompt_dict.get("negative_prompt", ""),
|
|
]
|
|
for i, prompt in enumerate(prompts):
|
|
if prompt in sample_prompts_te_outputs:
|
|
continue
|
|
|
|
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
|
|
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
|
|
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy,
|
|
text_encoders,
|
|
tokens_and_masks,
|
|
)
|
|
|
|
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
# move back to cpu
|
|
if not self.is_train_text_encoder(args):
|
|
logger.info("move Gemma 2 back to cpu")
|
|
text_encoders[0].to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
if not args.lowram:
|
|
logger.info("move vae and unet back to original device")
|
|
vae.to(org_vae_device)
|
|
unet.to(org_unet_device)
|
|
else:
|
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
|
|
|
def sample_images(
|
|
self,
|
|
accelerator,
|
|
args,
|
|
epoch,
|
|
global_step,
|
|
device,
|
|
vae,
|
|
tokenizer,
|
|
text_encoder,
|
|
lumina,
|
|
):
|
|
lumina_train_util.sample_images(
|
|
accelerator,
|
|
args,
|
|
epoch,
|
|
global_step,
|
|
lumina,
|
|
vae,
|
|
self.get_models_for_text_encoding(args, accelerator, text_encoder),
|
|
self.sample_prompts_te_outputs,
|
|
)
|
|
|
|
# Remaining methods maintain similar structure to flux implementation
|
|
# with Lumina-specific model calls and strategies
|
|
|
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
|
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
|
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
|
return noise_scheduler
|
|
|
|
def encode_images_to_latents(self, args, vae, images):
|
|
return vae.encode(images)
|
|
|
|
# not sure, they use same flux vae
|
|
def shift_scale_latents(self, args, latents):
|
|
return latents
|
|
|
|
def get_noise_pred_and_target(
|
|
self,
|
|
args,
|
|
accelerator: Accelerator,
|
|
noise_scheduler,
|
|
latents,
|
|
batch,
|
|
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
|
|
dit: lumina_models.NextDiT,
|
|
network,
|
|
weight_dtype,
|
|
train_unet,
|
|
is_train=True,
|
|
):
|
|
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
|
|
noise = torch.randn_like(latents)
|
|
# get noisy model input and timesteps
|
|
noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps(
|
|
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
|
)
|
|
|
|
# ensure the hidden state will require grad
|
|
if args.gradient_checkpointing:
|
|
noisy_model_input.requires_grad_(True)
|
|
for t in text_encoder_conds:
|
|
if t is not None and t.dtype.is_floating_point:
|
|
t.requires_grad_(True)
|
|
|
|
# Unpack Gemma2 outputs
|
|
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
|
|
|
def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
|
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
|
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
|
model_pred = dit(
|
|
x=img, # image latents (B, C, H, W)
|
|
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
|
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
|
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
|
)
|
|
return model_pred
|
|
|
|
model_pred = call_dit(
|
|
img=noisy_model_input,
|
|
gemma2_hidden_states=gemma2_hidden_states,
|
|
gemma2_attn_mask=gemma2_attn_mask,
|
|
timesteps=timesteps,
|
|
)
|
|
|
|
# apply model prediction type
|
|
model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
|
|
|
# flow matching loss
|
|
target = latents - noise
|
|
|
|
# differential output preservation
|
|
if "custom_attributes" in batch:
|
|
diff_output_pr_indices = []
|
|
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
|
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
|
diff_output_pr_indices.append(i)
|
|
|
|
if len(diff_output_pr_indices) > 0:
|
|
network.set_multiplier(0.0)
|
|
with torch.no_grad():
|
|
model_pred_prior = call_dit(
|
|
img=noisy_model_input[diff_output_pr_indices],
|
|
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
|
|
timesteps=timesteps[diff_output_pr_indices],
|
|
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
|
|
)
|
|
network.set_multiplier(1.0)
|
|
|
|
# model_pred_prior = lumina_util.unpack_latents(
|
|
# model_pred_prior, packed_latent_height, packed_latent_width
|
|
# )
|
|
model_pred_prior, _ = lumina_train_util.apply_model_prediction_type(
|
|
args,
|
|
model_pred_prior,
|
|
noisy_model_input[diff_output_pr_indices],
|
|
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
|
)
|
|
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
|
|
|
return model_pred, target, timesteps, weighting
|
|
|
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
|
return loss
|
|
|
|
def get_sai_model_spec(self, args):
|
|
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
|
|
|
|
def update_metadata(self, metadata, args):
|
|
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
|
metadata["ss_logit_mean"] = args.logit_mean
|
|
metadata["ss_logit_std"] = args.logit_std
|
|
metadata["ss_mode_scale"] = args.mode_scale
|
|
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
|
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
|
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
|
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
|
|
|
def is_text_encoder_not_needed_for_training(self, args):
|
|
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
|
|
|
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
|
text_encoder.embed_tokens.requires_grad_(True)
|
|
|
|
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
|
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
|
text_encoder.to(te_weight_dtype) # fp8
|
|
text_encoder.embed_tokens.to(dtype=weight_dtype)
|
|
|
|
def prepare_unet_with_accelerator(
|
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
|
) -> torch.nn.Module:
|
|
if not self.is_swapping_blocks:
|
|
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
|
|
|
# if we doesn't swap blocks, we can move the model to device
|
|
nextdit = unet
|
|
assert isinstance(nextdit, lumina_models.NextDiT)
|
|
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
|
|
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
|
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
|
|
|
return nextdit
|
|
|
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
|
if self.is_swapping_blocks:
|
|
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = train_network.setup_parser()
|
|
train_util.add_dit_training_arguments(parser)
|
|
lumina_train_util.add_lumina_train_arguments(parser)
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
args = parser.parse_args()
|
|
train_util.verify_command_line_training_args(args)
|
|
args = train_util.read_config_from_file(args, parser)
|
|
|
|
trainer = LuminaNetworkTrainer()
|
|
trainer.train(args)
|