From 82707654add89ffeef6d98a67a0dcce0cc440cf3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 28 Feb 2023 22:05:31 +0900 Subject: [PATCH] support sample generation in TI training --- library/train_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 848182fb..e4d87fce 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2074,7 +2074,7 @@ SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = 'scaled_linear' -def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet): +def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None): """ 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない clip skipは対応した @@ -2103,8 +2103,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v if args.clip_skip is None: text_encoder_or_wrapper = text_encoder else: - print("create wrapper") - class Wrapper(): def __init__(self, tenc) -> None: self.tenc = tenc @@ -2116,7 +2114,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) pooled_output = enc_out['pooler_output'] - return encoder_hidden_states, pooled_output # 1st output is only used + return encoder_hidden_states, pooled_output # 1st output is only used text_encoder_or_wrapper = Wrapper(text_encoder) @@ -2229,12 +2227,17 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v torch.manual_seed(seed) torch.cuda.manual_seed(seed) + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" seed_suffix = "" if seed is None else f"_{seed}" - img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename))