mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
support sample generation in TI training
This commit is contained in:
@@ -2074,7 +2074,7 @@ SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = 'scaled_linear'
|
||||
|
||||
|
||||
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet):
|
||||
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
|
||||
"""
|
||||
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
||||
clip skipは対応した
|
||||
@@ -2103,8 +2103,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
if args.clip_skip is None:
|
||||
text_encoder_or_wrapper = text_encoder
|
||||
else:
|
||||
print("create wrapper")
|
||||
|
||||
class Wrapper():
|
||||
def __init__(self, tenc) -> None:
|
||||
self.tenc = tenc
|
||||
@@ -2116,7 +2114,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
||||
pooled_output = enc_out['pooler_output']
|
||||
return encoder_hidden_states, pooled_output # 1st output is only used
|
||||
return encoder_hidden_states, pooled_output # 1st output is only used
|
||||
|
||||
text_encoder_or_wrapper = Wrapper(text_encoder)
|
||||
|
||||
@@ -2229,12 +2227,17 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||
|
||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
||||
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user