resume from huggingface repository

This commit is contained in:
ddPn08
2023-03-30 23:36:42 +09:00
parent a7d302e196
commit b5ff4e816f
8 changed files with 139 additions and 77 deletions

View File

@@ -231,9 +231,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -0,0 +1,71 @@
from typing import *
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
api = HfApi(
token=token,
)
try:
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
return True
except:
return False
@fire_in_thread
def upload(
src: Union[str, Path, bytes, BinaryIO],
args: argparse.Namespace,
dest_suffix: str = "",
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility == "private"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
def list_dir(
repo_id: str,
subfolder: str,
repo_type: str,
revision: str = "main",
token: str = None,
):
api = HfApi(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
return file_list

View File

@@ -2,6 +2,7 @@
import argparse
import ast
import asyncio
import importlib
import json
import pathlib
@@ -49,6 +50,7 @@ from diffusers import (
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from huggingface_hub import hf_hub_download
import albumentations as albu
import numpy as np
from PIL import Image
@@ -58,7 +60,7 @@ from torch import einsum
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.utils as utils
import library.huggingface_util as huggingface_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -1902,6 +1904,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token to upload model / huggingfaceにアップロードするモデルのトークン")
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface model visibility / huggingfaceにアップロードするモデルの公開設定")
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
parser.add_argument(
"--resume_from_huggingface",
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--save_precision",
type=str,
@@ -2266,6 +2273,56 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
# region utils
def resume(accelerator, args):
if args.resume:
print(f"resume training from state: {args.resume}")
if args.resume_from_huggingface:
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
path_in_repo = "/".join(args.resume.split("/")[2:])
revision = None
repo_type = None
if ":" in path_in_repo:
divided = path_in_repo.split(":")
if len(divided) == 2:
path_in_repo, revision = divided
repo_type = "model"
else:
path_in_repo, revision, repo_type = divided
print(
f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}"
)
list_files = huggingface_util.list_dir(
repo_id=repo_id,
subfolder=path_in_repo,
revision=revision,
token=args.huggingface_token,
repo_type=repo_type,
)
async def download(filename) -> str:
def task():
return hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
repo_type=repo_type,
token=args.huggingface_token,
)
return await asyncio.get_event_loop().run_in_executor(None, task)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(
asyncio.gather(
*[download(filename=filename.rfilename) for filename in list_files]
)
)
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
else:
accelerator.load_state(args.resume)
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
@@ -2812,7 +2869,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
utils.huggingface_upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:

View File

@@ -1,66 +1,8 @@
import argparse
import os
from pathlib import Path
import threading
from typing import *
from huggingface_hub import HfApi
def fire_in_thread(f):
def wrapped(*args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
return wrapped
def huggingface_exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
api = HfApi()
try:
api.repo_info(
repo_id=repo_id, token=token, revision=revision, repo_type=repo_type
)
return True
except:
return False
@fire_in_thread
def huggingface_upload(
src: Union[str, Path, bytes, BinaryIO],
args: argparse.Namespace,
dest_suffix: str = "",
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility == "private"
api = HfApi()
if not huggingface_exists_repo(
repo_id=repo_id, repo_type=repo_type, token=token
):
api.create_repo(
token=token, repo_id=repo_id, repo_type=repo_type, private=private
)
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
token=token,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
token=token,
)

View File

@@ -21,6 +21,6 @@ fairscale==0.4.13
# for WD14 captioning
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
huggingface-hub==0.13.3
# for kohya_ss library
.

View File

@@ -202,9 +202,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -24,7 +24,7 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.utils as utils
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
@@ -285,9 +285,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -628,7 +626,7 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -668,7 +666,7 @@ def train(args):
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
print("model saved.")

View File

@@ -304,9 +304,7 @@ def train(args):
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)