support sample generation in TI training

This commit is contained in:
Kohya S
2023-02-28 22:05:31 +09:00
parent 57c565c402
commit 82707654ad

View File

@@ -2074,7 +2074,7 @@ SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = 'scaled_linear'
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet):
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
"""
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
clip skipは対応した
@@ -2103,8 +2103,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if args.clip_skip is None:
text_encoder_or_wrapper = text_encoder
else:
print("create wrapper")
class Wrapper():
def __init__(self, tenc) -> None:
self.tenc = tenc
@@ -2116,7 +2114,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
pooled_output = enc_out['pooler_output']
return encoder_hidden_states, pooled_output # 1st output is only used
return encoder_hidden_states, pooled_output # 1st output is only used
text_encoder_or_wrapper = Wrapper(text_encoder)
@@ -2229,12 +2227,17 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))