mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
sample images in training (not fully tested)
This commit is contained in:
@@ -282,6 +282,8 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -309,6 +311,8 @@ def train(args):
|
||||
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
from typing import Optional, Union
|
||||
from accelerate import Accelerator
|
||||
from torch.autograd.function import Function
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -25,7 +25,10 @@ from transformers import CLIPTokenizer
|
||||
import transformers
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
from diffusers import (StableDiffusionPipeline, DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler,
|
||||
LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler)
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -1453,6 +1456,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--lowram", action="store_true",
|
||||
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)")
|
||||
|
||||
parser.add_argument("--sample_every_n_steps", type=int, default=None,
|
||||
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する")
|
||||
parser.add_argument("--sample_every_n_epochs", type=int, default=None,
|
||||
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)")
|
||||
parser.add_argument("--sample_prompts", type=str, default=None,
|
||||
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル")
|
||||
parser.add_argument('--sample_sampler', type=str, default='ddim',
|
||||
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
||||
'dpmsolver++', 'dpmsingle',
|
||||
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
||||
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
||||
@@ -2051,6 +2066,182 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||
|
||||
|
||||
# scheduler:
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = 'scaled_linear'
|
||||
|
||||
|
||||
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet):
|
||||
"""
|
||||
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
||||
clip skipは対応した
|
||||
"""
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0:
|
||||
return
|
||||
|
||||
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
# clip skip 対応のための wrapper を作る
|
||||
if args.clip_skip is None:
|
||||
text_encoder_or_wrapper = text_encoder
|
||||
else:
|
||||
print("create wrapper")
|
||||
|
||||
class Wrapper():
|
||||
def __init__(self, tenc) -> None:
|
||||
self.tenc = tenc
|
||||
self.config = {}
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids, attention_mask):
|
||||
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
||||
pooled_output = enc_out['pooler_output']
|
||||
return encoder_hidden_states, pooled_output # 1st output is only used
|
||||
|
||||
text_encoder_or_wrapper = Wrapper(text_encoder)
|
||||
|
||||
# read prompts
|
||||
with open(args.sample_prompts, 'rt', encoding='utf-8') as f:
|
||||
prompts = f.readlines()
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
if args.sample_sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif args.sample_sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms':
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler':
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a':
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args['algorithm_type'] = args.sample_sampler
|
||||
elif args.sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif args.sample_sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2':
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a':
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
if args.v_parameterization:
|
||||
sched_init_args['prediction_type'] = 'v_prediction'
|
||||
|
||||
scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
||||
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
||||
pipeline.to(device)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
with torch.no_grad():
|
||||
with accelerator.autocast():
|
||||
for i, prompt in enumerate(prompts):
|
||||
prompt = prompt.strip()
|
||||
if len(prompt) == 0 or prompt[0] == '#':
|
||||
continue
|
||||
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(' --')
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r'w (\d+)', parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r'h (\d+)', parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r'd (\d+)', parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r's (\d+)', parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r'n (.+)', parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||
|
||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
# endregion
|
||||
|
||||
# region 前処理用
|
||||
|
||||
@@ -278,6 +278,8 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -309,6 +311,8 @@ def train(args):
|
||||
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import importlib
|
||||
import argparse
|
||||
@@ -12,7 +11,6 @@ import json
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -400,6 +398,8 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
@@ -445,6 +445,8 @@ def train(args):
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
metadata["ss_epoch"] = str(num_train_epochs)
|
||||
|
||||
@@ -354,6 +354,8 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -376,8 +378,6 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
# d = updated_embs - bef_epo_embs
|
||||
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
@@ -399,6 +399,8 @@ def train(args):
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
Reference in New Issue
Block a user