Set dtype if text encoder is not trained at all

This commit is contained in:
Kohya S
2023-01-07 21:43:27 +09:00
parent 9f1d3aca24
commit 80af4c0c42

View File

@@ -98,6 +98,8 @@ def train(args):
train_text_encoder = args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@@ -153,6 +155,9 @@ def train(args):
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16: