mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix all trainer about vae
This commit is contained in:
29
fine_tune.py
29
fine_tune.py
@@ -221,10 +221,18 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
if args.deepspeed:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"[DeepSpeed] override steps not dividing by {accelerator.num_processes}. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
else:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -244,21 +252,16 @@ def train(args):
|
||||
if args.deepspeed:
|
||||
# wrapping model
|
||||
class DeepSpeedModel(torch.nn.Module):
|
||||
def __init__(self, unet, text_encoder, vae) -> None:
|
||||
def __init__(self, unet, text_encoder) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||
self.vae = vae
|
||||
def get_models(self):
|
||||
return self.unet, self.text_encoders, self.vae
|
||||
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||
return self.unet, self.text_encoders
|
||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||
vae.to(vae_dtype)
|
||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
||||
text_encoder = text_encoders
|
||||
|
||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||
|
||||
29
train_db.py
29
train_db.py
@@ -190,10 +190,18 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
if args.deepspeed:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"[DeepSpeed] override steps not dividing by {accelerator.num_processes}. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
else:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -217,22 +225,17 @@ def train(args):
|
||||
if args.deepspeed:
|
||||
# wrapping model
|
||||
class DeepSpeedModel(torch.nn.Module):
|
||||
def __init__(self, unet, text_encoder, vae) -> None:
|
||||
def __init__(self, unet, text_encoder) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||
self.vae = vae
|
||||
|
||||
def get_models(self):
|
||||
return self.unet, self.text_encoders, self.vae
|
||||
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||
return self.unet, self.text_encoders
|
||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
if train_text_encoder:
|
||||
|
||||
@@ -364,7 +364,7 @@ class NetworkTrainer:
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
f"[DeepSpeed] override steps not dividing by {accelerator.num_processes}. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
else:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
@@ -420,23 +420,18 @@ class NetworkTrainer:
|
||||
if args.deepspeed:
|
||||
# wrapping model
|
||||
class DeepSpeedModel(torch.nn.Module):
|
||||
def __init__(self, unet, text_encoder, vae, network) -> None:
|
||||
def __init__(self, unet, text_encoder, network) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||
self.vae = vae
|
||||
self.network = network
|
||||
|
||||
def get_models(self):
|
||||
return self.unet, self.text_encoders, self.vae, self.network
|
||||
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype)
|
||||
[t_enc.to(accelerator.device, dtype=te_weight_dtype) for t_enc in text_encoders]
|
||||
ds_model = DeepSpeedModel(unet, text_encoders, vae, network)
|
||||
return self.unet, self.text_encoders, self.network
|
||||
ds_model = DeepSpeedModel(unet, text_encoders, network)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||
unet, text_encoders, vae, network = ds_model.get_models() # for compatiblility
|
||||
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||
unet, text_encoders, network = ds_model.get_models() # for compatiblility
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
if train_unet:
|
||||
|
||||
Reference in New Issue
Block a user