From b5ff4e816f7b69f0ab0e8081a9b099fd5bb1a8f0 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 30 Mar 2023 23:36:42 +0900 Subject: [PATCH] resume from huggingface repository --- fine_tune.py | 4 +-- library/huggingface_util.py | 71 +++++++++++++++++++++++++++++++++++++ library/train_util.py | 61 +++++++++++++++++++++++++++++-- library/utils.py | 60 +------------------------------ requirements.txt | 2 +- train_db.py | 4 +-- train_network.py | 10 +++--- train_textual_inversion.py | 4 +-- 8 files changed, 139 insertions(+), 77 deletions(-) create mode 100644 library/huggingface_util.py diff --git a/fine_tune.py b/fine_tune.py index 637a729a..289fbeb8 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -231,9 +231,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/library/huggingface_util.py b/library/huggingface_util.py new file mode 100644 index 00000000..353189c0 --- /dev/null +++ b/library/huggingface_util.py @@ -0,0 +1,71 @@ +from typing import * +from huggingface_hub import HfApi +from pathlib import Path +import argparse +import os + +from library.utils import fire_in_thread + + +def exists_repo( + repo_id: str, repo_type: str, revision: str = "main", token: str = None +): + api = HfApi( + token=token, + ) + try: + api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + return True + except: + return False + + +@fire_in_thread +def upload( + src: Union[str, Path, bytes, BinaryIO], + args: argparse.Namespace, + dest_suffix: str = "", +): + repo_id = args.huggingface_repo_id + repo_type = args.huggingface_repo_type + token = args.huggingface_token + path_in_repo = args.huggingface_path_in_repo + dest_suffix + private = args.huggingface_repo_visibility == "private" + api = HfApi(token=token) + if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + + is_folder = (type(src) == str and os.path.isdir(src)) or ( + isinstance(src, Path) and src.is_dir() + ) + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + + +def list_dir( + repo_id: str, + subfolder: str, + repo_type: str, + revision: str = "main", + token: str = None, +): + api = HfApi( + token=token, + ) + repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + file_list = [ + file for file in repo_info.siblings if file.rfilename.startswith(subfolder) + ] + return file_list diff --git a/library/train_util.py b/library/train_util.py index 179f23e4..e4e91ee2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2,6 +2,7 @@ import argparse import ast +import asyncio import importlib import json import pathlib @@ -49,6 +50,7 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, ) +from huggingface_hub import hf_hub_download import albumentations as albu import numpy as np from PIL import Image @@ -58,7 +60,7 @@ from torch import einsum import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util -import library.utils as utils +import library.huggingface_util as huggingface_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1902,6 +1904,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token to upload model / huggingfaceにアップロードするモデルのトークン") parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface model visibility / huggingfaceにアップロードするモデルの公開設定") parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する") + parser.add_argument( + "--resume_from_huggingface", + action="store_true", + help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", + ) parser.add_argument( "--save_precision", type=str, @@ -2266,6 +2273,56 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar # region utils +def resume(accelerator, args): + if args.resume: + print(f"resume training from state: {args.resume}") + if args.resume_from_huggingface: + repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] + path_in_repo = "/".join(args.resume.split("/")[2:]) + revision = None + repo_type = None + if ":" in path_in_repo: + divided = path_in_repo.split(":") + if len(divided) == 2: + path_in_repo, revision = divided + repo_type = "model" + else: + path_in_repo, revision, repo_type = divided + print( + f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}" + ) + + list_files = huggingface_util.list_dir( + repo_id=repo_id, + subfolder=path_in_repo, + revision=revision, + token=args.huggingface_token, + repo_type=repo_type, + ) + + async def download(filename) -> str: + def task(): + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + token=args.huggingface_token, + ) + + return await asyncio.get_event_loop().run_in_executor(None, task) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + asyncio.gather( + *[download(filename=filename.rfilename) for filename in list_files] + ) + ) + dirname = os.path.dirname(results[0]) + accelerator.load_state(dirname) + else: + accelerator.load_state(args.resume) + def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" @@ -2812,7 +2869,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - utils.huggingface_upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) + huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs if last_n_epochs is not None: diff --git a/library/utils.py b/library/utils.py index 3c3727d2..a6b05917 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,66 +1,8 @@ -import argparse -import os -from pathlib import Path import threading from typing import * -from huggingface_hub import HfApi - def fire_in_thread(f): def wrapped(*args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() - return wrapped - - -def huggingface_exists_repo( - repo_id: str, repo_type: str, revision: str = "main", token: str = None -): - api = HfApi() - try: - api.repo_info( - repo_id=repo_id, token=token, revision=revision, repo_type=repo_type - ) - return True - except: - return False - - -@fire_in_thread -def huggingface_upload( - src: Union[str, Path, bytes, BinaryIO], - args: argparse.Namespace, - dest_suffix: str = "", -): - repo_id = args.huggingface_repo_id - repo_type = args.huggingface_repo_type - token = args.huggingface_token - path_in_repo = args.huggingface_path_in_repo + dest_suffix - private = args.huggingface_repo_visibility == "private" - api = HfApi() - if not huggingface_exists_repo( - repo_id=repo_id, repo_type=repo_type, token=token - ): - api.create_repo( - token=token, repo_id=repo_id, repo_type=repo_type, private=private - ) - - is_folder = (type(src) == str and os.path.isdir(src)) or ( - isinstance(src, Path) and src.is_dir() - ) - if is_folder: - api.upload_folder( - repo_id=repo_id, - repo_type=repo_type, - folder_path=src, - path_in_repo=path_in_repo, - token=token, - ) - else: - api.upload_file( - repo_id=repo_id, - repo_type=repo_type, - path_or_fileobj=src, - path_in_repo=path_in_repo, - token=token, - ) + return wrapped \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eea1c663..d3164894 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,6 @@ fairscale==0.4.13 # for WD14 captioning # tensorflow<2.11 tensorflow==2.10.1 -huggingface-hub==0.12.0 +huggingface-hub==0.13.3 # for kohya_ss library . diff --git a/train_db.py b/train_db.py index b3eead94..0b7f2d37 100644 --- a/train_db.py +++ b/train_db.py @@ -202,9 +202,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/train_network.py b/train_network.py index c951b150..8cfe1ab8 100644 --- a/train_network.py +++ b/train_network.py @@ -24,7 +24,7 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.utils as utils +import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight @@ -285,9 +285,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -628,7 +626,7 @@ def train(args): metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -668,7 +666,7 @@ def train(args): print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) print("model saved.") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f279370a..c5bacf3b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -304,9 +304,7 @@ def train(args): text_encoder.to(weight_dtype) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)