From e3ccf8fbf73a0f728fc167a20b1e0648a3604f41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Feb 2024 21:30:46 +0900 Subject: [PATCH] make deepspeed_utils --- fine_tune.py | 35 +++++----- library/deepspeed_utils.py | 139 +++++++++++++++++++++++++++++++++++++ library/train_util.py | 110 ++--------------------------- sdxl_train.py | 66 ++++++++---------- train_db.py | 37 +++++----- train_network.py | 51 +++++++------- 6 files changed, 238 insertions(+), 200 deletions(-) create mode 100644 library/deepspeed_utils.py diff --git a/fine_tune.py b/fine_tune.py index c5e97d26..b018a933 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,9 @@ import toml from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -42,6 +44,7 @@ from library.custom_train_functions import ( def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -219,7 +222,7 @@ def train(args): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -231,7 +234,7 @@ def train(args): accelerator.print( f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" ) - + # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -248,21 +251,16 @@ def train(args): text_encoder.to(weight_dtype) if args.deepspeed: - training_models_dict = {} - training_models_dict["unet"] = unet - if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder - - ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - - training_models = [] - unet = ds_model.models["unet"] - training_models.append(unet) if args.train_text_encoder: - text_encoder = ds_model.models["text_encoder"] - training_models.append(text_encoder) - - else: # acceleratorがなんかよろしくやってくれるらしい + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + else: + # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -327,13 +325,13 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with accelerator.accumulate(*training_models): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype) latents = latents * 0.18215 b_size = latents.shape[0] @@ -493,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py new file mode 100644 index 00000000..99a7b2b3 --- /dev/null +++ b/library/deepspeed_utils.py @@ -0,0 +1,139 @@ +import os +import argparse +import torch +from accelerate import DeepSpeedPlugin, Accelerator + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_deepspeed_arguments(parser: argparse.ArgumentParser): + # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed + parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") + parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") + parser.add_argument( + "--offload_optimizer_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", + ) + parser.add_argument( + "--offload_optimizer_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--zero3_init_flag", + action="store_true", + help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--zero3_save_16bit_model", + action="store_true", + help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--fp16_master_weights_and_gradients", + action="store_true", + help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", + ) + + +def prepare_deepspeed_args(args: argparse.Namespace): + if not args.deepspeed: + return + + # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + args.max_data_loader_n_workers = 1 + + +def prepare_deepspeed_plugin(args: argparse.Namespace): + if not args.deepspeed: + return None + + try: + import deepspeed + except ImportError as e: + logger.error( + "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" + ) + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, + offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, + offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, + zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + deepspeed_plugin.deepspeed_config["train_batch_size"] = ( + args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) + ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True + logger.info("[DeepSpeed] full fp16 enable.") + else: + logger.info( + "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." + ) + + if args.offload_optimizer_device is not None: + logger.info("[DeepSpeed] start to manually build cpu_adam.") + deepspeed.ops.op_builder.CPUAdamBuilder().load() + logger.info("[DeepSpeed] building cpu_adam done.") + + return deepspeed_plugin + + +# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. +def prepare_deepspeed_model(args: argparse.Namespace, **models): + # remove None from models + models = {k: v for k, v in models.items() if v is not None} + + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance( + model, torch.nn.Module + ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model diff --git a/library/train_util.py b/library/train_util.py index 3781dcde..38e1b458 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -21,7 +21,6 @@ from typing import ( Union, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs -from accelerate import DeepSpeedPlugin import glob import math import os @@ -70,6 +69,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging setup_logging() @@ -3243,52 +3243,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed - parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") - parser.add_argument( - "--zero_stage", - type=int, default=2, - choices=[0, 1, 2, 3], - help="Possible options are 0,1,2,3." - ) - parser.add_argument( - "--offload_optimizer_device", - type=str, default=None, - choices=[None, "cpu", "nvme"], - help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3." - ) - parser.add_argument( - "--offload_optimizer_nvme_path", - type=str, default=None, - help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3." - ) - parser.add_argument( - "--offload_param_device", - type=str, default=None, - choices=[None, "cpu", "nvme"], - help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3." - ) - parser.add_argument( - "--offload_param_nvme_path", - type=str, default=None, - help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3." - ) - parser.add_argument( - "--zero3_init_flag", - action="store_true", - help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." - "Only applicable with ZeRO Stage-3." - ) - parser.add_argument( - "--zero3_save_16bit_model", - action="store_true", - help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3." - ) - parser.add_argument( - "--fp16_master_weights_and_gradients", - action="store_true", - help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32." - ) def verify_training_args(args: argparse.Namespace): r""" @@ -4090,6 +4044,10 @@ def load_tokenizer(args: argparse.Namespace): def prepare_accelerator(args: argparse.Namespace): + """ + this function also prepares deepspeed plugin + """ + if args.logging_dir is None: logging_dir = None else: @@ -4135,7 +4093,7 @@ def prepare_accelerator(args: argparse.Namespace): ), ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) - deepspeed_plugin = prepare_deepspeed_plugin(args) + deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -4149,62 +4107,6 @@ def prepare_accelerator(args: argparse.Namespace): print("accelerator device:", accelerator.device) return accelerator -def prepare_deepspeed_plugin(args: argparse.Namespace): - if args.deepspeed is None: return None - try: - import deepspeed - except ImportError as e: - print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed") - exit(1) - - deepspeed_plugin = DeepSpeedPlugin( - zero_stage=args.zero_stage, - gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm, - offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, - offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path, - zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model, - ) - deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size - deepspeed_plugin.deepspeed_config['train_batch_size'] = \ - args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE']) - deepspeed_plugin.set_mixed_precision(args.mixed_precision) - if args.mixed_precision.lower() == "fp16": - deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow. - if args.full_fp16 or args.fp16_master_weights_and_gradients: - if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: - deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True - print("[DeepSpeed] full fp16 enable.") - else: - print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.") - - if args.offload_optimizer_device is not None: - print('[DeepSpeed] start to manually build cpu_adam.') - deepspeed.ops.op_builder.CPUAdamBuilder().load() - print('[DeepSpeed] building cpu_adam done.') - - return deepspeed_plugin - -def prepare_deepspeed_model(args: argparse.Namespace, **models): - class DeepSpeedWrapper(torch.nn.Module): - def __init__(self, **kw_models) -> None: - super().__init__() - self.models = torch.nn.ModuleDict() - - for key, model in kw_models.items(): - if isinstance(model, list): - model = torch.nn.ModuleList(model) - assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" - self.models.update( - torch.nn.ModuleDict( - {key: model} - ) - ) - - def get_models(self): - return self.models - - ds_model = DeepSpeedWrapper(**models) - return ds_model def prepare_dtype(args: argparse.Namespace): weight_dtype = torch.float32 diff --git a/sdxl_train.py b/sdxl_train.py index 5e5e9f29..0feb4e36 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,11 +11,12 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import sdxl_model_util +from library import deepspeed_utils, sdxl_model_util import library.train_util as train_util @@ -97,6 +98,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) assert ( @@ -361,7 +363,7 @@ def train(args): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -398,41 +400,31 @@ def train(args): text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) - if args.deepspeed: - training_models_dict = {} - if train_unet: - training_models_dict["unet"] = unet - if train_text_encoder1: - text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - text_encoder1.text_model.final_layer_norm.requires_grad_(False) - training_models_dict["text_encoder1"] = text_encoder1 - if train_text_encoder2: - training_models_dict["text_encoder2"] = text_encoder2 - ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - - training_models = [] # override training_models - if train_unet: - unet = ds_model.models["unet"] - training_models.append(unet) - if train_text_encoder1: - text_encoder1 = ds_model.models["text_encoder1"] - training_models.append(text_encoder1) - if train_text_encoder2: - text_encoder2 = ds_model.models["text_encoder2"] - training_models.append(text_encoder2) + # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + if train_text_encoder1: + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) - else: # acceleratorがなんかよろしくやってくれるらしい + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoder1 if train_text_encoder1 else None, + text_encoder2=text_encoder2 if train_text_encoder2 else None, + ) + ds_model = accelerator.prepare(ds_model) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい if train_unet: unet = accelerator.prepare(unet) if train_text_encoder1: - # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer - text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - text_encoder1.text_model.final_layer_norm.requires_grad_(False) text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -446,8 +438,9 @@ def train(args): text_encoder2.to(accelerator.device) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16 and not args.deepspeed: + if args.full_fp16: # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする @@ -508,10 +501,10 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(*training_models): - with torch.no_grad(): # why this block differ within train_network.py? - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): # latentに変換 latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) @@ -519,7 +512,7 @@ def train(args): if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: input_ids1 = batch["input_ids"] @@ -768,6 +761,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/train_db.py b/train_db.py index 66a83d1d..ea1cfeb8 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,9 @@ import toml from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -46,6 +48,7 @@ logger = logging.getLogger(__name__) def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -187,7 +190,7 @@ def train(args): batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -220,30 +223,27 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if args.deepspeed: - training_models_dict = {} - training_models_dict["unet"] = unet - if train_text_encoder: training_models_dict["text_encoder"] = text_encoder + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] - ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - - training_models = [] - unet = ds_model.models["unet"] - training_models.append(unet) - if train_text_encoder: - text_encoder = ds_model.models["text_encoder"] - training_models.append(text_encoder) - else: if train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) + training_models = [unet, text_encoder] else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + training_models = [unet] - if not train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -312,8 +312,10 @@ def train(args): if not args.gradient_checkpointing: text_encoder.train(False) text_encoder.requires_grad_(False) + if len(training_models) == 2: + training_models = training_models[0] # remove text_encoder from training_models - with accelerator.accumulate(unet): + with accelerator.accumulate(*training_models): with torch.no_grad(): # latentに変換 if cache_latents: @@ -480,6 +482,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/train_network.py b/train_network.py index af1b7f63..a6ce169a 100644 --- a/train_network.py +++ b/train_network.py @@ -13,13 +13,14 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import model_util +from library import deepspeed_utils, model_util import library.train_util as train_util from library.train_util import ( @@ -141,6 +142,7 @@ class NetworkTrainer: training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -357,7 +359,7 @@ class NetworkTrainer: batch_size=1, shuffle=True, collate_fn=collator, - num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -414,22 +416,17 @@ class NetworkTrainer: # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: - training_models_dict = {} - if train_unet: training_models_dict["unet"] = unet - if train_text_encoder: training_models_dict["text_encoder"] = text_encoders - training_models_dict["network"] = network - - ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - - if train_unet: unet = ds_model.models["unet"] - if train_text_encoder: - text_encoder = ds_model.models["text_encoder"] - if len(ds_model.models["text_encoder"]) > 1: - text_encoders = text_encoder - else: - text_encoders = [text_encoder] - + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoders[0] if train_text_encoder else None, + text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + network=network, + ) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_model = ds_model else: if train_unet: unet = accelerator.prepare(unet) @@ -444,7 +441,10 @@ class NetworkTrainer: else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required @@ -777,13 +777,13 @@ class NetworkTrainer: for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(network): + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + with torch.no_grad(): # latentに変換 latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() @@ -791,7 +791,7 @@ class NetworkTrainer: if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor + latents = latents * self.vae_scale_factor # get multiplier for each sample if network_has_multiplier: @@ -976,6 +976,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser)