mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
add stage c tmp training code
This commit is contained in:
@@ -6,8 +6,10 @@ import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import safetensors
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
r"""
|
||||
@@ -55,11 +57,13 @@ ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_STABLE_CASCADE = "stable-cascade"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
@@ -113,6 +117,7 @@ def build_metadata(
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
stable_cascade: Optional[bool] = None,
|
||||
):
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
@@ -124,7 +129,9 @@ def build_metadata(
|
||||
# hash = precalculate_safetensors_hashes(state_dict)
|
||||
# metadata["modelspec.hash_sha256"] = hash
|
||||
|
||||
if sdxl:
|
||||
if stable_cascade:
|
||||
arch = ARCH_STABLE_CASCADE
|
||||
elif sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
@@ -142,9 +149,11 @@ def build_metadata(
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
if stable_cascade:
|
||||
impl = IMPL_STABILITY_AI_STABLE_CASCADE
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -236,7 +245,7 @@ def build_metadata(
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -250,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]:
|
||||
def load_metadata_from_safetensors(model: str) -> dict:
|
||||
if not model.endswith(".safetensors"):
|
||||
return {}
|
||||
|
||||
|
||||
with safetensors.safe_open(model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# https://github.com/Stability-AI/StableCascade
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -901,15 +901,19 @@ class StageC(nn.Module):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
|
||||
|
||||
def get_clip_conditions(captions: List[str], tokenizer, text_model):
|
||||
def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
|
||||
# self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
|
||||
# is_eval の処理をここでやるのは微妙なので別のところでやる
|
||||
# is_unconditional もここでやるのは微妙なので別のところでやる
|
||||
# clip_image はとりあえずサポートしない
|
||||
clip_tokens_unpooled = tokenizer(
|
||||
captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
).to(text_model.device)
|
||||
text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
|
||||
if captions is not None:
|
||||
clip_tokens_unpooled = tokenizer(
|
||||
captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
).to(text_model.device)
|
||||
text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
|
||||
else:
|
||||
text_encoder_output = text_model(input_ids, output_hidden_states=True)
|
||||
|
||||
text_embeddings = text_encoder_output.hidden_states[-1]
|
||||
text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
@@ -1262,4 +1266,108 @@ class CosineTNoiseCond(BaseNoiseCond):
|
||||
return t
|
||||
|
||||
|
||||
# --- Loss Weighting
|
||||
class BaseLossWeight:
|
||||
def weight(self, logSNR):
|
||||
raise NotImplementedError("this method needs to be overridden")
|
||||
|
||||
def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
|
||||
clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
|
||||
if shift != 1:
|
||||
logSNR = logSNR.clone() + 2 * np.log(shift)
|
||||
return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
|
||||
|
||||
|
||||
# class ComposedLossWeight(BaseLossWeight):
|
||||
# def __init__(self, div, mul):
|
||||
# self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
|
||||
# self.div = [div] if isinstance(div, BaseLossWeight) else div
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# prod, div = 1, 1
|
||||
# for m in self.mul:
|
||||
# prod *= m.weight(logSNR)
|
||||
# for d in self.div:
|
||||
# div *= d.weight(logSNR)
|
||||
# return prod/div
|
||||
|
||||
# class ConstantLossWeight(BaseLossWeight):
|
||||
# def __init__(self, v=1):
|
||||
# self.v = v
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# return torch.ones_like(logSNR) * self.v
|
||||
|
||||
# class SNRLossWeight(BaseLossWeight):
|
||||
# def weight(self, logSNR):
|
||||
# return logSNR.exp()
|
||||
|
||||
|
||||
class P2LossWeight(BaseLossWeight):
|
||||
def __init__(self, k=1.0, gamma=1.0, s=1.0):
|
||||
self.k, self.gamma, self.s = k, gamma, s
|
||||
|
||||
def weight(self, logSNR):
|
||||
return (self.k + (logSNR * self.s).exp()) ** -self.gamma
|
||||
|
||||
|
||||
# class SNRPlusOneLossWeight(BaseLossWeight):
|
||||
# def weight(self, logSNR):
|
||||
# return logSNR.exp() + 1
|
||||
|
||||
# class MinSNRLossWeight(BaseLossWeight):
|
||||
# def __init__(self, max_snr=5):
|
||||
# self.max_snr = max_snr
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# return logSNR.exp().clamp(max=self.max_snr)
|
||||
|
||||
# class MinSNRPlusOneLossWeight(BaseLossWeight):
|
||||
# def __init__(self, max_snr=5):
|
||||
# self.max_snr = max_snr
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# return (logSNR.exp() + 1).clamp(max=self.max_snr)
|
||||
|
||||
# class TruncatedSNRLossWeight(BaseLossWeight):
|
||||
# def __init__(self, min_snr=1):
|
||||
# self.min_snr = min_snr
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# return logSNR.exp().clamp(min=self.min_snr)
|
||||
|
||||
# class SechLossWeight(BaseLossWeight):
|
||||
# def __init__(self, div=2):
|
||||
# self.div = div
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# return 1/(logSNR/self.div).cosh()
|
||||
|
||||
# class DebiasedLossWeight(BaseLossWeight):
|
||||
# def weight(self, logSNR):
|
||||
# return 1/logSNR.exp().sqrt()
|
||||
|
||||
# class SigmoidLossWeight(BaseLossWeight):
|
||||
# def __init__(self, s=1):
|
||||
# self.s = s
|
||||
|
||||
# def weight(self, logSNR):
|
||||
# return (logSNR * self.s).sigmoid()
|
||||
|
||||
|
||||
class AdaptiveLossWeight(BaseLossWeight):
|
||||
def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
|
||||
self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1)
|
||||
self.bucket_losses = torch.ones(buckets)
|
||||
self.weight_range = weight_range
|
||||
|
||||
def weight(self, logSNR):
|
||||
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
|
||||
return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
|
||||
|
||||
def update_buckets(self, logSNR, loss, beta=0.99):
|
||||
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
|
||||
self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta)
|
||||
|
||||
|
||||
# endregion gdf
|
||||
|
||||
504
library/stable_cascade_utils.py
Normal file
504
library/stable_cascade_utils.py
Normal file
@@ -0,0 +1,504 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library import stable_cascade as sc
|
||||
from library.train_util import (
|
||||
ImageInfo,
|
||||
load_image,
|
||||
trim_and_resize_if_required,
|
||||
save_latents_to_disk,
|
||||
HIGH_VRAM,
|
||||
save_text_encoder_outputs_to_disk,
|
||||
)
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.train_util import save_sd_model_on_epoch_end_or_stepwise_common, save_sd_model_on_train_end_common
|
||||
from library import sai_model_spec
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
EFFNET_PREPROCESS = torchvision.transforms.Compose(
|
||||
[torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]
|
||||
)
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
|
||||
LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
|
||||
|
||||
|
||||
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
|
||||
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
|
||||
effnet = sc.EfficientNetEncoder()
|
||||
effnet_checkpoint = load_file(effnet_checkpoint_path)
|
||||
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
|
||||
logger.info(info)
|
||||
del effnet_checkpoint
|
||||
return effnet
|
||||
|
||||
|
||||
def load_tokenizer(args: argparse.Namespace):
|
||||
# TODO commonize with sdxl_train_util.load_tokenizers
|
||||
logger.info("prepare tokenizers")
|
||||
|
||||
original_paths = [CLIP_TEXT_MODEL_NAME]
|
||||
tokenizers = []
|
||||
for i, original_path in enumerate(original_paths):
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
tokenizers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
|
||||
return tokenizers[0]
|
||||
|
||||
|
||||
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
|
||||
# Generator
|
||||
logger.info(f"Instantiating Stage C generator")
|
||||
with init_empty_weights():
|
||||
generator_c = sc.StageC()
|
||||
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
|
||||
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
|
||||
logger.info(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
return generator_c
|
||||
|
||||
|
||||
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
|
||||
logger.info(f"Instantiating Stage B generator")
|
||||
with init_empty_weights():
|
||||
generator_b = sc.StageB()
|
||||
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
|
||||
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
|
||||
logger.info(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
return generator_b
|
||||
|
||||
|
||||
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
|
||||
# CLIP encoders
|
||||
logger.info(f"Loading CLIP text model")
|
||||
if save_text_model or text_model_checkpoint_path is None:
|
||||
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
|
||||
|
||||
if save_text_model:
|
||||
sd = text_model.state_dict()
|
||||
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
|
||||
save_file(sd, text_model_checkpoint_path)
|
||||
else:
|
||||
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
|
||||
|
||||
# copy from sdxl_model_util.py
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
text_model_checkpoint = load_file(text_model_checkpoint_path)
|
||||
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
|
||||
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
|
||||
stage_a = sc.StageA().to(device)
|
||||
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
|
||||
info = stage_a.load_state_dict(
|
||||
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
|
||||
)
|
||||
logger.info(info)
|
||||
return stage_a
|
||||
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
expected_latents_size = (reso[1] // 32, reso[0] // 32) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
|
||||
return False
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
effnet: sc.EfficientNetEncoder,
|
||||
cache_to_disk: bool,
|
||||
image_infos: List[ImageInfo],
|
||||
flip_aug: bool,
|
||||
random_crop: bool,
|
||||
device,
|
||||
dtype,
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
optionally requires image_infos to have: image
|
||||
if cache_to_disk is True, set info.latents_npz
|
||||
flipped latents is also saved if flip_aug is True
|
||||
if cache_to_disk is False, set info.latents
|
||||
latents_flipped is also set if flip_aug is True
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image = EFFNET_PREPROCESS(image)
|
||||
images.append(image)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=device, dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = effnet(img_tensors).to("cpu")
|
||||
print(latents.shape)
|
||||
|
||||
if flip_aug:
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = effnet(img_tensors).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents)
|
||||
|
||||
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(device)
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids, dtype):
|
||||
# 75 トークン越えは未対応
|
||||
input_ids = input_ids.to(text_encoders[0].device)
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state, b_pool = sc.get_clip_conditions(None, input_ids, tokenizers[0], text_encoders[0])
|
||||
|
||||
b_hidden_state = b_hidden_state.detach().to("cpu") # b,n*75+2,768
|
||||
b_pool = b_pool.detach().to("cpu") # b,1280
|
||||
|
||||
for info, hidden_state, pool in zip(image_infos, b_hidden_state, b_pool):
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, None, hidden_state, pool)
|
||||
else:
|
||||
info.text_encoder_outputs1 = hidden_state
|
||||
info.text_encoder_pool2 = pool
|
||||
|
||||
|
||||
def add_effnet_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--effnet_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_text_model_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--text_model_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
|
||||
)
|
||||
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_a_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_a_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_b_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_b_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_c_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_c_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_sai_model_spec(args):
|
||||
timestamp = time.time()
|
||||
|
||||
reso = args.resolution
|
||||
|
||||
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
||||
|
||||
if args.min_timestep is not None or args.max_timestep is not None:
|
||||
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
||||
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
||||
timesteps = (min_time_step, max_time_step)
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
metadata = sai_model_spec.build_metadata(
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
timestamp,
|
||||
title=title,
|
||||
reso=reso,
|
||||
is_stable_diffusion_ckpt=False,
|
||||
author=args.metadata_author,
|
||||
description=args.metadata_description,
|
||||
license=args.metadata_license,
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
stable_cascade=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
def save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
|
||||
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
|
||||
)
|
||||
|
||||
|
||||
def save_stage_c_model_on_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
|
||||
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
|
||||
|
||||
|
||||
def cache_latents(self, effnet, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + LATENTS_CACHE_SUFFIX
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(effnet, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
||||
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
def cache_text_encoder_outputs(self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True):
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
image_infos_to_cache = []
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
continue
|
||||
|
||||
image_infos_to_cache.append(info)
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(device)
|
||||
if weight_dtype is not None:
|
||||
text_encoder.to(dtype=weight_dtype)
|
||||
|
||||
# create batch
|
||||
batch = []
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
batch.append((info, input_ids1, None))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
logger.info("caching text encoder outputs...")
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, weight_dtype
|
||||
)
|
||||
@@ -909,7 +909,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
@@ -1325,7 +1325,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.caching_mode == "text":
|
||||
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None
|
||||
else:
|
||||
input_ids1 = None
|
||||
input_ids2 = None
|
||||
@@ -2328,7 +2328,7 @@ def cache_batch_text_encoder_outputs(
|
||||
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||
np.savez(
|
||||
npz_path,
|
||||
hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||
hidden_state1=hidden_state1.cpu().float().numpy() if hidden_state1 is not None else None,
|
||||
hidden_state2=hidden_state2.cpu().float().numpy(),
|
||||
pool2=pool2.cpu().float().numpy(),
|
||||
)
|
||||
@@ -2684,6 +2684,14 @@ def get_sai_model_spec(
|
||||
return metadata
|
||||
|
||||
|
||||
def add_tokenizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--tokenizer_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
|
||||
@@ -2696,12 +2704,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
add_tokenizer_arguments(parser)
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
@@ -3150,18 +3153,22 @@ def verify_training_args(args: argparse.Namespace):
|
||||
print("highvram is enabled / highvramが有効です")
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
|
||||
if args.cache_latents_to_disk and not args.cache_latents:
|
||||
args.cache_latents = True
|
||||
logger.warning(
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
)
|
||||
|
||||
if not hasattr(args, "v_parameterization"):
|
||||
# Stable Cascade: skip following checks
|
||||
return
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
||||
# # Listを使って数えてもいいけど並べてしまえ
|
||||
# if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
||||
|
||||
@@ -11,11 +11,11 @@ from PIL import Image
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
import library.stable_cascade as sc
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.device_utils as device_utils
|
||||
from library import train_util
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
|
||||
clip_text_model_name: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
|
||||
resolution_multiple = 42.67
|
||||
@@ -45,94 +45,31 @@ def main(args):
|
||||
text_model_dtype = torch.float32
|
||||
|
||||
# EfficientNet encoder
|
||||
print(f"Loading EfficientNet encoder from {args.effnet_checkpoint_path}")
|
||||
effnet = sc.EfficientNetEncoder()
|
||||
effnet_checkpoint = load_file(args.effnet_checkpoint_path)
|
||||
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
|
||||
print(info)
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
||||
effnet.eval().requires_grad_(False).to(loading_device)
|
||||
del effnet_checkpoint
|
||||
|
||||
# Generator
|
||||
print(f"Instantiating Stage C generator")
|
||||
with init_empty_weights():
|
||||
generator_c = sc.StageC()
|
||||
print(f"Loading Stage C generator from {args.stage_c_checkpoint_path}")
|
||||
stage_c_checkpoint = load_file(args.stage_c_checkpoint_path)
|
||||
print(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, loading_device, dtype=dtype)
|
||||
print(info)
|
||||
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
generator_c.eval().requires_grad_(False).to(loading_device)
|
||||
|
||||
print(f"Instantiating Stage B generator")
|
||||
with init_empty_weights():
|
||||
generator_b = sc.StageB()
|
||||
print(f"Loading Stage B generator from {args.stage_b_checkpoint_path}")
|
||||
stage_b_checkpoint = load_file(args.stage_b_checkpoint_path)
|
||||
print(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, loading_device, dtype=dtype)
|
||||
print(info)
|
||||
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
generator_b.eval().requires_grad_(False).to(loading_device)
|
||||
|
||||
# CLIP encoders
|
||||
print(f"Loading CLIP text model")
|
||||
|
||||
# TODO 完全にオフラインで動かすには tokenizer もローカルに保存できるようにする必要がある
|
||||
tokenizer = AutoTokenizer.from_pretrained(clip_text_model_name)
|
||||
|
||||
if args.save_text_model or args.text_model_checkpoint_path is None:
|
||||
print(f"Loading CLIP text model from {clip_text_model_name}")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(clip_text_model_name)
|
||||
|
||||
if args.save_text_model:
|
||||
sd = text_model.state_dict()
|
||||
print(f"Saving CLIP text model to {args.text_model_checkpoint_path}")
|
||||
save_file(sd, args.text_model_checkpoint_path)
|
||||
else:
|
||||
print(f"Loading CLIP text model from {args.text_model_checkpoint_path}")
|
||||
|
||||
# copy from sdxl_model_util.py
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
text_model_checkpoint = load_file(args.text_model_checkpoint_path)
|
||||
info = _load_state_dict_on_device(text_model, text_model_checkpoint, text_model_device, dtype=text_model_dtype)
|
||||
print(info)
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
text_model = sc_utils.load_clip_text_model(
|
||||
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
|
||||
)
|
||||
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
|
||||
|
||||
# image_model = (
|
||||
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
|
||||
# )
|
||||
|
||||
# vqGAN
|
||||
print(f"Loading Stage A vqGAN from {args.stage_a_checkpoint_path}")
|
||||
stage_a = sc.StageA().to(loading_device)
|
||||
stage_a_checkpoint = load_file(args.stage_a_checkpoint_path)
|
||||
info = stage_a.load_state_dict(
|
||||
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
|
||||
)
|
||||
print(info)
|
||||
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
stage_a.eval().requires_grad_(False)
|
||||
|
||||
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
||||
@@ -169,19 +106,19 @@ def main(args):
|
||||
# extras_b.sampling_configs["t_start"] = 1.0
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
cond_text, cond_pooled = sc.get_clip_conditions([caption], tokenizer, text_model)
|
||||
cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.to(device, dtype=dtype)
|
||||
|
||||
uncond_text, uncond_pooled = sc.get_clip_conditions([""], tokenizer, text_model)
|
||||
uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.to(device, dtype=dtype)
|
||||
|
||||
img_emb = torch.zeros(1, 768, device=device)
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": img_emb}
|
||||
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": img_emb}
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
|
||||
conditions_b = {}
|
||||
conditions_b.update(conditions)
|
||||
unconditions_b = {}
|
||||
@@ -249,14 +186,13 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--effnet_checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--stage_a_checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--stage_b_checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--stage_c_checkpoint_path", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--text_model_checkpoint_path", type=str, required=False, default=None, help="if omitted, download from HuggingFace"
|
||||
)
|
||||
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
||||
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_stage_a_arguments(parser)
|
||||
sc_utils.add_stage_b_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
parser.add_argument("--bf16", action="store_true")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
||||
|
||||
526
stable_cascade_train_stage_c.py
Normal file
526
stable_cascade_train_stage_c.py
Normal file
@@ -0,0 +1,526 @@
|
||||
# training with captions
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.sdxl_train_util import add_sdxl_training_arguments
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.stable_cascade as sc
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
|
||||
# TODO add assertions for other unsupported options
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
loading_device = accelerator.device if args.lowram else "cpu"
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
||||
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device)
|
||||
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
raise NotImplementedError("Caching latents is not supported in this version / latentのキャッシュはサポートされていません")
|
||||
logger.info(
|
||||
"Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`."
|
||||
+ " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。"
|
||||
)
|
||||
# effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
effnet,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
cache_func=sc_utils.cache_batch_latents,
|
||||
)
|
||||
effnet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.gradient_checkpointing:
|
||||
logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.")
|
||||
# stage_c.enable_gradient_checkpointing()
|
||||
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
raise NotImplementedError(
|
||||
"Caching text encoder outputs is not supported in this version / text encoderの出力のキャッシュはサポートされていません"
|
||||
)
|
||||
print(
|
||||
f"Please make sure that the text encoder outputs are cached before training with `stable_cascade_cache_text_encoder_outputs.py`."
|
||||
+ " / 学習前に`stable_cascade_cache_text_encoder_outputs.py`でtext encoderの出力をキャッシュしてください。"
|
||||
)
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer),
|
||||
(text_encoder1),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
|
||||
stage_c.requires_grad_(True)
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(stage_c)
|
||||
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
stage_c.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
stage_c.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
stage_c = accelerator.prepare(stage_c)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
# 謎のクラス GDF
|
||||
gdf = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=sc.AdaptiveLossWeight(),
|
||||
)
|
||||
|
||||
# 以下2つの変数は、どうもデフォルトのままっぽい
|
||||
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
|
||||
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
|
||||
|
||||
# noise_scheduler = DDPMScheduler(
|
||||
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
# )
|
||||
# prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
# if args.zero_terminal_snr:
|
||||
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# # For --sample_at_first
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# 0,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# effnet,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# stage_c,
|
||||
# )
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = effnet(batch["images"].to(effnet_dtype)).to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
with torch.no_grad():
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
# unwrap_model is fine for models not wrapped by accelerator
|
||||
encoder_hidden_states, pool = sc.get_clip_conditions(None, input_ids1, tokenizer, text_encoder1)
|
||||
else:
|
||||
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# FORWARD PASS
|
||||
with torch.no_grad():
|
||||
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
|
||||
|
||||
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
|
||||
with accelerator.autocast():
|
||||
pred = stage_c(
|
||||
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
|
||||
loss_adjusted = (loss * loss_weight).mean()
|
||||
|
||||
gdf.loss_weight.update_buckets(logSNR, loss)
|
||||
|
||||
accelerator.backward(loss_adjusted)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# None,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# effnet,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# stage_c,
|
||||
# )
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.accelerator.unwrap_model(stage_c),
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(stage_c)
|
||||
)
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# epoch + 1,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# effnet,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# stage_c,
|
||||
# )
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state: # and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
sc_utils.save_stage_c_model_on_end(args, save_dtype, epoch, global_step, stage_c)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_sdxl_training_arguments(parser) # cache text encoder outputs
|
||||
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
191
tools/stable_cascade_cache_latents.py
Normal file
191
tools/stable_cascade_cache_latents.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Stable Cascadeのlatentsをdiskにキャッシュする
|
||||
# cache latents of Stable Cascade to disk
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import stable_cascade_utils as sc_utils
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache latents arg
|
||||
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("latents")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
b_size = len(batch["images"])
|
||||
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
||||
flip_aug = batch["flip_aug"]
|
||||
random_crop = batch["random_crop"]
|
||||
bucket_reso = batch["bucket_reso"]
|
||||
|
||||
# バッチを分割して処理する
|
||||
for i in range(0, b_size, vae_batch_size):
|
||||
images = batch["images"][i : i + vae_batch_size]
|
||||
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
||||
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
||||
|
||||
image_infos = []
|
||||
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.image = image
|
||||
image_info.bucket_reso = bucket_reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + sc_utils.LATENTS_CACHE_SUFFIX
|
||||
|
||||
if args.skip_existing:
|
||||
if sc_utils.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
sc_utils.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop, accelerator.device, effnet_dtype)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
183
tools/stable_cascade_cache_text_encoder_outputs.py
Normal file
183
tools/stable_cascade_cache_text_encoder_outputs.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library import stable_cascade_utils as sc_utils
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache arg
|
||||
assert (
|
||||
args.cache_text_encoder_outputs_to_disk
|
||||
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
text_encoder = sc_utils.load_clip_text_model(
|
||||
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("text")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
absolute_paths = batch["absolute_paths"]
|
||||
input_ids1_list = batch["input_ids1_list"]
|
||||
|
||||
image_infos = []
|
||||
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
image_info
|
||||
|
||||
if args.skip_existing:
|
||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_info.input_ids1 = input_ids1
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
||||
sc_utils.cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, weight_dtype
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
Reference in New Issue
Block a user