diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 1f103e31..273c0dd8 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -937,6 +937,17 @@ class PipelineLike: if self.control_nets: guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if reginonal_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # last subprompt and negative prompt + text_emb_last = [] + for j in range(batch_size): + text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 2]) + text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 1]) + text_emb_last = torch.stack(text_emb_last) + else: + text_emb_last = text_embeddings + for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) @@ -944,11 +955,6 @@ class PipelineLike: # predict the noise residual if self.control_nets and self.control_net_enabled: - if reginonal_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - else: - text_emb_last = text_embeddings noise_pred = original_control_net.call_unet_and_control_net( i, num_latent_input, @@ -958,6 +964,7 @@ class PipelineLike: i / len(timesteps), latent_model_input, t, + text_embeddings, text_emb_last, ).sample else: @@ -2746,6 +2753,10 @@ def main(args): print(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) + # shuffle prompt list + if args.shuffle_prompts: + random.shuffle(prompt_list) + # バッチ処理の関数 def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): batch_size = len(batch) @@ -3321,6 +3332,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", ) + parser.add_argument( + "--shuffle_prompts", + action="store_true", + help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", + ) parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")