diff --git a/train_network.py b/train_network.py index f7ee451b..e4c00524 100644 --- a/train_network.py +++ b/train_network.py @@ -358,6 +358,11 @@ class NetworkTrainer: accelerator.print("enable full fp16 training.") network.to(weight_dtype) + unet.requires_grad_(False) + unet.to(dtype=weight_dtype) + for t_enc in text_encoders: + t_enc.requires_grad_(False) + # acceleratorがなんかよろしくやってくれるらしい # TODO めちゃくちゃ冗長なのでコードを整理する if train_unet and train_text_encoder: @@ -397,11 +402,6 @@ class NetworkTrainer: text_encoders = train_util.transform_models_if_DDP(text_encoders) unet, network = train_util.transform_models_if_DDP([unet, network]) - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - for t_enc in text_encoders: - t_enc.requires_grad_(False) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train()