mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
421 lines
19 KiB
Python
421 lines
19 KiB
Python
import os
|
|
import glob
|
|
import random
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
import torch
|
|
import numpy as np
|
|
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
|
|
|
from library import sd3_utils, train_util
|
|
from library import sd3_models
|
|
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
|
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
|
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
|
|
|
|
|
class Sd3TokenizeStrategy(TokenizeStrategy):
|
|
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
|
self.t5xxl_max_length = t5xxl_max_length
|
|
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
|
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
|
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
|
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
|
|
|
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
|
text = [text] if isinstance(text, str) else text
|
|
|
|
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
|
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
|
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
|
|
|
l_attn_mask = l_tokens["attention_mask"]
|
|
g_attn_mask = g_tokens["attention_mask"]
|
|
t5_attn_mask = t5_tokens["attention_mask"]
|
|
l_tokens = l_tokens["input_ids"]
|
|
g_tokens = g_tokens["input_ids"]
|
|
t5_tokens = t5_tokens["input_ids"]
|
|
|
|
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
|
|
|
|
|
|
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|
def __init__(
|
|
self,
|
|
apply_lg_attn_mask: Optional[bool] = None,
|
|
apply_t5_attn_mask: Optional[bool] = None,
|
|
l_dropout_rate: float = 0.0,
|
|
g_dropout_rate: float = 0.0,
|
|
t5_dropout_rate: float = 0.0,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
|
"""
|
|
self.apply_lg_attn_mask = apply_lg_attn_mask
|
|
self.apply_t5_attn_mask = apply_t5_attn_mask
|
|
self.l_dropout_rate = l_dropout_rate
|
|
self.g_dropout_rate = g_dropout_rate
|
|
self.t5_dropout_rate = t5_dropout_rate
|
|
|
|
def encode_tokens(
|
|
self,
|
|
tokenize_strategy: TokenizeStrategy,
|
|
models: List[Any],
|
|
tokens: List[torch.Tensor],
|
|
apply_lg_attn_mask: Optional[bool] = False,
|
|
apply_t5_attn_mask: Optional[bool] = False,
|
|
enable_dropout: bool = True,
|
|
) -> List[torch.Tensor]:
|
|
"""
|
|
returned embeddings are not masked
|
|
"""
|
|
clip_l, clip_g, t5xxl = models
|
|
clip_l: Optional[CLIPTextModel]
|
|
clip_g: Optional[CLIPTextModelWithProjection]
|
|
t5xxl: Optional[T5EncoderModel]
|
|
|
|
if apply_lg_attn_mask is None:
|
|
apply_lg_attn_mask = self.apply_lg_attn_mask
|
|
if apply_t5_attn_mask is None:
|
|
apply_t5_attn_mask = self.apply_t5_attn_mask
|
|
|
|
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
|
|
|
|
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
|
|
|
|
if l_tokens is None or clip_l is None:
|
|
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
|
lg_out = None
|
|
lg_pooled = None
|
|
l_attn_mask = None
|
|
g_attn_mask = None
|
|
else:
|
|
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
|
|
|
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
|
|
batch_size, l_seq_len = l_tokens.shape
|
|
g_seq_len = g_tokens.shape[1]
|
|
|
|
non_drop_l_indices = []
|
|
non_drop_g_indices = []
|
|
for i in range(l_tokens.shape[0]):
|
|
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
|
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
|
if not drop_l:
|
|
non_drop_l_indices.append(i)
|
|
if not drop_g:
|
|
non_drop_g_indices.append(i)
|
|
|
|
# filter out dropped members
|
|
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
|
|
l_tokens = l_tokens[non_drop_l_indices]
|
|
l_attn_mask = l_attn_mask[non_drop_l_indices]
|
|
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
|
|
g_tokens = g_tokens[non_drop_g_indices]
|
|
g_attn_mask = g_attn_mask[non_drop_g_indices]
|
|
|
|
# call clip_l for non-dropped members
|
|
if len(non_drop_l_indices) > 0:
|
|
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
|
|
prompt_embeds = clip_l(
|
|
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
|
)
|
|
nd_l_pooled = prompt_embeds[0]
|
|
nd_l_out = prompt_embeds.hidden_states[-2]
|
|
if len(non_drop_g_indices) > 0:
|
|
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
|
|
prompt_embeds = clip_g(
|
|
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
|
)
|
|
nd_g_pooled = prompt_embeds[0]
|
|
nd_g_out = prompt_embeds.hidden_states[-2]
|
|
|
|
# fill in the dropped members
|
|
if len(non_drop_l_indices) == batch_size:
|
|
l_pooled = nd_l_pooled
|
|
l_out = nd_l_out
|
|
else:
|
|
# model output is always float32 because of the models are wrapped with Accelerator
|
|
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
|
|
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
|
|
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
|
|
if len(non_drop_l_indices) > 0:
|
|
l_pooled[non_drop_l_indices] = nd_l_pooled
|
|
l_out[non_drop_l_indices] = nd_l_out
|
|
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
|
|
|
|
if len(non_drop_g_indices) == batch_size:
|
|
g_pooled = nd_g_pooled
|
|
g_out = nd_g_out
|
|
else:
|
|
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
|
|
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
|
|
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
|
|
if len(non_drop_g_indices) > 0:
|
|
g_pooled[non_drop_g_indices] = nd_g_pooled
|
|
g_out[non_drop_g_indices] = nd_g_out
|
|
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
|
|
|
|
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
|
|
|
if t5xxl is None or t5_tokens is None:
|
|
t5_out = None
|
|
t5_attn_mask = None
|
|
else:
|
|
# drop some members of the batch: we do not call t5xxl for dropped members
|
|
batch_size, t5_seq_len = t5_tokens.shape
|
|
non_drop_t5_indices = []
|
|
for i in range(t5_tokens.shape[0]):
|
|
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
|
if not drop_t5:
|
|
non_drop_t5_indices.append(i)
|
|
|
|
# filter out dropped members
|
|
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
|
|
t5_tokens = t5_tokens[non_drop_t5_indices]
|
|
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
|
|
|
|
# call t5xxl for non-dropped members
|
|
if len(non_drop_t5_indices) > 0:
|
|
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
|
nd_t5_out, _ = t5xxl(
|
|
t5_tokens.to(t5xxl.device),
|
|
nd_t5_attn_mask if apply_t5_attn_mask else None,
|
|
return_dict=False,
|
|
output_hidden_states=True,
|
|
)
|
|
|
|
# fill in the dropped members
|
|
if len(non_drop_t5_indices) == batch_size:
|
|
t5_out = nd_t5_out
|
|
else:
|
|
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
|
|
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
|
|
if len(non_drop_t5_indices) > 0:
|
|
t5_out[non_drop_t5_indices] = nd_t5_out
|
|
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
|
|
|
|
# masks are used for attention masking in transformer
|
|
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
|
|
|
def drop_cached_text_encoder_outputs(
|
|
self,
|
|
lg_out: torch.Tensor,
|
|
t5_out: torch.Tensor,
|
|
lg_pooled: torch.Tensor,
|
|
l_attn_mask: torch.Tensor,
|
|
g_attn_mask: torch.Tensor,
|
|
t5_attn_mask: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
|
|
if lg_out is not None:
|
|
for i in range(lg_out.shape[0]):
|
|
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
|
|
if drop_l:
|
|
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
|
|
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
|
|
if l_attn_mask is not None:
|
|
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
|
|
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
|
|
if drop_g:
|
|
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
|
|
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
|
|
if g_attn_mask is not None:
|
|
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
|
|
|
|
if t5_out is not None:
|
|
for i in range(t5_out.shape[0]):
|
|
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
|
|
if drop_t5:
|
|
t5_out[i] = torch.zeros_like(t5_out[i])
|
|
if t5_attn_mask is not None:
|
|
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
|
|
|
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
|
|
|
def concat_encodings(
|
|
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
|
if t5_out is None:
|
|
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
|
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
|
|
|
|
|
|
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
|
|
|
def __init__(
|
|
self,
|
|
cache_to_disk: bool,
|
|
batch_size: int,
|
|
skip_disk_cache_validity_check: bool,
|
|
is_partial: bool = False,
|
|
apply_lg_attn_mask: bool = False,
|
|
apply_t5_attn_mask: bool = False,
|
|
) -> None:
|
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
|
self.apply_lg_attn_mask = apply_lg_attn_mask
|
|
self.apply_t5_attn_mask = apply_t5_attn_mask
|
|
|
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
|
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
|
|
|
def is_disk_cached_outputs_expected(self, npz_path: str):
|
|
if not self.cache_to_disk:
|
|
return False
|
|
if not os.path.exists(npz_path):
|
|
return False
|
|
if self.skip_disk_cache_validity_check:
|
|
return True
|
|
|
|
try:
|
|
npz = np.load(npz_path)
|
|
if "lg_out" not in npz:
|
|
return False
|
|
if "lg_pooled" not in npz:
|
|
return False
|
|
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
|
return False
|
|
if "apply_lg_attn_mask" not in npz:
|
|
return False
|
|
if "t5_out" not in npz:
|
|
return False
|
|
if "t5_attn_mask" not in npz:
|
|
return False
|
|
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
|
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
|
return False
|
|
if "apply_t5_attn_mask" not in npz:
|
|
return False
|
|
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
|
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error loading file: {npz_path}")
|
|
raise e
|
|
|
|
return True
|
|
|
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
|
data = np.load(npz_path)
|
|
lg_out = data["lg_out"]
|
|
lg_pooled = data["lg_pooled"]
|
|
t5_out = data["t5_out"]
|
|
|
|
l_attn_mask = data["clip_l_attn_mask"]
|
|
g_attn_mask = data["clip_g_attn_mask"]
|
|
t5_attn_mask = data["t5_attn_mask"]
|
|
|
|
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
|
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
|
|
|
def cache_batch_outputs(
|
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
|
):
|
|
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
|
captions = [info.caption for info in infos]
|
|
|
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
|
with torch.no_grad():
|
|
# always disable dropout during caching
|
|
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy,
|
|
models,
|
|
tokens_and_masks,
|
|
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
|
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
|
enable_dropout=False,
|
|
)
|
|
|
|
if lg_out.dtype == torch.bfloat16:
|
|
lg_out = lg_out.float()
|
|
if lg_pooled.dtype == torch.bfloat16:
|
|
lg_pooled = lg_pooled.float()
|
|
if t5_out.dtype == torch.bfloat16:
|
|
t5_out = t5_out.float()
|
|
|
|
lg_out = lg_out.cpu().numpy()
|
|
lg_pooled = lg_pooled.cpu().numpy()
|
|
t5_out = t5_out.cpu().numpy()
|
|
|
|
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
|
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
|
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
|
|
|
for i, info in enumerate(infos):
|
|
lg_out_i = lg_out[i]
|
|
t5_out_i = t5_out[i]
|
|
lg_pooled_i = lg_pooled[i]
|
|
l_attn_mask_i = l_attn_mask[i]
|
|
g_attn_mask_i = g_attn_mask[i]
|
|
t5_attn_mask_i = t5_attn_mask[i]
|
|
apply_lg_attn_mask = self.apply_lg_attn_mask
|
|
apply_t5_attn_mask = self.apply_t5_attn_mask
|
|
|
|
if self.cache_to_disk:
|
|
np.savez(
|
|
info.text_encoder_outputs_npz,
|
|
lg_out=lg_out_i,
|
|
lg_pooled=lg_pooled_i,
|
|
t5_out=t5_out_i,
|
|
clip_l_attn_mask=l_attn_mask_i,
|
|
clip_g_attn_mask=g_attn_mask_i,
|
|
t5_attn_mask=t5_attn_mask_i,
|
|
apply_lg_attn_mask=apply_lg_attn_mask,
|
|
apply_t5_attn_mask=apply_t5_attn_mask,
|
|
)
|
|
else:
|
|
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
|
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
|
|
|
|
|
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
|
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
|
|
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
|
|
|
@property
|
|
def cache_suffix(self) -> str:
|
|
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
|
|
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
|
return (
|
|
os.path.splitext(absolute_path)[0]
|
|
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
|
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
|
)
|
|
|
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
|
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
|
|
|
def load_latents_from_disk(
|
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
|
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
|
|
|
# TODO remove circular dependency for ImageInfo
|
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
|
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
|
vae_device = vae.device
|
|
vae_dtype = vae.dtype
|
|
|
|
self._default_cache_batch_latents(
|
|
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
|
)
|
|
|
|
if not train_util.HIGH_VRAM:
|
|
train_util.clean_memory_on_device(vae.device)
|