sample images in training (not fully tested)

This commit is contained in:
Kohya S
2023-02-27 17:48:32 +09:00
parent a28f9ae7a3
commit dd523c94ff
5 changed files with 209 additions and 6 deletions

View File

@@ -282,6 +282,8 @@ def train(args):
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@@ -309,6 +311,8 @@ def train(args):
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)

View File

@@ -3,12 +3,12 @@
import argparse
import importlib
import json
import re
import shutil
import time
from typing import Dict, List, NamedTuple, Tuple
from typing import Optional, Union
from accelerate import Accelerator
from torch.autograd.function import Function
import glob
import math
import os
@@ -25,7 +25,10 @@ from transformers import CLIPTokenizer
import transformers
import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import DDPMScheduler, StableDiffusionPipeline
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
import albumentations as albu
import numpy as np
from PIL import Image
@@ -1453,6 +1456,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--lowram", action="store_true",
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなどColabやKaggleなどRAMに比べてVRAMが多い環境向け")
parser.add_argument("--sample_every_n_steps", type=int, default=None,
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
parser.add_argument("--sample_prompts", type=str, default=None,
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
parser.add_argument('--sample_sampler', type=str, default='ddim',
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
'dpmsolver++', 'dpmsingle',
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
if support_dreambooth:
# DreamBooth training
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
@@ -2051,6 +2066,182 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = 'scaled_linear'
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet):
"""
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
clip skipは対応した
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0:
return
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
org_vae_device = vae.device # CPUにいるはず
vae.to(device)
# clip skip 対応のための wrapper を作る
if args.clip_skip is None:
text_encoder_or_wrapper = text_encoder
else:
print("create wrapper")
class Wrapper():
def __init__(self, tenc) -> None:
self.tenc = tenc
self.config = {}
super().__init__()
def __call__(self, input_ids, attention_mask):
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
pooled_output = enc_out['pooler_output']
return encoder_hidden_states, pooled_output # 1st output is only used
text_encoder_or_wrapper = Wrapper(text_encoder)
# read prompts
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
prompts = f.readlines()
# schedulerを用意する
sched_init_args = {}
if args.sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args['algorithm_type'] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if args.v_parameterization:
sched_init_args['prediction_type'] = 'v_prediction'
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
pipeline.to(device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state()
with torch.no_grad():
with accelerator.autocast():
for i, prompt in enumerate(prompts):
prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == '#':
continue
# subset of gen_img_diffusers
prompt_args = prompt.split(' --')
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
for parg in prompt_args:
try:
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r's (\d+)', parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r'n (.+)', parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
torch.set_rng_state(rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
# endregion
# region 前処理用

View File

@@ -278,6 +278,8 @@ def train(args):
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@@ -309,6 +311,8 @@ def train(args):
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)

View File

@@ -1,4 +1,3 @@
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
@@ -12,7 +11,6 @@ import json
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
@@ -400,6 +398,8 @@ def train(args):
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
@@ -445,6 +445,8 @@ def train(args):
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
metadata["ss_epoch"] = str(num_train_epochs)

View File

@@ -354,6 +354,8 @@ def train(args):
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@@ -376,8 +378,6 @@ def train(args):
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
# d = updated_embs - bef_epo_embs
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -399,6 +399,8 @@ def train(args):
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
is_main_process = accelerator.is_main_process