diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index b04b86fb..a0202ad4 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -214,6 +214,24 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) # copy from Diffusers parser.add_argument( diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index a27e99e6..d87ad7d1 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -1,5 +1,6 @@ import os import glob +import random from typing import Any, List, Optional, Tuple, Union import torch import numpy as np @@ -48,13 +49,23 @@ class Sd3TokenizeStrategy(TokenizeStrategy): class Sd3TextEncodingStrategy(TextEncodingStrategy): - def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None: + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: """ Args: apply_t5_attn_mask: Default value for apply_t5_attn_mask. """ self.apply_lg_attn_mask = apply_lg_attn_mask self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate def encode_tokens( self, @@ -63,6 +74,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): tokens: List[torch.Tensor], apply_lg_attn_mask: Optional[bool] = False, apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, ) -> List[torch.Tensor]: """ returned embeddings are not masked @@ -91,37 +103,92 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): g_attn_mask = None t5_attn_mask = None + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None else: - with torch.no_grad(): - assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + if drop_l: + l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) + l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) + if l_attn_mask is not None: + l_attn_mask = torch.zeros_like(l_attn_mask) + else: l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None - g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None - prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) l_pooled = prompt_embeds[0] l_out = prompt_embeds.hidden_states[-2] + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if drop_g: + g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + if g_attn_mask is not None: + g_attn_mask = torch.zeros_like(g_attn_mask) + else: + g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) g_pooled = prompt_embeds[0] g_out = prompt_embeds.hidden_states[-2] - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None - lg_out = torch.cat([l_out, g_out], dim=-1) + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is None or t5_tokens is None: t5_out = None else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None - with torch.no_grad(): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if drop_t5: + t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) + if t5_attn_mask is not None: + t5_attn_mask = torch.zeros_like(t5_attn_mask) + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) # masks are used for attention masking in transformer return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask + def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -207,8 +274,14 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # always disable dropout during caching lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, ) if lg_out.dtype == torch.bfloat16: diff --git a/sd3_train.py b/sd3_train.py index d12f7f56..cdac945e 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -69,6 +69,11 @@ def train(args): # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" @@ -232,7 +237,9 @@ def train(args): assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" # prepare text encoding strategy - text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate + ) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) # 学習を準備する:モデルを適切な状態にする @@ -311,6 +318,7 @@ def train(args): tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask, + enable_dropout=False, ) accelerator.wait_for_everyone() @@ -863,6 +871,7 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list if args.use_t5xxl_cache_only: lg_out = None @@ -878,7 +887,7 @@ def train(args): if lg_out is None: # not cached or training, so get from text encoders input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] - with torch.set_grad_enabled(args.train_text_encoder): + with torch.set_grad_enabled(train_clip): # TODO support weighted captions # text models in sd3_models require "cpu" for input_ids input_ids_clip_l = input_ids_clip_l.to("cpu") @@ -891,7 +900,7 @@ def train(args): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] - with torch.no_grad(): + with torch.set_grad_enabled(train_t5xxl): input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] diff --git a/sd3_train_network.py b/sd3_train_network.py index 129afed5..7b547127 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -120,7 +120,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5xxl_dropout_rate, + ) def post_process_network(self, args, accelerator, network, text_encoders, unet): # check t5xxl is trained or not @@ -408,6 +414,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/train_network.py b/train_network.py index aab1d84b..9d78a4ef 100644 --- a/train_network.py +++ b/train_network.py @@ -272,6 +272,9 @@ class NetworkTrainer: def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + # endregion def train(self, args): @@ -1030,9 +1033,9 @@ class NetworkTrainer: # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: - on_step_start = lambda *args, **kwargs: None + on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -1113,7 +1116,10 @@ class NetworkTrainer: continue with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -1146,6 +1152,7 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: