Add --save_last_n_epochs option

This commit is contained in:
Yuta Hayashibe
2023-01-01 21:46:38 +09:00
parent bda0e8333c
commit 61a61c51ee

View File

@@ -27,6 +27,7 @@ import itertools
import math import math
import os import os
import random import random
import shutil
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -1101,16 +1102,28 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
if args.save_last_n_epochs is not None:
old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_last_n_epochs))
if os.path.exists(old_ckpt_file):
os.remove(old_ckpt_file)
else: else:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
unwrap_model(unet), src_diffusers_model_path, unwrap_model(unet), src_diffusers_model_path,
use_safetensors=use_safetensors) use_safetensors=use_safetensors)
if args.save_last_n_epochs is not None:
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_last_n_epochs))
if os.path.exists(out_dir_old):
shutil.rmtree(out_dir_old)
if args.save_state: if args.save_state:
print("saving state.") print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
if args.save_last_n_epochs is not None:
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_last_n_epochs))
if os.path.exists(state_dir_old):
shutil.rmtree(state_dir_old)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
@@ -1173,6 +1186,8 @@ if __name__ == '__main__':
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時") help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時")
parser.add_argument("--save_every_n_epochs", type=int, default=None, parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_last_n_epochs", type=int, default=None,
help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_state", action="store_true", parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")