mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix vae type error during training sdxl
This commit is contained in:
@@ -17,7 +17,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
|
||||
@@ -4042,28 +4042,23 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||
args,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
|
||||
@@ -392,23 +392,20 @@ def train(args):
|
||||
if args.deepspeed:
|
||||
# Wrapping model for DeepSpeed
|
||||
class DeepSpeedModel(torch.nn.Module):
|
||||
def __init__(self, unet, text_encoder, vae) -> None:
|
||||
def __init__(self, unet, text_encoder) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||
self.vae = vae
|
||||
|
||||
def get_models(self):
|
||||
return self.unet, self.text_encoders, self.vae
|
||||
return self.unet, self.text_encoders
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||
text_encoder1, text_encoder2 = text_encoders[0], text_encoders[1]
|
||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
||||
text_encoder1, text_encoder2 = text_encoder = text_encoders
|
||||
training_models = [unet, text_encoder1, text_encoder2]
|
||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
@@ -493,10 +490,10 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(): # why this block differ within train_network.py?
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
@@ -504,7 +501,7 @@ def train(args):
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
|
||||
Reference in New Issue
Block a user