mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix ControlNet with regional LoRA, add shuffle cap
This commit is contained in:
@@ -937,6 +937,17 @@ class PipelineLike:
|
||||
if self.control_nets:
|
||||
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
# last subprompt and negative prompt
|
||||
text_emb_last = []
|
||||
for j in range(batch_size):
|
||||
text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 2])
|
||||
text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 1])
|
||||
text_emb_last = torch.stack(text_emb_last)
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
@@ -944,11 +955,6 @@ class PipelineLike:
|
||||
|
||||
# predict the noise residual
|
||||
if self.control_nets and self.control_net_enabled:
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
noise_pred = original_control_net.call_unet_and_control_net(
|
||||
i,
|
||||
num_latent_input,
|
||||
@@ -958,6 +964,7 @@ class PipelineLike:
|
||||
i / len(timesteps),
|
||||
latent_model_input,
|
||||
t,
|
||||
text_embeddings,
|
||||
text_emb_last,
|
||||
).sample
|
||||
else:
|
||||
@@ -2746,6 +2753,10 @@ def main(args):
|
||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||
iter_seed = random.randint(0, 0x7FFFFFFF)
|
||||
|
||||
# shuffle prompt list
|
||||
if args.shuffle_prompts:
|
||||
random.shuffle(prompt_list)
|
||||
|
||||
# バッチ処理の関数
|
||||
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
||||
batch_size = len(batch)
|
||||
@@ -3321,6 +3332,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle_prompts",
|
||||
action="store_true",
|
||||
help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする",
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
|
||||
parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
|
||||
|
||||
Reference in New Issue
Block a user