diff --git a/fine_tune.py b/fine_tune.py index 4de57b45..b6a8d1d7 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -275,7 +275,7 @@ def train(args): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -285,18 +285,19 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings(tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -351,6 +352,27 @@ def train(args): accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet ) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -376,21 +398,23 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end( - args, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - unwrap_model(text_encoder), - unwrap_model(unet), - vae, - ) + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -401,7 +425,7 @@ def train(args): accelerator.end_training() - if args.save_state: + if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -437,4 +461,4 @@ if __name__ == "__main__": args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) \ No newline at end of file + train(args) diff --git a/library/train_util.py b/library/train_util.py index 40119c77..ec17e11c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,6 +74,11 @@ LAST_STATE_NAME = "{}-state" DEFAULT_EPOCH_NAME = "epoch" DEFAULT_LAST_OUTPUT_NAME = "last" +DEFAULT_STEP_NAME = "at" +STEP_STATE_NAME = "{}-step{:08d}-state" +STEP_FILE_NAME = "{}-step{:08d}" +STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}" + # region dataset IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] @@ -1986,18 +1991,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" ) + parser.add_argument( + "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" + ) parser.add_argument( "--save_n_epoch_ratio", type=int, default=None, help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)", ) - parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument( + "--save_last_n_epochs", + type=int, + default=None, + help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", + ) parser.add_argument( "--save_last_n_epochs_state", type=int, default=None, - help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)", + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", + ) + parser.add_argument( + "--save_last_n_steps", + type=int, + default=None, + help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", + ) + parser.add_argument( + "--save_last_n_steps_state", + type=int, + default=None, + help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", ) parser.add_argument( "--save_state", @@ -2903,26 +2928,53 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod return encoder_hidden_states -def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): - model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name - ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") - return model_name, ckpt_name +def default_if_none(value, default): + return default if value is None else value -def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): - saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs - if saving: - os.makedirs(args.output_dir, exist_ok=True) - save_func() - - if args.save_last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs - remove_old_func(remove_epoch_no) - return saving +def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int): + model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) + return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext -def save_sd_model_on_epoch_end( +def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int): + model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) + return STEP_FILE_NAME.format(model_name, step_no) + ext + + +def get_last_ckpt_name(args: argparse.Namespace, ext: str): + model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) + return model_name + ext + + +def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int): + if args.save_last_n_epochs is None: + return None + + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + if remove_epoch_no < 0: + return None + return remove_epoch_no + + +def get_remove_step_no(args: argparse.Namespace, step_no: int): + if args.save_last_n_steps is None: + return None + + # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する + # save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する + remove_step_no = step_no - args.save_last_n_steps - 1 + remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) + if remove_step_no < 0: + return None + return remove_step_no + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd_model_on_epoch_end_or_stepwise( args: argparse.Namespace, + on_epoch_end: bool, accelerator, src_path: str, save_stable_diffusion_format: bool, @@ -2935,57 +2987,87 @@ def save_sd_model_on_epoch_end( unet, vae, ): - epoch_no = epoch + 1 - model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) + if on_epoch_end: + epoch_no = epoch + 1 + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if not saving: + return - if save_stable_diffusion_format: - - def save_sd(): - ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae - ) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) - - def remove_sd(old_epoch_no): - _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - save_func = save_sd - remove_old_func = remove_sd + model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) + remove_no = get_remove_epoch_no(args, epoch_no) else: + # 保存するか否かは呼び出し側で判断済み - def save_du(): + model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) + epoch_no = epoch # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される + remove_no = get_remove_step_no(args, global_step) + + os.makedirs(args.output_dir, exist_ok=True) + if save_stable_diffusion_format: + ext = ".safetensors" if use_safetensors else ".ckpt" + + if on_epoch_end: + ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no) + else: + ckpt_name = get_step_ckpt_name(args, ext, global_step) + + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) + + # remove older checkpoints + if remove_no is not None: + if on_epoch_end: + remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no) + else: + remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no) + + remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) + if os.path.exists(remove_ckpt_file): + print(f"removing old checkpoint: {remove_ckpt_file}") + os.remove(remove_ckpt_file) + + else: + if on_epoch_end: out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) - print(f"saving model: {out_dir}") - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, out_dir, "/" + model_name) + else: + out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - def remove_du(old_epoch_no): - out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) - if os.path.exists(out_dir_old): - print(f"removing old model: {out_dir_old}") - shutil.rmtree(out_dir_old) + print(f"saving model: {out_dir}") + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name) - save_func = save_du - remove_old_func = remove_du + # remove older checkpoints + if remove_no is not None: + if on_epoch_end: + remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) + else: + remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) - saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) - if saving and args.save_state: - save_state_on_epoch_end(args, accelerator, model_name, epoch_no) + if os.path.exists(remove_out_dir): + print(f"removing old model: {remove_out_dir}") + shutil.rmtree(remove_out_dir) + + if on_epoch_end: + save_and_remove_state_on_epoch_end(args, accelerator, epoch_no) + else: + save_and_remove_state_stepwise(args, accelerator, global_step) -def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): - print("saving state.") +def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): + model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) + + print(f"saving state at epoch {epoch_no}") + os.makedirs(args.output_dir, exist_ok=True) + state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: @@ -3001,12 +3083,40 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e shutil.rmtree(state_dir_old) +def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): + model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) + + print(f"saving state at step {step_no}") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + print("uploading state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) + + last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps + if last_n_steps is not None: + # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する + remove_step_no = step_no - last_n_steps - 1 + remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) + + if remove_step_no > 0: + state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + def save_state_on_train_end(args: argparse.Namespace, accelerator): + model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) + print("saving last state.") os.makedirs(args.output_dir, exist_ok=True) - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) + if args.save_state_to_huggingface: print("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -3024,7 +3134,7 @@ def save_sd_model_on_train_end( unet, vae, ): - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) if save_stable_diffusion_format: os.makedirs(args.output_dir, exist_ok=True) diff --git a/train_db.py b/train_db.py index 5c4202a6..178d5cb4 100644 --- a/train_db.py +++ b/train_db.py @@ -25,6 +25,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings + def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) @@ -273,18 +274,19 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings(tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) @@ -335,6 +337,27 @@ def train(args): accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet ) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -364,21 +387,24 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end( - args, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - unwrap_model(text_encoder), - unwrap_model(unet), - vae, - ) + if accelerator.is_main_process: + # checking for saving is in util + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -389,7 +415,7 @@ def train(args): accelerator.end_training() - if args.save_state: + if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -434,4 +460,4 @@ if __name__ == "__main__": args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) \ No newline at end of file + train(args) diff --git a/train_network.py b/train_network.py index 8b6f2c83..5c4d5ad1 100644 --- a/train_network.py +++ b/train_network.py @@ -549,6 +549,27 @@ def train(args): # else: # on_step_start = lambda *args, **kwargs: None + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"saving checkpoint: {ckpt_file}") + metadata["ss_training_finished_at"] = str(time.time()) + metadata["ss_steps"] = str(steps) + metadata["ss_epoch"] = str(epoch_no) + + unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") @@ -638,6 +659,21 @@ def train(args): accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet ) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, unwrap_model(network), global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + current_loss = loss.detach().item() if epoch == 0: loss_list.append(current_loss) @@ -662,35 +698,26 @@ def train(args): accelerator.wait_for_everyone() + # 指定エポックごとにモデルを保存 if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1) - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - metadata["ss_training_finished_at"] = str(time.time()) - print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - if is_main_process: - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # end of epoch - metadata["ss_epoch"] = str(num_train_epochs) + # metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_training_finished_at"] = str(time.time()) if is_main_process: @@ -698,22 +725,15 @@ def train(args): accelerator.end_training() - if args.save_state: + if is_main_process and args.save_state: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + "." + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + print("model saved.") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 2042a618..fb6b6053 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -339,6 +339,23 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) + # function for saving/removing + def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -423,6 +440,23 @@ def train(args): accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement ) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, updated_embs, global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -449,26 +483,18 @@ def train(args): updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if accelerator.is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, updated_embs, epoch + 1, global_step) - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") - save_weights(ckpt_file, updated_embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) train_util.sample_images( accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement @@ -482,7 +508,7 @@ def train(args): accelerator.end_training() - if args.save_state: + if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() @@ -490,16 +516,9 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + "." + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - save_weights(ckpt_file, updated_embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.") diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index c2ebf7cb..69ec3eb1 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -373,6 +373,23 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) + # function for saving/removing + def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -462,6 +479,23 @@ def train(args): # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement # ) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, updated_embs, global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} @@ -488,26 +522,18 @@ def train(args): updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if accelerator.is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, updated_embs, epoch + 1, global_step) - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") - save_weights(ckpt_file, updated_embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) # TODO: fix sample_images # train_util.sample_images( @@ -522,7 +548,7 @@ def train(args): accelerator.end_training() - if args.save_state: + if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() @@ -530,16 +556,9 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + "." + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - save_weights(ckpt_file, updated_embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.")