fix vae type error during training sdxl

This commit is contained in:
BootsofLagrangian
2024-02-05 20:13:28 +09:00
parent 64873c1b43
commit 2824312d5e
3 changed files with 11 additions and 20 deletions

View File

@@ -17,7 +17,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:

View File

@@ -4042,28 +4042,23 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args,
weight_dtype,
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format

View File

@@ -392,23 +392,20 @@ def train(args):
if args.deepspeed:
# Wrapping model for DeepSpeed
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae) -> None:
def __init__(self, unet, text_encoder) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
def get_models(self):
return self.unet, self.text_encoders, self.vae
return self.unet, self.text_encoders
text_encoders = [text_encoder1, text_encoder2]
unet.to(accelerator.device, dtype=weight_dtype)
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae)
ds_model = DeepSpeedModel(unet, text_encoders)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
vae.to(vae_dtype) # to avoid explicitly half-vae
text_encoder1, text_encoder2 = text_encoders[0], text_encoders[1]
unet, text_encoders = ds_model.get_models() # for compatiblility
text_encoder1, text_encoder2 = text_encoder = text_encoders
training_models = [unet, text_encoder1, text_encoder2]
else: # acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
@@ -493,10 +490,10 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
with torch.no_grad(): # why this block differ within train_network.py?
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
@@ -504,7 +501,7 @@ def train(args):
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]