From dd523c94ff400e94a294acaa8c7587c513020dfa Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Feb 2023 17:48:32 +0900 Subject: [PATCH] sample images in training (not fully tested) --- fine_tune.py | 4 + library/train_util.py | 195 ++++++++++++++++++++++++++++++++++++- train_db.py | 4 + train_network.py | 6 +- train_textual_inversion.py | 6 +- 5 files changed, 209 insertions(+), 6 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 80290e72..a9db2c4b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -282,6 +282,8 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -309,6 +311,8 @@ def train(args): train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + is_main_process = accelerator.is_main_process if is_main_process: unet = unwrap_model(unet) diff --git a/library/train_util.py b/library/train_util.py index 9f13baf2..848182fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,12 +3,12 @@ import argparse import importlib import json +import re import shutil import time from typing import Dict, List, NamedTuple, Tuple from typing import Optional, Union from accelerate import Accelerator -from torch.autograd.function import Function import glob import math import os @@ -25,7 +25,10 @@ from transformers import CLIPTokenizer import transformers import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION -from diffusers import DDPMScheduler, StableDiffusionPipeline +from diffusers import (StableDiffusionPipeline, DDPMScheduler, + EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, + KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler) import albumentations as albu import numpy as np from PIL import Image @@ -1453,6 +1456,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--lowram", action="store_true", help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)") + parser.add_argument("--sample_every_n_steps", type=int, default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する") + parser.add_argument("--sample_every_n_epochs", type=int, default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)") + parser.add_argument("--sample_prompts", type=str, default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル") + parser.add_argument('--sample_sampler', type=str, default='ddim', + choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', + 'dpmsolver++', 'dpmsingle', + 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], + help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類') + if support_dreambooth: # DreamBooth training parser.add_argument("--prior_loss_weight", type=float, default=1.0, @@ -2051,6 +2066,182 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = 'scaled_linear' + + +def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet): + """ + 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない + clip skipは対応した + """ + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0: + return + + print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + # ここでCUDAのキャッシュクリアとかしたほうがいいのか…… + + org_vae_device = vae.device # CPUにいるはず + vae.to(device) + + # clip skip 対応のための wrapper を作る + if args.clip_skip is None: + text_encoder_or_wrapper = text_encoder + else: + print("create wrapper") + + class Wrapper(): + def __init__(self, tenc) -> None: + self.tenc = tenc + self.config = {} + super().__init__() + + def __call__(self, input_ids, attention_mask): + enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) + pooled_output = enc_out['pooler_output'] + return encoder_hidden_states, pooled_output # 1st output is only used + + text_encoder_or_wrapper = Wrapper(text_encoder) + + # read prompts + with open(args.sample_prompts, 'rt', encoding='utf-8') as f: + prompts = f.readlines() + + # schedulerを用意する + sched_init_args = {} + if args.sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif args.sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms': + scheduler_cls = LMSDiscreteScheduler + elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler': + scheduler_cls = EulerDiscreteScheduler + elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a': + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args['algorithm_type'] = args.sample_sampler + elif args.sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2': + scheduler_cls = KDPM2DiscreteScheduler + elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a': + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if args.v_parameterization: + sched_init_args['prediction_type'] = 'v_prediction' + + scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, + scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) + pipeline.to(device) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + with accelerator.autocast(): + for i, prompt in enumerate(prompts): + prompt = prompt.strip() + if len(prompt) == 0 or prompt[0] == '#': + continue + + # subset of gen_img_diffusers + prompt_args = prompt.split(' --') + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r'w (\d+)', parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue + + m = re.match(r'h (\d+)', parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r'd (\d+)', parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r's (\d+)', parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r'n (.+)', parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] + + ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + + image.save(os.path.join(save_dir, img_filename)) + + torch.set_rng_state(rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) + # endregion # region 前処理用 diff --git a/train_db.py b/train_db.py index 03fba1a6..755e98f2 100644 --- a/train_db.py +++ b/train_db.py @@ -278,6 +278,8 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -309,6 +311,8 @@ def train(args): train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + is_main_process = accelerator.is_main_process if is_main_process: unet = unwrap_model(unet) diff --git a/train_network.py b/train_network.py index 0ba290a7..292a6701 100644 --- a/train_network.py +++ b/train_network.py @@ -1,4 +1,3 @@ -from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel as DDP import importlib import argparse @@ -12,7 +11,6 @@ import json from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -400,6 +398,8 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() if epoch == 0: loss_list.append(current_loss) @@ -445,6 +445,8 @@ def train(args): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # end of epoch metadata["ss_epoch"] = str(num_train_epochs) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b4ddd763..df718133 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -354,6 +354,8 @@ def train(args): progress_bar.update(1) global_step += 1 + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -376,8 +378,6 @@ def train(args): accelerator.wait_for_everyone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - # d = updated_embs - bef_epo_embs - # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) if args.save_every_n_epochs is not None: model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name @@ -399,6 +399,8 @@ def train(args): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # end of epoch is_main_process = accelerator.is_main_process