From a7ef6422b658660dd4c4685397f9b56e24996eb2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 10:00:30 +0900 Subject: [PATCH] fix to work with torch 2.0 --- train_network.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/train_network.py b/train_network.py index 5f28a5e0..b1291ed1 100644 --- a/train_network.py +++ b/train_network.py @@ -393,11 +393,9 @@ class NetworkTrainer: unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram if args.fp8_base: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( - torch.__version__ >= '2.1.0' - ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" - assert ( - args.mixed_precision != 'no' + args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training.") unet_weight_dtype = torch.float8_e4m3fn @@ -407,13 +405,12 @@ class NetworkTrainer: unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) - t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=( - weight_dtype - if te_weight_dtype == torch.float8_e4m3fn - else te_weight_dtype - )) + + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: @@ -805,7 +802,14 @@ class NetworkTrainer: # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: