mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Add --save_last_n_epochs option
This commit is contained in:
15
train_db.py
15
train_db.py
@@ -27,6 +27,7 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -1101,16 +1102,28 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
|
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
|
||||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
||||||
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
|
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
|
||||||
|
if args.save_last_n_epochs is not None:
|
||||||
|
old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_last_n_epochs))
|
||||||
|
if os.path.exists(old_ckpt_file):
|
||||||
|
os.remove(old_ckpt_file)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
|
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
|
||||||
unwrap_model(unet), src_diffusers_model_path,
|
unwrap_model(unet), src_diffusers_model_path,
|
||||||
use_safetensors=use_safetensors)
|
use_safetensors=use_safetensors)
|
||||||
|
if args.save_last_n_epochs is not None:
|
||||||
|
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_last_n_epochs))
|
||||||
|
if os.path.exists(out_dir_old):
|
||||||
|
shutil.rmtree(out_dir_old)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
print("saving state.")
|
print("saving state.")
|
||||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
|
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
|
||||||
|
if args.save_last_n_epochs is not None:
|
||||||
|
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_last_n_epochs))
|
||||||
|
if os.path.exists(state_dir_old):
|
||||||
|
shutil.rmtree(state_dir_old)
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@@ -1173,6 +1186,8 @@ if __name__ == '__main__':
|
|||||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
|
parser.add_argument("--save_last_n_epochs", type=int, default=None,
|
||||||
|
help="save last N checkpoints / 最大Nエポック保存する")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
|
|||||||
Reference in New Issue
Block a user