mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Cast weights to correct precision before transferring them to GPU
This commit is contained in:
@@ -358,6 +358,11 @@ class NetworkTrainer:
|
||||
accelerator.print("enable full fp16 training.")
|
||||
network.to(weight_dtype)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
||||
if train_unet and train_text_encoder:
|
||||
@@ -397,11 +402,6 @@ class NetworkTrainer:
|
||||
text_encoders = train_util.transform_models_if_DDP(text_encoders)
|
||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
|
||||
Reference in New Issue
Block a user