Cast weights to correct precision before transferring them to GPU

This commit is contained in:
Henrik Forstén
2023-07-13 12:45:28 +03:00
parent 8fa5fb2816
commit cdffd19f61

View File

@@ -358,6 +358,11 @@ class NetworkTrainer:
accelerator.print("enable full fp16 training.")
network.to(weight_dtype)
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet and train_text_encoder:
@@ -397,11 +402,6 @@ class NetworkTrainer:
text_encoders = train_util.transform_models_if_DDP(text_encoders)
unet, network = train_util.transform_models_if_DDP([unet, network])
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()