mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Refactor code to ensure args.guidance_scale is always a float #1525
This commit is contained in:
@@ -688,8 +688,8 @@ def train(args):
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
# get guidance
|
||||
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
||||
# get guidance: ensure args.guidance_scale is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
Reference in New Issue
Block a user