diff --git a/fine_tune.py b/fine_tune.py index ca42a403..89c16d2d 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -231,7 +231,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - train_util.resume(accelerator, args) + train_util.resume_from_local_or_hf_if_specified(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 4431a208..41031b1f 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -30,7 +30,7 @@ def upload( repo_type = args.huggingface_repo_type token = args.huggingface_token path_in_repo = args.huggingface_path_in_repo + dest_suffix - private = args.huggingface_repo_visibility == "private" + private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" api = HfApi(token=token) if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) diff --git a/library/train_util.py b/library/train_util.py index 98088c21..d5c5b0ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -490,7 +490,7 @@ class BaseDataset(torch.utils.data.Dataset): else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: tokens = [t.strip() for t in caption.strip().split(",")] - if subset.token_warmup_step < 1: # 初回に上書きする + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( @@ -1898,12 +1898,28 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") - parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名") - parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類") - parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス") + parser.add_argument( + "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名" + ) + parser.add_argument( + "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類" + ) + parser.add_argument( + "--huggingface_path_in_repo", + type=str, + default=None, + help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", + ) parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") - parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface repository visibility / huggingfaceにアップロードするリポジトリの公開設定") - parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する") + parser.add_argument( + "--huggingface_repo_visibility", + type=str, + default=None, + help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", + ) + parser.add_argument( + "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" + ) parser.add_argument( "--resume_from_huggingface", action="store_true", @@ -2278,55 +2294,56 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar # region utils -def resume(accelerator, args): - if args.resume: - print(f"resume training from state: {args.resume}") - if args.resume_from_huggingface: - repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] - path_in_repo = "/".join(args.resume.split("/")[2:]) - revision = None - repo_type = None - if ":" in path_in_repo: - divided = path_in_repo.split(":") - if len(divided) == 2: - path_in_repo, revision = divided - repo_type = "model" - else: - path_in_repo, revision, repo_type = divided - print( - f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}" - ) - list_files = huggingface_util.list_dir( - repo_id=repo_id, - subfolder=path_in_repo, - revision=revision, - token=args.huggingface_token, - repo_type=repo_type, - ) +def resume_from_local_or_hf_if_specified(accelerator, args): + if not args.resume: + return - async def download(filename) -> str: - def task(): - return hf_hub_download( - repo_id=repo_id, - filename=filename, - revision=revision, - repo_type=repo_type, - token=args.huggingface_token, - ) + if not args.resume_from_huggingface: + print(f"resume training from local state: {args.resume}") + accelerator.load_state(args.resume) + return - return await asyncio.get_event_loop().run_in_executor(None, task) - - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - asyncio.gather( - *[download(filename=filename.rfilename) for filename in list_files] - ) - ) - dirname = os.path.dirname(results[0]) - accelerator.load_state(dirname) + print(f"resume training from huggingface state: {args.resume}") + repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] + path_in_repo = "/".join(args.resume.split("/")[2:]) + revision = None + repo_type = None + if ":" in path_in_repo: + divided = path_in_repo.split(":") + if len(divided) == 2: + path_in_repo, revision = divided + repo_type = "model" else: - accelerator.load_state(args.resume) + path_in_repo, revision, repo_type = divided + print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + + list_files = huggingface_util.list_dir( + repo_id=repo_id, + subfolder=path_in_repo, + revision=revision, + token=args.huggingface_token, + repo_type=repo_type, + ) + + async def download(filename) -> str: + def task(): + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + token=args.huggingface_token, + ) + + return await asyncio.get_event_loop().run_in_executor(None, task) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) + if len(results) == 0: + raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした") + dirname = os.path.dirname(results[0]) + accelerator.load_state(dirname) def get_optimizer(args, trainable_params): @@ -2713,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'): +def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers @@ -2883,6 +2900,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: + print("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs @@ -2894,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e shutil.rmtree(state_dir_old) +def save_state_on_train_end(args: argparse.Namespace, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + print("uploading last state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) + + def save_sd_model_on_train_end( args: argparse.Namespace, src_path: str, @@ -2932,13 +2961,6 @@ def save_sd_model_on_train_end( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def save_state_on_train_end(args: argparse.Namespace, accelerator): - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) - - # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -3168,7 +3190,7 @@ class collater_class: def __init__(self, epoch, step, dataset): self.current_epoch = epoch self.current_step = step - self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing def __call__(self, examples): worker_info = torch.utils.data.get_worker_info() diff --git a/train_db.py b/train_db.py index 0b7f2d37..247256ad 100644 --- a/train_db.py +++ b/train_db.py @@ -202,7 +202,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - train_util.resume(accelerator, args) + train_util.resume_from_local_or_hf_if_specified(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/train_network.py b/train_network.py index 48ce73f7..e453d708 100644 --- a/train_network.py +++ b/train_network.py @@ -310,7 +310,7 @@ def train(args): train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする - train_util.resume(accelerator, args) + train_util.resume_from_local_or_hf_if_specified(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e7d052ee..d8d803a4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -305,7 +305,7 @@ def train(args): text_encoder.to(weight_dtype) # resumeする - train_util.resume(accelerator, args) + train_util.resume_from_local_or_hf_if_specified(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7e393bcd..9bd775ef 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -341,9 +341,7 @@ def train(args): text_encoder.to(weight_dtype) # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + train_util.resume_from_local_or_hf_if_specified(accelerator, args) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)