change async uploading to optional

This commit is contained in:
ddPn08
2023-04-02 17:45:26 +09:00
parent 8bfa50e283
commit 16ba1cec69
5 changed files with 38 additions and 26 deletions

View File

@@ -20,11 +20,11 @@ def exists_repo(
return False
@fire_in_thread
def upload(
src: Union[str, Path, bytes, BinaryIO],
args: argparse.Namespace,
src: Union[str, Path, bytes, BinaryIO],
dest_suffix: str = "",
force_sync_upload: bool = False,
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
@@ -38,20 +38,27 @@ def upload(
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
uploader()
def list_dir(

View File

@@ -1909,6 +1909,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--async_upload",
action="store_true",
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
)
parser.add_argument(
"--save_precision",
type=str,
@@ -2831,7 +2836,7 @@ def save_sd_model_on_epoch_end(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
@@ -2852,7 +2857,7 @@ def save_sd_model_on_epoch_end(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)
huggingface_util.upload(args, out_dir, "/" + model_name)
def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
@@ -2873,7 +2878,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
@@ -2909,7 +2914,7 @@ def save_sd_model_on_train_end(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else:
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
@@ -2919,7 +2924,7 @@ def save_sd_model_on_train_end(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def save_state_on_train_end(args: argparse.Namespace, accelerator):

View File

@@ -627,7 +627,7 @@ def train(args):
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -668,7 +668,7 @@ def train(args):
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")

View File

@@ -452,7 +452,7 @@ def train(args):
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -494,7 +494,7 @@ def train(args):
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")

View File

@@ -495,7 +495,7 @@ def train(args):
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -538,7 +538,7 @@ def train(args):
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")