enable comment in prompt file, record raw prompt to metadata

This commit is contained in:
Kohya S
2023-12-12 08:20:36 +09:00
parent 07ef03d340
commit d61ecb26fd
2 changed files with 32 additions and 12 deletions

View File

@@ -2184,6 +2184,7 @@ class BatchDataBase(NamedTuple):
mask_image: Any
clip_prompt: str
guide_image: Any
raw_prompt: str
class BatchDataExt(NamedTuple):
@@ -2710,7 +2711,7 @@ def main(args):
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
elif args.prompt is not None:
prompt_list = [args.prompt]
else:
@@ -2954,13 +2955,14 @@ def main(args):
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image),
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = []
negative_prompts = []
raw_prompts = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2991,11 +2993,16 @@ def main(args):
all_images_are_same = True
all_masks_are_same = True
all_guide_images_are_same = True
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
_,
) in enumerate(batch):
prompts.append(prompt)
negative_prompts.append(negative_prompt)
seeds.append(seed)
clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
if init_image is not None:
init_images.append(init_image)
@@ -3087,8 +3094,8 @@ def main(args):
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts)
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
):
if highres_fix:
seed -= 1 # record original seed
@@ -3104,6 +3111,8 @@ def main(args):
metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt)
if raw_prompt is not None:
metadata.add_text("raw-prompt", raw_prompt)
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
@@ -3438,7 +3447,9 @@ def main(args):
b1 = BatchData(
False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
),
BatchDataExt(
width,
height,

View File

@@ -1449,6 +1449,7 @@ class BatchDataBase(NamedTuple):
mask_image: Any
clip_prompt: str
guide_image: Any
raw_prompt: str
class BatchDataExt(NamedTuple):
@@ -1918,7 +1919,7 @@ def main(args):
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
elif args.prompt is not None:
prompt_list = [args.prompt]
else:
@@ -2190,7 +2191,7 @@ def main(args):
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image),
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(
width,
height,
@@ -2212,6 +2213,7 @@ def main(args):
prompts = []
negative_prompts = []
raw_prompts = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2242,11 +2244,16 @@ def main(args):
all_images_are_same = True
all_masks_are_same = True
all_guide_images_are_same = True
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
_,
) in enumerate(batch):
prompts.append(prompt)
negative_prompts.append(negative_prompt)
seeds.append(seed)
clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
if init_image is not None:
init_images.append(init_image)
@@ -2344,8 +2351,8 @@ def main(args):
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts)
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts )
):
if highres_fix:
seed -= 1 # record original seed
@@ -2361,6 +2368,8 @@ def main(args):
metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt)
if raw_prompt is not None:
metadata.add_text("raw-prompt", raw_prompt)
metadata.add_text("original-height", str(original_height))
metadata.add_text("original-width", str(original_width))
metadata.add_text("original-height-negative", str(original_height_negative))
@@ -2736,7 +2745,7 @@ def main(args):
b1 = BatchData(
False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
BatchDataExt(
width,
height,