diff --git a/fine_tune.py b/fine_tune.py index 50549878..ca42a403 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -231,9 +231,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/library/huggingface_util.py b/library/huggingface_util.py new file mode 100644 index 00000000..4431a208 --- /dev/null +++ b/library/huggingface_util.py @@ -0,0 +1,78 @@ +from typing import * +from huggingface_hub import HfApi +from pathlib import Path +import argparse +import os + +from library.utils import fire_in_thread + + +def exists_repo( + repo_id: str, repo_type: str, revision: str = "main", token: str = None +): + api = HfApi( + token=token, + ) + try: + api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + return True + except: + return False + + +def upload( + args: argparse.Namespace, + src: Union[str, Path, bytes, BinaryIO], + dest_suffix: str = "", + force_sync_upload: bool = False, +): + repo_id = args.huggingface_repo_id + repo_type = args.huggingface_repo_type + token = args.huggingface_token + path_in_repo = args.huggingface_path_in_repo + dest_suffix + private = args.huggingface_repo_visibility == "private" + api = HfApi(token=token) + if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + + is_folder = (type(src) == str and os.path.isdir(src)) or ( + isinstance(src, Path) and src.is_dir() + ) + + def uploader(): + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + + if args.async_upload and not force_sync_upload: + fire_in_thread(uploader) + else: + uploader() + + +def list_dir( + repo_id: str, + subfolder: str, + repo_type: str, + revision: str = "main", + token: str = None, +): + api = HfApi( + token=token, + ) + repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + file_list = [ + file for file in repo_info.siblings if file.rfilename.startswith(subfolder) + ] + return file_list diff --git a/library/train_util.py b/library/train_util.py index a195faac..98088c21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2,6 +2,7 @@ import argparse import ast +import asyncio import importlib import json import pathlib @@ -49,6 +50,7 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, ) +from huggingface_hub import hf_hub_download import albumentations as albu import numpy as np from PIL import Image @@ -58,6 +60,7 @@ from torch import einsum import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util +import library.huggingface_util as huggingface_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1441,7 +1444,6 @@ def glob_images_pathlib(dir_path, recursive): # endregion - # region モジュール入れ替え部 """ 高速化のためのモジュール入れ替え @@ -1896,6 +1898,22 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") + parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名") + parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類") + parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス") + parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") + parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface repository visibility / huggingfaceにアップロードするリポジトリの公開設定") + parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する") + parser.add_argument( + "--resume_from_huggingface", + action="store_true", + help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", + ) + parser.add_argument( + "--async_upload", + action="store_true", + help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", + ) parser.add_argument( "--save_precision", type=str, @@ -2260,6 +2278,56 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar # region utils +def resume(accelerator, args): + if args.resume: + print(f"resume training from state: {args.resume}") + if args.resume_from_huggingface: + repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] + path_in_repo = "/".join(args.resume.split("/")[2:]) + revision = None + repo_type = None + if ":" in path_in_repo: + divided = path_in_repo.split(":") + if len(divided) == 2: + path_in_repo, revision = divided + repo_type = "model" + else: + path_in_repo, revision, repo_type = divided + print( + f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}" + ) + + list_files = huggingface_util.list_dir( + repo_id=repo_id, + subfolder=path_in_repo, + revision=revision, + token=args.huggingface_token, + repo_type=repo_type, + ) + + async def download(filename) -> str: + def task(): + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + token=args.huggingface_token, + ) + + return await asyncio.get_event_loop().run_in_executor(None, task) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + asyncio.gather( + *[download(filename=filename.rfilename) for filename in list_files] + ) + ) + dirname = os.path.dirname(results[0]) + accelerator.load_state(dirname) + else: + accelerator.load_state(args.resume) + def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" @@ -2772,6 +2840,8 @@ def save_sd_model_on_epoch_end( model_util.save_stable_diffusion_checkpoint( args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_sd(old_epoch_no): _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) @@ -2791,6 +2861,8 @@ def save_sd_model_on_epoch_end( model_util.save_diffusers_checkpoint( args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name) def remove_du(old_epoch_no): out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) @@ -2808,7 +2880,10 @@ def save_sd_model_on_epoch_end( def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs if last_n_epochs is not None: @@ -2843,6 +2918,8 @@ def save_sd_model_on_train_end( model_util.save_stable_diffusion_checkpoint( args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) else: out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) @@ -2851,6 +2928,8 @@ def save_sd_model_on_train_end( model_util.save_diffusers_checkpoint( args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) def save_state_on_train_end(args: argparse.Namespace, accelerator): diff --git a/library/utils.py b/library/utils.py new file mode 100644 index 00000000..7d801a67 --- /dev/null +++ b/library/utils.py @@ -0,0 +1,6 @@ +import threading +from typing import * + + +def fire_in_thread(f, *args, **kwargs): + threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eea1c663..d3164894 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,6 @@ fairscale==0.4.13 # for WD14 captioning # tensorflow<2.11 tensorflow==2.10.1 -huggingface-hub==0.12.0 +huggingface-hub==0.13.3 # for kohya_ss library . diff --git a/train_db.py b/train_db.py index b3eead94..0b7f2d37 100644 --- a/train_db.py +++ b/train_db.py @@ -202,9 +202,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/train_network.py b/train_network.py index 9956a905..48ce73f7 100644 --- a/train_network.py +++ b/train_network.py @@ -24,6 +24,7 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight @@ -71,8 +72,9 @@ def train(args): use_dreambooth_method = args.in_json is None use_user_config = args.dataset_config is not None - if args.seed is not None: - set_seed(args.seed) + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) tokenizer = train_util.load_tokenizer(args) @@ -308,9 +310,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -650,6 +650,8 @@ def train(args): metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -689,6 +691,8 @@ def train(args): print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f279370a..e7d052ee 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -13,6 +13,7 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util +import library.huggingface_util as huggingface_util import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -304,9 +305,7 @@ def train(args): text_encoder.to(weight_dtype) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -452,6 +451,8 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -492,6 +493,8 @@ def train(args): print(f"save trained model to {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.") diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 74e9bc2e..7e393bcd 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -13,6 +13,7 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util +import library.huggingface_util as huggingface_util import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -493,6 +494,8 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -534,6 +537,8 @@ def train(args): print(f"save trained model to {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.")