add stage c tmp training code

This commit is contained in:
Kohya S
2024-02-17 23:59:20 +09:00
parent fa440208b7
commit 319bbf8057
8 changed files with 1577 additions and 113 deletions

View File

@@ -6,8 +6,10 @@ import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
@@ -55,11 +57,13 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_STABLE_CASCADE = "stable-cascade"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
@@ -113,6 +117,7 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
stable_cascade: Optional[bool] = None,
):
# if state_dict is None, hash is not calculated
@@ -124,7 +129,9 @@ def build_metadata(
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
if stable_cascade:
arch = ARCH_STABLE_CASCADE
elif sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
@@ -142,9 +149,11 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if stable_cascade:
impl = IMPL_STABILITY_AI_STABLE_CASCADE
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -236,7 +245,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -250,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

View File

@@ -3,7 +3,7 @@
# https://github.com/Stability-AI/StableCascade
import math
from typing import List
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
@@ -901,15 +901,19 @@ class StageC(nn.Module):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
def get_clip_conditions(captions: List[str], tokenizer, text_model):
def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
# self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
# is_eval の処理をここでやるのは微妙なので別のところでやる
# is_unconditional もここでやるのは微妙なので別のところでやる
# clip_image はとりあえずサポートしない
clip_tokens_unpooled = tokenizer(
captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
).to(text_model.device)
text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
if captions is not None:
clip_tokens_unpooled = tokenizer(
captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
).to(text_model.device)
text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
else:
text_encoder_output = text_model(input_ids, output_hidden_states=True)
text_embeddings = text_encoder_output.hidden_states[-1]
text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
@@ -1262,4 +1266,108 @@ class CosineTNoiseCond(BaseNoiseCond):
return t
# --- Loss Weighting
class BaseLossWeight:
def weight(self, logSNR):
raise NotImplementedError("this method needs to be overridden")
def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
if shift != 1:
logSNR = logSNR.clone() + 2 * np.log(shift)
return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
# class ComposedLossWeight(BaseLossWeight):
# def __init__(self, div, mul):
# self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
# self.div = [div] if isinstance(div, BaseLossWeight) else div
# def weight(self, logSNR):
# prod, div = 1, 1
# for m in self.mul:
# prod *= m.weight(logSNR)
# for d in self.div:
# div *= d.weight(logSNR)
# return prod/div
# class ConstantLossWeight(BaseLossWeight):
# def __init__(self, v=1):
# self.v = v
# def weight(self, logSNR):
# return torch.ones_like(logSNR) * self.v
# class SNRLossWeight(BaseLossWeight):
# def weight(self, logSNR):
# return logSNR.exp()
class P2LossWeight(BaseLossWeight):
def __init__(self, k=1.0, gamma=1.0, s=1.0):
self.k, self.gamma, self.s = k, gamma, s
def weight(self, logSNR):
return (self.k + (logSNR * self.s).exp()) ** -self.gamma
# class SNRPlusOneLossWeight(BaseLossWeight):
# def weight(self, logSNR):
# return logSNR.exp() + 1
# class MinSNRLossWeight(BaseLossWeight):
# def __init__(self, max_snr=5):
# self.max_snr = max_snr
# def weight(self, logSNR):
# return logSNR.exp().clamp(max=self.max_snr)
# class MinSNRPlusOneLossWeight(BaseLossWeight):
# def __init__(self, max_snr=5):
# self.max_snr = max_snr
# def weight(self, logSNR):
# return (logSNR.exp() + 1).clamp(max=self.max_snr)
# class TruncatedSNRLossWeight(BaseLossWeight):
# def __init__(self, min_snr=1):
# self.min_snr = min_snr
# def weight(self, logSNR):
# return logSNR.exp().clamp(min=self.min_snr)
# class SechLossWeight(BaseLossWeight):
# def __init__(self, div=2):
# self.div = div
# def weight(self, logSNR):
# return 1/(logSNR/self.div).cosh()
# class DebiasedLossWeight(BaseLossWeight):
# def weight(self, logSNR):
# return 1/logSNR.exp().sqrt()
# class SigmoidLossWeight(BaseLossWeight):
# def __init__(self, s=1):
# self.s = s
# def weight(self, logSNR):
# return (logSNR * self.s).sigmoid()
class AdaptiveLossWeight(BaseLossWeight):
def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1)
self.bucket_losses = torch.ones(buckets)
self.weight_range = weight_range
def weight(self, logSNR):
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
def update_buckets(self, logSNR, loss, beta=0.99):
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta)
# endregion gdf

View File

@@ -0,0 +1,504 @@
import argparse
import os
import time
from typing import List
import numpy as np
import torch
import torchvision
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from accelerate import init_empty_weights
from library import stable_cascade as sc
from library.train_util import (
ImageInfo,
load_image,
trim_and_resize_if_required,
save_latents_to_disk,
HIGH_VRAM,
save_text_encoder_outputs_to_disk,
)
from library.sdxl_model_util import _load_state_dict_on_device
from library.device_utils import clean_memory_on_device
from library.train_util import save_sd_model_on_epoch_end_or_stepwise_common, save_sd_model_on_train_end_common
from library import sai_model_spec
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
EFFNET_PREPROCESS = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
effnet = sc.EfficientNetEncoder()
effnet_checkpoint = load_file(effnet_checkpoint_path)
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
logger.info(info)
del effnet_checkpoint
return effnet
def load_tokenizer(args: argparse.Namespace):
# TODO commonize with sdxl_train_util.load_tokenizers
logger.info("prepare tokenizers")
original_paths = [CLIP_TEXT_MODEL_NAME]
tokenizers = []
for i, original_path in enumerate(original_paths):
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
tokenizers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
logger.info(f"update token length: {args.max_token_length}")
return tokenizers[0]
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
# Generator
logger.info(f"Instantiating Stage C generator")
with init_empty_weights():
generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_c
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
logger.info(f"Instantiating Stage B generator")
with init_empty_weights():
generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_b
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
# CLIP encoders
logger.info(f"Loading CLIP text model")
if save_text_model or text_model_checkpoint_path is None:
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
if save_text_model:
sd = text_model.state_dict()
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
save_file(sd, text_model_checkpoint_path)
else:
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
# copy from sdxl_model_util.py
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model = CLIPTextModelWithProjection(text_model2_cfg)
text_model_checkpoint = load_file(text_model_checkpoint_path)
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
logger.info(info)
return text_model
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
stage_a = sc.StageA().to(device)
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
info = stage_a.load_state_dict(
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
)
logger.info(info)
return stage_a
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 32, reso[0] // 32) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
return True
def cache_batch_latents(
effnet: sc.EfficientNetEncoder,
cache_to_disk: bool,
image_infos: List[ImageInfo],
flip_aug: bool,
random_crop: bool,
device,
dtype,
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
optionally requires image_infos to have: image
if cache_to_disk is True, set info.latents_npz
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_ltrb are also set
"""
images = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image = EFFNET_PREPROCESS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=device, dtype=dtype)
with torch.no_grad():
latents = effnet(img_tensors).to("cpu")
print(latents.shape)
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = effnet(img_tensors).to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
if not HIGH_VRAM:
clean_memory_on_device(device)
def cache_batch_text_encoder_outputs(image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids, dtype):
# 75 トークン越えは未対応
input_ids = input_ids.to(text_encoders[0].device)
with torch.no_grad():
b_hidden_state, b_pool = sc.get_clip_conditions(None, input_ids, tokenizers[0], text_encoders[0])
b_hidden_state = b_hidden_state.detach().to("cpu") # b,n*75+2,768
b_pool = b_pool.detach().to("cpu") # b,1280
for info, hidden_state, pool in zip(image_infos, b_hidden_state, b_pool):
if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, None, hidden_state, pool)
else:
info.text_encoder_outputs1 = hidden_state
info.text_encoder_pool2 = pool
def add_effnet_arguments(parser):
parser.add_argument(
"--effnet_checkpoint_path",
type=str,
required=True,
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
)
return parser
def add_text_model_arguments(parser):
parser.add_argument(
"--text_model_checkpoint_path",
type=str,
required=True,
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
return parser
def add_stage_a_arguments(parser):
parser.add_argument(
"--stage_a_checkpoint_path",
type=str,
required=True,
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
)
return parser
def add_stage_b_arguments(parser):
parser.add_argument(
"--stage_b_checkpoint_path",
type=str,
required=True,
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
)
return parser
def add_stage_c_arguments(parser):
parser.add_argument(
"--stage_c_checkpoint_path",
type=str,
required=True,
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
)
return parser
def get_sai_model_spec(args):
timestamp = time.time()
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
None,
False,
False,
False,
False,
False,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=False,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
stable_cascade=True,
)
return metadata
def save_stage_c_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
stage_c,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
save_file(state_dict, ckpt_file, metadata=sai_metadata)
save_sd_model_on_epoch_end_or_stepwise_common(
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
)
def save_stage_c_model_on_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
stage_c,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
save_file(state_dict, ckpt_file, metadata=sai_metadata)
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
def cache_latents(self, effnet, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
image_infos = list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + LATENTS_CACHE_SUFFIX
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(effnet, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True):
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())
logger.info("checking cache existence...")
image_infos_to_cache = []
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
continue
if os.path.exists(te_out_npz):
continue
image_infos_to_cache.append(info)
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
# prepare tokenizers and text encoders
for text_encoder in text_encoders:
text_encoder.to(device)
if weight_dtype is not None:
text_encoder.to(dtype=weight_dtype)
# create batch
batch = []
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
batch.append((info, input_ids1, None))
if len(batch) >= self.batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, weight_dtype
)

View File

@@ -909,7 +909,7 @@ class BaseDataset(torch.utils.data.Dataset):
)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
image_infos = list(self.image_data.values())
@@ -1325,7 +1325,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None
else:
input_ids1 = None
input_ids2 = None
@@ -2328,7 +2328,7 @@ def cache_batch_text_encoder_outputs(
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
np.savez(
npz_path,
hidden_state1=hidden_state1.cpu().float().numpy(),
hidden_state1=hidden_state1.cpu().float().numpy() if hidden_state1 is not None else None,
hidden_state2=hidden_state2.cpu().float().numpy(),
pool2=pool2.cpu().float().numpy(),
)
@@ -2684,6 +2684,14 @@ def get_sai_model_spec(
return metadata
def add_tokenizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
@@ -2696,12 +2704,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
)
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
add_tokenizer_arguments(parser)
def add_optimizer_arguments(parser: argparse.ArgumentParser):
@@ -3150,18 +3153,22 @@ def verify_training_args(args: argparse.Namespace):
print("highvram is enabled / highvramが有効です")
global HIGH_VRAM
HIGH_VRAM = True
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
logger.warning(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
if not hasattr(args, "v_parameterization"):
# Stable Cascade: skip following checks
return
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
# # Listを使って数えてもいいけど並べてしまえ
# if args.noise_offset is not None and args.multires_noise_iterations is not None:

View File

@@ -11,11 +11,11 @@ from PIL import Image
from accelerate import init_empty_weights
import library.stable_cascade as sc
import library.stable_cascade_utils as sc_utils
import library.device_utils as device_utils
from library import train_util
from library.sdxl_model_util import _load_state_dict_on_device
clip_text_model_name: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
@@ -45,94 +45,31 @@ def main(args):
text_model_dtype = torch.float32
# EfficientNet encoder
print(f"Loading EfficientNet encoder from {args.effnet_checkpoint_path}")
effnet = sc.EfficientNetEncoder()
effnet_checkpoint = load_file(args.effnet_checkpoint_path)
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
print(info)
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
effnet.eval().requires_grad_(False).to(loading_device)
del effnet_checkpoint
# Generator
print(f"Instantiating Stage C generator")
with init_empty_weights():
generator_c = sc.StageC()
print(f"Loading Stage C generator from {args.stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(args.stage_c_checkpoint_path)
print(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, loading_device, dtype=dtype)
print(info)
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device)
print(f"Instantiating Stage B generator")
with init_empty_weights():
generator_b = sc.StageB()
print(f"Loading Stage B generator from {args.stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(args.stage_b_checkpoint_path)
print(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, loading_device, dtype=dtype)
print(info)
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device)
# CLIP encoders
print(f"Loading CLIP text model")
# TODO 完全にオフラインで動かすには tokenizer もローカルに保存できるようにする必要がある
tokenizer = AutoTokenizer.from_pretrained(clip_text_model_name)
if args.save_text_model or args.text_model_checkpoint_path is None:
print(f"Loading CLIP text model from {clip_text_model_name}")
text_model = CLIPTextModelWithProjection.from_pretrained(clip_text_model_name)
if args.save_text_model:
sd = text_model.state_dict()
print(f"Saving CLIP text model to {args.text_model_checkpoint_path}")
save_file(sd, args.text_model_checkpoint_path)
else:
print(f"Loading CLIP text model from {args.text_model_checkpoint_path}")
# copy from sdxl_model_util.py
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model = CLIPTextModelWithProjection(text_model2_cfg)
text_model_checkpoint = load_file(args.text_model_checkpoint_path)
info = _load_state_dict_on_device(text_model, text_model_checkpoint, text_model_device, dtype=text_model_dtype)
print(info)
tokenizer = sc_utils.load_tokenizer(args)
text_model = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
)
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
# image_model = (
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
# )
# vqGAN
print(f"Loading Stage A vqGAN from {args.stage_a_checkpoint_path}")
stage_a = sc.StageA().to(loading_device)
stage_a_checkpoint = load_file(args.stage_a_checkpoint_path)
info = stage_a.load_state_dict(
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
)
print(info)
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
stage_a.eval().requires_grad_(False)
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
@@ -169,19 +106,19 @@ def main(args):
# extras_b.sampling_configs["t_start"] = 1.0
# PREPARE CONDITIONS
cond_text, cond_pooled = sc.get_clip_conditions([caption], tokenizer, text_model)
cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.to(device, dtype=dtype)
uncond_text, uncond_pooled = sc.get_clip_conditions([""], tokenizer, text_model)
uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.to(device, dtype=dtype)
img_emb = torch.zeros(1, 768, device=device)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": img_emb}
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
conditions_b = {}
conditions_b.update(conditions)
unconditions_b = {}
@@ -249,14 +186,13 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--effnet_checkpoint_path", type=str, required=True)
parser.add_argument("--stage_a_checkpoint_path", type=str, required=True)
parser.add_argument("--stage_b_checkpoint_path", type=str, required=True)
parser.add_argument("--stage_c_checkpoint_path", type=str, required=True)
parser.add_argument(
"--text_model_checkpoint_path", type=str, required=False, default=None, help="if omitted, download from HuggingFace"
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
sc_utils.add_effnet_arguments(parser)
train_util.add_tokenizer_arguments(parser)
sc_utils.add_stage_a_arguments(parser)
sc_utils.add_stage_b_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")

View File

@@ -0,0 +1,526 @@
# training with captions
import argparse
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.sdxl_train_util import add_sdxl_training_arguments
import library.stable_cascade_utils as sc_utils
import library.stable_cascade as sc
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
# TODO add assertions for other unsupported options
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = sc_utils.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
loading_device = accelerator.device if args.lowram else "cpu"
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device)
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
# 学習を準備する
if cache_latents:
raise NotImplementedError("Caching latents is not supported in this version / latentのキャッシュはサポートされていません")
logger.info(
"Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`."
+ " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。"
)
# effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
effnet,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
cache_func=sc_utils.cache_batch_latents,
)
effnet.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.")
# stage_c.enable_gradient_checkpointing()
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder1.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
raise NotImplementedError(
"Caching text encoder outputs is not supported in this version / text encoderの出力のキャッシュはサポートされていません"
)
print(
f"Please make sure that the text encoder outputs are cached before training with `stable_cascade_cache_text_encoder_outputs.py`."
+ " / 学習前に`stable_cascade_cache_text_encoder_outputs.py`でtext encoderの出力をキャッシュしてください。"
)
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer),
(text_encoder1),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
if not cache_latents:
effnet.requires_grad_(False)
effnet.eval()
effnet.to(accelerator.device, dtype=effnet_dtype)
stage_c.requires_grad_(True)
training_models = []
params_to_optimize = []
training_models.append(stage_c)
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
stage_c = accelerator.prepare(stage_c)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
# 謎のクラス GDF
gdf = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=sc.AdaptiveLossWeight(),
)
# 以下2つの変数は、どうもデフォルトのままっぽい
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
# noise_scheduler = DDPMScheduler(
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
# )
# prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
# if args.zero_terminal_snr:
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# # For --sample_at_first
# sdxl_train_util.sample_images(
# accelerator,
# args,
# 0,
# global_step,
# accelerator.device,
# effnet,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# stage_c,
# )
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = effnet(batch["images"].to(effnet_dtype)).to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
with torch.no_grad():
# Get the text embedding for conditioning
# TODO support weighted captions
input_ids1 = input_ids1.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states, pool = sc.get_clip_conditions(None, input_ids1, tokenizer, text_encoder1)
else:
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# FORWARD PASS
with torch.no_grad():
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
with accelerator.autocast():
pred = stage_c(
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()
gdf.loss_weight.update_buckets(logSNR, loss)
accelerator.backward(loss_adjusted)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# effnet,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# stage_c,
# )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.accelerator.unwrap_model(stage_c),
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(stage_c)
)
# sdxl_train_util.sample_images(
# accelerator,
# args,
# epoch + 1,
# global_step,
# accelerator.device,
# effnet,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# stage_c,
# )
is_main_process = accelerator.is_main_process
# if is_main_process:
stage_c = accelerator.unwrap_model(stage_c)
accelerator.end_training()
if args.save_state: # and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
sc_utils.save_stage_c_model_on_end(args, save_dtype, epoch, global_step, stage_c)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
sc_utils.add_effnet_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_text_model_arguments(parser)
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_sdxl_training_arguments(parser) # cache text encoder outputs
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -0,0 +1,191 @@
# Stable Cascadeのlatentsをdiskにキャッシュする
# cache latents of Stable Cascade to disk
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import stable_cascade_utils as sc_utils
from library import config_util
from library import train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
logger.info("load model")
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + sc_utils.LATENTS_CACHE_SUFFIX
if args.skip_existing:
if sc_utils.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
sc_utils.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop, accelerator.device, effnet_dtype)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_effnet_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -0,0 +1,183 @@
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import stable_cascade_utils as sc_utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
logger.info("load model")
text_encoder = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
)
text_encoders = [text_encoder]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
image_infos = []
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
sc_utils.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, weight_dtype
)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)