enable different prompt for text encoders

This commit is contained in:
Kohya S
2023-07-18 21:39:01 +09:00
parent 7e20c6d1a1
commit 3d66a234b0

View File

@@ -86,6 +86,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--prompt2", type=str, default=None)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument(
@@ -98,6 +99,9 @@ if __name__ == "__main__":
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
if args.prompt2 is None:
args.prompt2 = args.prompt
# HuggingFaceのmodel id
text_encoder_1_name = "openai/clip-vit-large-patch14"
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -169,7 +173,7 @@ if __name__ == "__main__":
beta_schedule=SCHEDLER_SCHEDULE,
)
def generate_image(prompt, negative_prompt, seed=None):
def generate_image(prompt, prompt2, negative_prompt, seed=None):
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
# prepare embedding
with torch.no_grad():
@@ -184,7 +188,7 @@ if __name__ == "__main__":
# crossattn
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
def call_text_encoder(text):
def call_text_encoder(text, text2):
# text encoder 1
batch_encoding = tokenizer1(
text,
@@ -203,7 +207,7 @@ if __name__ == "__main__":
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text).to(DEVICE)
tokens = tokenizer2(text2).to(DEVICE)
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
@@ -215,12 +219,12 @@ if __name__ == "__main__":
return text_embedding, text_embedding2_pool
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt)
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt)
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
text_embeddings = torch.cat([uc_ctx, c_ctx])
@@ -295,19 +299,22 @@ if __name__ == "__main__":
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
if not args.interactive:
generate_image(args.prompt, args.negative_prompt, seed)
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
else:
# loop for interactive
while True:
prompt = input("prompt: ")
if prompt == "":
break
prompt2 = input("prompt2: ")
if prompt2 == "":
prompt2 = prompt
negative_prompt = input("negative prompt: ")
seed = input("seed: ")
if seed == "":
seed = None
else:
seed = int(seed)
generate_image(prompt, negative_prompt, seed)
generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!")