From 8856c19c76b7b2b4f8b288ac8d28034f810224c2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Jul 2023 14:19:25 +0900 Subject: [PATCH] fix batch generation not working --- sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index d2f59a33..6578b9a8 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -511,7 +511,7 @@ class PipelineLike: emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype) c_vector = torch.cat([text_pool, c_vector], dim=1)