fix sampling gen fails in lora training

This commit is contained in:
Kohya S
2023-07-13 19:02:34 +09:00
parent 8fa5fb2816
commit 3bb80ebf20

View File

@@ -1005,6 +1005,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
# perform guidance
if do_classifier_free_guidance: