mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
make deepspeed_utils
This commit is contained in:
35
fine_tune.py
35
fine_tune.py
@@ -10,7 +10,9 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -42,6 +44,7 @@ from library.custom_train_functions import (
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
@@ -219,7 +222,7 @@ def train(args):
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
@@ -231,7 +234,7 @@ def train(args):
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
@@ -248,21 +251,16 @@ def train(args):
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
if args.deepspeed:
|
||||
training_models_dict = {}
|
||||
training_models_dict["unet"] = unet
|
||||
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder
|
||||
|
||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
training_models = []
|
||||
unet = ds_model.models["unet"]
|
||||
training_models.append(unet)
|
||||
if args.train_text_encoder:
|
||||
text_encoder = ds_model.models["text_encoder"]
|
||||
training_models.append(text_encoder)
|
||||
|
||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
@@ -327,13 +325,13 @@ def train(args):
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
@@ -493,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
|
||||
Reference in New Issue
Block a user