mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Support SD3.5M multi resolutional training
This commit is contained in:
@@ -88,6 +88,78 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
return emb
|
||||
|
||||
|
||||
def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
|
||||
"""
|
||||
This function is contributed by KohakuBlueleaf. Thanks for the contribution!
|
||||
|
||||
Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
|
||||
when the resolution differs from the training resolution.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Dimension of the positional embedding.
|
||||
grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
|
||||
cls_token (bool): Whether to include class token. Defaults to False.
|
||||
extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
|
||||
sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
|
||||
base_size (int): Base grid size used during training. Defaults to 16.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
|
||||
(H*W + extra_tokens, embed_dim) if cls_token is True.
|
||||
"""
|
||||
# Convert grid_size to tuple if it's an integer
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
# Create normalized grid coordinates (0 to 1)
|
||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
|
||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
|
||||
|
||||
# Calculate scaling factors for height and width
|
||||
# This ensures that the central region matches the original resolution's embeddings
|
||||
scale_h = base_size * grid_size[0] / (sample_size)
|
||||
scale_w = base_size * grid_size[1] / (sample_size)
|
||||
|
||||
# Calculate shift values to center the original resolution's embedding region
|
||||
# This ensures that the central sample_size x sample_size region has similar
|
||||
# positional embeddings to the original resolution
|
||||
shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
|
||||
shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
|
||||
|
||||
# Apply scaling and shifting to create the final grid coordinates
|
||||
grid_h = grid_h * scale_h - shift_h
|
||||
grid_w = grid_w * scale_w - shift_w
|
||||
|
||||
# Create 2D grid using meshgrid (note: w goes first)
|
||||
grid = np.meshgrid(grid_w, grid_h)
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
# # Calculate the starting indices for the central region
|
||||
# # This is used for debugging/visualization of the central region
|
||||
# st_h = (grid_size[0] - sample_size) // 2
|
||||
# st_w = (grid_size[1] - sample_size) // 2
|
||||
# print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
|
||||
|
||||
# Reshape grid for positional embedding calculation
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
|
||||
# Generate the sinusoidal positional embeddings
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
|
||||
# Add zeros for extra tokens (e.g., [CLS] token) if required
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # This is what you get when you load SD3.5 state dict
|
||||
# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
|
||||
# 1536, [384, 384], sample_size=64, base_size=16
|
||||
# )).float().unsqueeze(0)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
@@ -669,6 +741,9 @@ class MMDiT(nn.Module):
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
# prepare pos_embed for latent size * 2
|
||||
POS_EMBED_MAX_RATIO = 1.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 32,
|
||||
@@ -697,6 +772,8 @@ class MMDiT(nn.Module):
|
||||
x_block_self_attn_layers: Optional[list[int]] = [],
|
||||
qkv_bias: bool = True,
|
||||
pos_emb_random_crop_rate: float = 0.0,
|
||||
use_scaled_pos_embed: bool = False,
|
||||
pos_embed_latent_sizes: Optional[list[int]] = None,
|
||||
model_type: str = "sd3m",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -722,6 +799,8 @@ class MMDiT(nn.Module):
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
@@ -785,6 +864,43 @@ class MMDiT(nn.Module):
|
||||
self.blocks_to_swap = None
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
|
||||
self.use_scaled_pos_embed = use_scaled_pos_embed
|
||||
|
||||
if self.use_scaled_pos_embed:
|
||||
# # remove pos_embed to free up memory up to 0.4 GB
|
||||
self.pos_embed = None
|
||||
|
||||
# sort latent sizes in ascending order
|
||||
latent_sizes = sorted(latent_sizes)
|
||||
|
||||
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
|
||||
|
||||
# calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
|
||||
max_areas = []
|
||||
for i in range(1, len(patched_sizes)):
|
||||
prev_area = patched_sizes[i - 1] ** 2
|
||||
area = patched_sizes[i] ** 2
|
||||
max_areas.append((prev_area + area) // 2)
|
||||
|
||||
# area of the last latent size, if the latent size exceeds this, error will be raised
|
||||
max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
|
||||
# print("max_areas", max_areas)
|
||||
|
||||
self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
|
||||
|
||||
self.resolution_pos_embeds = {}
|
||||
for patched_size in patched_sizes:
|
||||
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
||||
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||
self.resolution_pos_embeds[patched_size] = pos_embed
|
||||
# print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
|
||||
|
||||
else:
|
||||
self.resolution_area_to_latent_size = None
|
||||
self.resolution_pos_embeds = None
|
||||
|
||||
@property
|
||||
def model_type(self):
|
||||
return self._model_type
|
||||
@@ -884,6 +1000,54 @@ class MMDiT(nn.Module):
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
return spatial_pos_embed
|
||||
|
||||
def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
|
||||
p = self.x_embedder.patch_size
|
||||
# patched size
|
||||
h = (h + 1) // p
|
||||
w = (w + 1) // p
|
||||
|
||||
# select pos_embed size based on area
|
||||
area = h * w
|
||||
patched_size = None
|
||||
for area_, patched_size_ in self.resolution_area_to_latent_size:
|
||||
if area <= area_:
|
||||
patched_size = patched_size_
|
||||
break
|
||||
if patched_size is None:
|
||||
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
||||
|
||||
pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
||||
if h > pos_embed_size or w > pos_embed_size:
|
||||
# fallback to normal pos_embed
|
||||
logger.warning(
|
||||
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
||||
)
|
||||
return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop)
|
||||
|
||||
if not random_crop:
|
||||
top = (pos_embed_size - h) // 2
|
||||
left = (pos_embed_size - w) // 2
|
||||
else:
|
||||
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
|
||||
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
|
||||
|
||||
pos_embed = self.resolution_pos_embeds[patched_size]
|
||||
if pos_embed.device != device:
|
||||
pos_embed = pos_embed.to(device)
|
||||
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
|
||||
self.resolution_pos_embeds[patched_size] = pos_embed # update device
|
||||
if pos_embed.dtype != dtype:
|
||||
pos_embed = pos_embed.to(dtype)
|
||||
self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
|
||||
|
||||
spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
# print(
|
||||
# f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
|
||||
# )
|
||||
return spatial_pos_embed
|
||||
|
||||
def enable_block_swap(self, num_blocks: int):
|
||||
self.blocks_to_swap = num_blocks
|
||||
|
||||
@@ -931,7 +1095,16 @@ class MMDiT(nn.Module):
|
||||
)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
||||
|
||||
# x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
||||
if not self.use_scaled_pos_embed:
|
||||
pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
||||
else:
|
||||
# print(f"Using scaled pos_embed for size {H}x{W}")
|
||||
pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
|
||||
x = self.x_embedder(x) + pos_embed
|
||||
del pos_embed
|
||||
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
|
||||
@@ -246,6 +246,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
|
||||
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_scaled_pos_embed",
|
||||
action="store_true",
|
||||
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
|
||||
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
|
||||
@@ -518,7 +518,7 @@ class LatentsCachingStrategy:
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL/SD3.0
|
||||
for SD/SDXL
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
|
||||
|
||||
@@ -212,7 +212,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True)
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
@@ -226,7 +226,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
|
||||
@@ -399,7 +399,12 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -407,7 +412,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -2510,6 +2510,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.verify_bucket_reso_steps(min_steps)
|
||||
|
||||
def get_resolutions(self) -> List[Tuple[int, int]]:
|
||||
return [(dataset.width, dataset.height) for dataset in self.datasets]
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
|
||||
@@ -362,6 +362,13 @@ def train(args):
|
||||
|
||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||
|
||||
# set resolutions for positional embeddings
|
||||
if args.enable_scaled_pos_embed:
|
||||
resolutions = train_dataset_group.get_resolutions()
|
||||
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
|
||||
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
|
||||
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
mmdit.enable_gradient_checkpointing()
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
@@ -53,6 +53,9 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
# enumerate resolutions from dataset for positional embeddings
|
||||
self.resolutions = train_dataset_group.get_resolutions()
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
@@ -67,6 +70,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
self.model_type = mmdit.model_type
|
||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||
|
||||
# set resolutions for positional embeddings
|
||||
if args.enable_scaled_pos_embed:
|
||||
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
|
||||
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
|
||||
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
|
||||
|
||||
Reference in New Issue
Block a user