make deepspeed_utils

This commit is contained in:
Kohya S
2024-02-27 21:30:46 +09:00
parent 0e4a5738df
commit e3ccf8fbf7
6 changed files with 238 additions and 200 deletions

View File

@@ -10,7 +10,9 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -42,6 +44,7 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -219,7 +222,7 @@ def train(args):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -231,7 +234,7 @@ def train(args):
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -248,21 +251,16 @@ def train(args):
text_encoder.to(weight_dtype)
if args.deepspeed:
training_models_dict = {}
training_models_dict["unet"] = unet
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if args.train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)
else: # acceleratorがなんかよろしくやってくれるらしい
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
@@ -327,13 +325,13 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]
@@ -493,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)