fix batch generation not working

This commit is contained in:
Kohya S
2023-07-30 14:19:25 +09:00
parent 0eacadfa99
commit 8856c19c76

View File

@@ -511,7 +511,7 @@ class PipelineLike:
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype)
c_vector = torch.cat([text_pool, c_vector], dim=1)