diff --git a/fine_tune.py b/fine_tune.py index e2f2c8a9..a0350ce1 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -328,7 +328,7 @@ def train(args): with accelerator.accumulate(*training_models): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype) @@ -507,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser diff --git a/library/train_util.py b/library/train_util.py index e10e38da..6ca5d559 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -20,7 +20,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math import os diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index e99b4e35..9eaaa19f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -22,7 +22,7 @@ from accelerate.utils import set_seed import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util import library.model_util as model_util import library.train_util as train_util @@ -394,10 +394,10 @@ def train(args): with accelerator.accumulate(unet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -566,6 +566,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index dac56eed..e55a5889 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -18,7 +18,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util import library.model_util as model_util import library.train_util as train_util @@ -361,10 +361,10 @@ def train(args): with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -534,6 +534,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/train_controlnet.py b/train_controlnet.py index e44f0885..0cb0405f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,6 +11,7 @@ import toml from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -396,7 +397,7 @@ def train(args): with accelerator.accumulate(controlnet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -584,6 +585,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/train_db.py b/train_db.py index e27bc7ab..bc7ef4ed 100644 --- a/train_db.py +++ b/train_db.py @@ -319,7 +319,7 @@ def train(args): with torch.no_grad(): # latentに変換 if cache_latents: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 diff --git a/train_network.py b/train_network.py index a74a3caf..f2451f41 100644 --- a/train_network.py +++ b/train_network.py @@ -471,8 +471,7 @@ class NetworkTrainer: vae.to(accelerator.device, dtype=vae_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16 and not args.deepspeed: - # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする @@ -781,11 +780,11 @@ class NetworkTrainer: on_step_start(text_encoder, unet) if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: with torch.no_grad(): # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0266bc14..72e26a45 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,12 +8,13 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import model_util +from library import deepspeed_utils, model_util import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -558,10 +559,10 @@ class TextualInversionTrainer: with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) latents = latents * self.vae_scale_factor # Get the text embedding for conditioning @@ -749,6 +750,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ad7c267e..00b251ba 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,6 +8,7 @@ from multiprocessing import Value from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -439,7 +440,7 @@ def train(args): with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -662,6 +663,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False)