Support SD3.5M multi resolutional training

This commit is contained in:
Kohya S
2024-10-31 19:58:22 +09:00
parent 70a179e446
commit 1434d8506f
8 changed files with 215 additions and 10 deletions

View File

@@ -88,6 +88,78 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
return emb
def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
"""
This function is contributed by KohakuBlueleaf. Thanks for the contribution!
Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
when the resolution differs from the training resolution.
Args:
embed_dim (int): Dimension of the positional embedding.
grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
cls_token (bool): Whether to include class token. Defaults to False.
extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
base_size (int): Base grid size used during training. Defaults to 16.
Returns:
numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
(H*W + extra_tokens, embed_dim) if cls_token is True.
"""
# Convert grid_size to tuple if it's an integer
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
# Create normalized grid coordinates (0 to 1)
grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
# Calculate scaling factors for height and width
# This ensures that the central region matches the original resolution's embeddings
scale_h = base_size * grid_size[0] / (sample_size)
scale_w = base_size * grid_size[1] / (sample_size)
# Calculate shift values to center the original resolution's embedding region
# This ensures that the central sample_size x sample_size region has similar
# positional embeddings to the original resolution
shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
# Apply scaling and shifting to create the final grid coordinates
grid_h = grid_h * scale_h - shift_h
grid_w = grid_w * scale_w - shift_w
# Create 2D grid using meshgrid (note: w goes first)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
# # Calculate the starting indices for the central region
# # This is used for debugging/visualization of the central region
# st_h = (grid_size[0] - sample_size) // 2
# st_w = (grid_size[1] - sample_size) // 2
# print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
# Reshape grid for positional embedding calculation
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
# Generate the sinusoidal positional embeddings
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
# Add zeros for extra tokens (e.g., [CLS] token) if required
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
# if __name__ == "__main__":
# # This is what you get when you load SD3.5 state dict
# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
# 1536, [384, 384], sample_size=64, base_size=16
# )).float().unsqueeze(0)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
@@ -617,7 +689,7 @@ class MMDiTBlock(nn.Module):
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
self.head_dim = self.x_block.attn.head_dim
self.mode = self.x_block.attn_mode
self.gradient_checkpointing = False
@@ -669,6 +741,9 @@ class MMDiT(nn.Module):
Diffusion model with a Transformer backbone.
"""
# prepare pos_embed for latent size * 2
POS_EMBED_MAX_RATIO = 1.5
def __init__(
self,
input_size: int = 32,
@@ -697,6 +772,8 @@ class MMDiT(nn.Module):
x_block_self_attn_layers: Optional[list[int]] = [],
qkv_bias: bool = True,
pos_emb_random_crop_rate: float = 0.0,
use_scaled_pos_embed: bool = False,
pos_embed_latent_sizes: Optional[list[int]] = None,
model_type: str = "sd3m",
):
super().__init__()
@@ -722,6 +799,8 @@ class MMDiT(nn.Module):
self.num_heads = num_heads
self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
self.x_embedder = PatchEmbed(
input_size,
patch_size,
@@ -785,6 +864,43 @@ class MMDiT(nn.Module):
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
self.use_scaled_pos_embed = use_scaled_pos_embed
if self.use_scaled_pos_embed:
# # remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# sort latent sizes in ascending order
latent_sizes = sorted(latent_sizes)
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
# calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
max_areas = []
for i in range(1, len(patched_sizes)):
prev_area = patched_sizes[i - 1] ** 2
area = patched_sizes[i] ** 2
max_areas.append((prev_area + area) // 2)
# area of the last latent size, if the latent size exceeds this, error will be raised
max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
# print("max_areas", max_areas)
self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
self.resolution_pos_embeds = {}
for patched_size in patched_sizes:
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
self.resolution_pos_embeds[patched_size] = pos_embed
# print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
else:
self.resolution_area_to_latent_size = None
self.resolution_pos_embeds = None
@property
def model_type(self):
return self._model_type
@@ -884,6 +1000,54 @@ class MMDiT(nn.Module):
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
p = self.x_embedder.patch_size
# patched size
h = (h + 1) // p
w = (w + 1) // p
# select pos_embed size based on area
area = h * w
patched_size = None
for area_, patched_size_ in self.resolution_area_to_latent_size:
if area <= area_:
patched_size = patched_size_
break
if patched_size is None:
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
if h > pos_embed_size or w > pos_embed_size:
# fallback to normal pos_embed
logger.warning(
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
)
return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop)
if not random_crop:
top = (pos_embed_size - h) // 2
left = (pos_embed_size - w) // 2
else:
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
pos_embed = self.resolution_pos_embeds[patched_size]
if pos_embed.device != device:
pos_embed = pos_embed.to(device)
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
self.resolution_pos_embeds[patched_size] = pos_embed # update device
if pos_embed.dtype != dtype:
pos_embed = pos_embed.to(dtype)
self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
# print(
# f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
# )
return spatial_pos_embed
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
@@ -931,7 +1095,16 @@ class MMDiT(nn.Module):
)
B, C, H, W = x.shape
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
# x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
if not self.use_scaled_pos_embed:
pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
else:
# print(f"Using scaled pos_embed for size {H}x{W}")
pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
x = self.x_embedder(x) + pos_embed
del pos_embed
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)

View File

@@ -246,6 +246,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
)
parser.add_argument(
"--enable_scaled_pos_embed",
action="store_true",
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# copy from Diffusers
parser.add_argument(

View File

@@ -518,7 +518,7 @@ class LatentsCachingStrategy:
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL/SD3.0
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)

View File

@@ -212,7 +212,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
@@ -226,7 +226,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:

View File

@@ -399,7 +399,12 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -407,7 +412,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -2510,6 +2510,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def get_resolutions(self) -> List[Tuple[int, int]]:
return [(dataset.width, dataset.height) for dataset in self.datasets]
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])

View File

@@ -361,7 +361,14 @@ def train(args):
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
resolutions = train_dataset_group.get_resolutions()
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)
if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()

View File

@@ -26,8 +26,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
@@ -53,6 +53,9 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
self.resolutions = train_dataset_group.get_resolutions()
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
@@ -67,6 +70,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
self.model_type = mmdit.model_type
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)
if args.fp8_base:
# check dtype of model
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: