From cdffd19f61d8b6d8c22f2778a4e3eb3bd44aeb48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Thu, 13 Jul 2023 12:45:28 +0300 Subject: [PATCH] Cast weights to correct precision before transferring them to GPU --- train_network.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index f7ee451b..e4c00524 100644 --- a/train_network.py +++ b/train_network.py @@ -358,6 +358,11 @@ class NetworkTrainer: accelerator.print("enable full fp16 training.") network.to(weight_dtype) + unet.requires_grad_(False) + unet.to(dtype=weight_dtype) + for t_enc in text_encoders: + t_enc.requires_grad_(False) + # acceleratorがなんかよろしくやってくれるらしい # TODO めちゃくちゃ冗長なのでコードを整理する if train_unet and train_text_encoder: @@ -397,11 +402,6 @@ class NetworkTrainer: text_encoders = train_util.transform_models_if_DDP(text_encoders) unet, network = train_util.transform_models_if_DDP([unet, network]) - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - for t_enc in text_encoders: - t_enc.requires_grad_(False) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train()