mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
supprot dynamic prompt variants
This commit is contained in:
@@ -46,6 +46,7 @@ VGG(
|
||||
)
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
@@ -2159,6 +2160,102 @@ def preprocess_mask(mask):
|
||||
return mask
|
||||
|
||||
|
||||
# regular expression for dynamic prompt:
|
||||
# starts and ends with "{" and "}"
|
||||
# contains at least one variant divided by "|"
|
||||
# optional framgments divided by "$$" at start
|
||||
# if the first fragment is "E" or "e", enumerate all variants
|
||||
# if the second fragment is a number or two numbers, repeat the variants in the range
|
||||
# if the third fragment is a string, use it as a separator
|
||||
|
||||
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
|
||||
|
||||
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
|
||||
if not founds:
|
||||
return [prompt]
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found in founds:
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
|
||||
separater = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
else:
|
||||
print(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
|
||||
if found_enumerating:
|
||||
# make all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer():
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separater))
|
||||
else:
|
||||
# make random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer():
|
||||
count = random.randint(cr[0], cr[1])
|
||||
comb = random.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separater))
|
||||
|
||||
# make each prompt
|
||||
if not enumerating:
|
||||
prompts = []
|
||||
for _ in range(repeat_count):
|
||||
current = prompt
|
||||
for found, replacer in zip(founds, replacers):
|
||||
current = current.replace(found.group(0), replacer()[0])
|
||||
prompts.append(current)
|
||||
else:
|
||||
prompts = [prompt]
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None: # enumerating
|
||||
new_prompts = []
|
||||
for current in prompts:
|
||||
replecements = replacer()
|
||||
for replecement in replecements:
|
||||
new_prompts.append(current.replace(found.group(0), replecement))
|
||||
prompts = new_prompts
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is None:
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace(found.group(0), replacer()[0])
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -2776,6 +2873,7 @@ def main(args):
|
||||
|
||||
# seed指定時はseedを決めておく
|
||||
if args.seed is not None:
|
||||
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
|
||||
random.seed(args.seed)
|
||||
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
|
||||
if len(predefined_seeds) == 1:
|
||||
@@ -3058,121 +3156,134 @@ def main(args):
|
||||
while not valid:
|
||||
print("\nType prompt:")
|
||||
try:
|
||||
prompt = input()
|
||||
raw_prompt = input()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
valid = len(prompt.strip().split(" --")[0].strip()) > 0
|
||||
valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
|
||||
if not valid: # EOF, end app
|
||||
break
|
||||
else:
|
||||
prompt = prompt_list[prompt_index]
|
||||
raw_prompt = prompt_list[prompt_index]
|
||||
|
||||
# parse prompt
|
||||
width = args.W
|
||||
height = args.H
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
# sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary
|
||||
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
|
||||
|
||||
prompt_args = prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
# repeat prompt
|
||||
for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0]
|
||||
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
print(f"width: {width}")
|
||||
continue
|
||||
if prompt_index == 0 or len(raw_prompts) > 1:
|
||||
# parse prompt: if prompt is not changed, skip parsing
|
||||
width = args.W
|
||||
height = args.H
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
print(f"height: {height}")
|
||||
continue
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
steps = max(1, min(1000, int(m.group(1))))
|
||||
print(f"steps: {steps}")
|
||||
continue
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
print(f"width: {width}")
|
||||
continue
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
print(f"seeds: {seeds}")
|
||||
continue
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
print(f"height: {height}")
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
print(f"scale: {scale}")
|
||||
continue
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
steps = max(1, min(1000, int(m.group(1))))
|
||||
print(f"steps: {steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||
if m: # negative scale
|
||||
if m.group(1).lower() == "none":
|
||||
negative_scale = None
|
||||
else:
|
||||
negative_scale = float(m.group(1))
|
||||
print(f"negative scale: {negative_scale}")
|
||||
continue
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
print(f"seeds: {seeds}")
|
||||
continue
|
||||
|
||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # strength
|
||||
strength = float(m.group(1))
|
||||
print(f"strength: {strength}")
|
||||
continue
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
print(f"scale: {scale}")
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
print(f"negative prompt: {negative_prompt}")
|
||||
continue
|
||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||
if m: # negative scale
|
||||
if m.group(1).lower() == "none":
|
||||
negative_scale = None
|
||||
else:
|
||||
negative_scale = float(m.group(1))
|
||||
print(f"negative scale: {negative_scale}")
|
||||
continue
|
||||
|
||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||
if m: # clip prompt
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # strength
|
||||
strength = float(m.group(1))
|
||||
print(f"strength: {strength}")
|
||||
continue
|
||||
|
||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
print(f"negative prompt: {negative_prompt}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||
if m: # clip prompt
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
|
||||
if seeds is not None:
|
||||
# 数が足りないなら繰り返す
|
||||
if len(seeds) < args.images_per_prompt:
|
||||
seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds)))
|
||||
seeds = seeds[: args.images_per_prompt]
|
||||
else:
|
||||
if predefined_seeds is not None:
|
||||
seeds = predefined_seeds[-args.images_per_prompt :]
|
||||
predefined_seeds = predefined_seeds[: -args.images_per_prompt]
|
||||
elif args.iter_same_seed:
|
||||
seeds = [iter_seed] * args.images_per_prompt
|
||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
if len(seeds) > 0:
|
||||
seed = seeds.pop(0)
|
||||
else:
|
||||
seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)]
|
||||
if predefined_seeds is not None:
|
||||
if len(predefined_seeds) > 0:
|
||||
seed = predefined_seeds.pop(0)
|
||||
else:
|
||||
print("predefined seeds are exhausted")
|
||||
seed = None
|
||||
elif args.iter_same_seed:
|
||||
seeds = iter_seed
|
||||
if seed is None:
|
||||
seed = random.randint(0, 0x7FFFFFFF)
|
||||
if args.interactive:
|
||||
print(f"seed: {seeds}")
|
||||
print(f"seed: {seed}")
|
||||
|
||||
# prepare init image, guide image and mask
|
||||
init_image = mask_image = guide_image = None
|
||||
|
||||
init_image = mask_image = guide_image = None
|
||||
for seed in seeds: # images_per_promptの数だけ
|
||||
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
|
||||
if init_images is not None:
|
||||
init_image = init_images[global_step % len(init_images)]
|
||||
|
||||
Reference in New Issue
Block a user