remove workaround for accelerator=0.15, fix XTI

This commit is contained in:
ykume
2023-06-11 18:32:14 +09:00
parent 33a6234b52
commit 0315611b11
7 changed files with 153 additions and 159 deletions

View File

@@ -91,7 +91,7 @@ def train(args):
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -385,8 +385,8 @@ def train(args):
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)
@@ -428,8 +428,8 @@ def train(args):
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)
@@ -437,8 +437,8 @@ def train(args):
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
unet = accelerator.unwrap_model(unet)
text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training()