Add guidance scale for prompt param and flux sampling

This commit is contained in:
Kohya S
2025-04-16 21:50:36 +09:00
parent 06df0377f9
commit 629073cd9d
2 changed files with 12 additions and 3 deletions

View File

@@ -154,6 +154,7 @@ def sample_image_inference(
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
guidance_scale = prompt_dict.get("guidance_scale", args.guidance_scale)
scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
@@ -180,9 +181,12 @@ def sample_image_inference(
logger.info(f"prompt: {prompt}")
if scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"guidance_scale: {guidance_scale}")
if scale != 1.0:
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
@@ -256,7 +260,7 @@ def sample_image_inference(
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=scale,
guidance=guidance_scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,
@@ -489,7 +493,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
@@ -514,7 +518,7 @@ def get_noisy_model_input_and_timesteps(
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)

View File

@@ -6178,6 +6178,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["scale"] = float(m.group(1))
continue
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
if m: # guidance scale
prompt_dict["guidance_scale"] = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict["negative_prompt"] = m.group(1)