diff --git a/library/train_util.py b/library/train_util.py index 24e15d1f..f67a576d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1620,9 +1620,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 @@ -1648,6 +1645,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) return encoder_hidden_states diff --git a/train_network.py b/train_network.py index bb3159fd..90771b31 100644 --- a/train_network.py +++ b/train_network.py @@ -1,5 +1,6 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer +from torch.cuda.amp import autocast from typing import Optional, Union import importlib import argparse @@ -154,7 +155,9 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - + # unnecessary, but work on low-ram device + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -258,17 +261,26 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: @@ -415,7 +427,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training