Refactor code to ensure args.guidance_scale is always a float #1525

This commit is contained in:
Kohya S
2024-08-29 22:10:57 +09:00
parent 930d709e3d
commit 8ecf0fc4bf

View File

@@ -688,8 +688,8 @@ def train(args):
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds