From 938bd71844e24f08b6483717b1a13aab9cb83657 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 18:31:27 +0800 Subject: [PATCH 1/9] lower ram usage --- train_network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3e8f4e7d..710055e0 100644 --- a/train_network.py +++ b/train_network.py @@ -150,7 +150,9 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - + # unnecessary, but work on low-ram device + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) From fb312acb7f07605e46662099f62181de197fb490 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 18:54:55 +0800 Subject: [PATCH 2/9] support DistributedDataParallel --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 710055e0..c2f9cbf6 100644 --- a/train_network.py +++ b/train_network.py @@ -267,6 +267,14 @@ def train(args): unet.eval() text_encoder.eval() + # support DistributedDataParallel + try: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + except: + pass + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: From c0be52a7731b58e305ae22ed26e57af6b7d61f5a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:05:39 +0800 Subject: [PATCH 3/9] ignore get_hidden_states expected scalar Error --- library/train_util.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6f809deb..dc0724d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1560,9 +1560,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # uncomment code may raise expected scalar type Half but found Float when using DDP + # if weight_dtype is not None: + # # this is required for additional network training + # encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 From 5e96e1369da410e52725e2c7af8b9f9def956c4b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:23:39 +0800 Subject: [PATCH 4/9] fix get_hidden_states expected scalar Error --- library/train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index dc0724d7..4c410567 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1560,10 +1560,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - # uncomment code may raise expected scalar type Half but found Float when using DDP - # if weight_dtype is not None: - # # this is required for additional network training - # encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 @@ -1589,6 +1585,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) return encoder_hidden_states From b599adc938de5a9d728dd73dc311ab95e1f786e0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:34:03 +0800 Subject: [PATCH 5/9] fix Input type error when using DDP --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index c2f9cbf6..fc387bc3 100644 --- a/train_network.py +++ b/train_network.py @@ -1,5 +1,6 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer +from torch.cuda.amp import autocast from typing import Optional, Union import importlib import argparse From 6473aa1dd7cd927852957b977cfdf4f86b4f17b9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 21:32:21 +0800 Subject: [PATCH 6/9] fix Input type error in noise_pred when using DDP --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index fc387bc3..59f74211 100644 --- a/train_network.py +++ b/train_network.py @@ -417,7 +417,8 @@ def train(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + with autocast(): + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample From 9a9ac79edff44e9c0cdb28d127f62fe5ce0cca07 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 22:30:20 +0800 Subject: [PATCH 7/9] correct wrong inserted code for noise_pred fix --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 59f74211..f247c74e 100644 --- a/train_network.py +++ b/train_network.py @@ -417,11 +417,11 @@ def train(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - with autocast(): - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training From b8ad17902f91d0293daf6a685224d2bb59f9b301 Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Wed, 8 Feb 2023 23:09:59 +0800 Subject: [PATCH 8/9] fix get_hidden_states expected scalar Error again --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f247c74e..928ad321 100644 --- a/train_network.py +++ b/train_network.py @@ -257,7 +257,7 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() From 2b1a3080e7ddc329e3a3bf59126d8ccac80d0dae Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 12 Feb 2023 15:32:38 +0800 Subject: [PATCH 9/9] Add type checking --- train_network.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index 852aea8e..90771b31 100644 --- a/train_network.py +++ b/train_network.py @@ -267,18 +267,19 @@ def train(args): text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() # support DistributedDataParallel - try: - text_encoder = text_encoder.module - unet = unet.module - network = network.module - except: - pass + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module network.prepare_grad_etc(text_encoder, unet)