mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
enable different prompt for text encoders
This commit is contained in:
@@ -86,6 +86,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
parser.add_argument("--prompt2", type=str, default=None)
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
parser.add_argument(
|
||||
@@ -98,6 +99,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.prompt2 is None:
|
||||
args.prompt2 = args.prompt
|
||||
|
||||
# HuggingFaceのmodel id
|
||||
text_encoder_1_name = "openai/clip-vit-large-patch14"
|
||||
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -169,7 +173,7 @@ if __name__ == "__main__":
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
)
|
||||
|
||||
def generate_image(prompt, negative_prompt, seed=None):
|
||||
def generate_image(prompt, prompt2, negative_prompt, seed=None):
|
||||
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
||||
# prepare embedding
|
||||
with torch.no_grad():
|
||||
@@ -184,7 +188,7 @@ if __name__ == "__main__":
|
||||
# crossattn
|
||||
|
||||
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
||||
def call_text_encoder(text):
|
||||
def call_text_encoder(text, text2):
|
||||
# text encoder 1
|
||||
batch_encoding = tokenizer1(
|
||||
text,
|
||||
@@ -203,7 +207,7 @@ if __name__ == "__main__":
|
||||
|
||||
# text encoder 2
|
||||
with torch.no_grad():
|
||||
tokens = tokenizer2(text).to(DEVICE)
|
||||
tokens = tokenizer2(text2).to(DEVICE)
|
||||
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
@@ -215,12 +219,12 @@ if __name__ == "__main__":
|
||||
return text_embedding, text_embedding2_pool
|
||||
|
||||
# cond
|
||||
c_ctx, c_ctx_pool = call_text_encoder(prompt)
|
||||
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||
|
||||
# uncond
|
||||
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt)
|
||||
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
|
||||
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
||||
|
||||
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
||||
@@ -295,19 +299,22 @@ if __name__ == "__main__":
|
||||
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(args.prompt, args.negative_prompt, seed)
|
||||
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
|
||||
else:
|
||||
# loop for interactive
|
||||
while True:
|
||||
prompt = input("prompt: ")
|
||||
if prompt == "":
|
||||
break
|
||||
prompt2 = input("prompt2: ")
|
||||
if prompt2 == "":
|
||||
prompt2 = prompt
|
||||
negative_prompt = input("negative prompt: ")
|
||||
seed = input("seed: ")
|
||||
if seed == "":
|
||||
seed = None
|
||||
else:
|
||||
seed = int(seed)
|
||||
generate_image(prompt, negative_prompt, seed)
|
||||
generate_image(prompt, prompt2, negative_prompt, seed)
|
||||
|
||||
print("Done!")
|
||||
|
||||
Reference in New Issue
Block a user