Update train_network.py

This commit is contained in:
gesen2egee
2024-03-11 19:15:55 +08:00
parent befbec5335
commit 63e58f78e3

View File

@@ -178,8 +178,7 @@ class NetworkTrainer:
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype