fix all trainer about vae

This commit is contained in:
BootsofLagrangian
2024-02-05 20:19:56 +09:00
parent 2824312d5e
commit 4295f91dcd
3 changed files with 37 additions and 36 deletions

View File

@@ -221,10 +221,18 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if args.deepspeed:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
accelerator.print(
f"[DeepSpeed] override steps not dividing by {accelerator.num_processes}. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
else:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -244,21 +252,16 @@ def train(args):
if args.deepspeed:
# wrapping model
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae) -> None:
def __init__(self, unet, text_encoder) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
def get_models(self):
return self.unet, self.text_encoders, self.vae
unet.to(accelerator.device, dtype=weight_dtype)
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae)
return self.unet, self.text_encoders
ds_model = DeepSpeedModel(unet, text_encoders)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
vae.to(vae_dtype)
unet, text_encoders = ds_model.get_models() # for compatiblility
text_encoder = text_encoders
else: # acceleratorがなんかよろしくやってくれるらしい