mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
passing filtered hyperparameters to accelerate
This commit is contained in:
@@ -310,7 +310,7 @@ def train(args):
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -3378,6 +3378,20 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
||||
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
|
||||
def filter_sensitive_args(args: argparse.Namespace):
|
||||
sensitive_args = ["wandb_api_key", "huggingface_token"]
|
||||
sensitive_path_args = [
|
||||
"pretrained_model_name_or_path",
|
||||
"vae",
|
||||
"tokenizer_cache_dir",
|
||||
"train_data_dir",
|
||||
"conditioning_data_dir",
|
||||
"reg_data_dir",
|
||||
"output_dir",
|
||||
"logging_dir",
|
||||
]
|
||||
filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args}
|
||||
return filtered_args
|
||||
|
||||
# verify command line args for training
|
||||
def verify_command_line_training_args(args: argparse.Namespace):
|
||||
|
||||
@@ -487,7 +487,7 @@ def train(args):
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
|
||||
@@ -353,7 +353,7 @@ def train(args):
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
||||
@@ -324,7 +324,7 @@ def train(args):
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
||||
@@ -344,7 +344,7 @@ def train(args):
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
||||
@@ -290,7 +290,7 @@ def train(args):
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -753,7 +753,7 @@ class NetworkTrainer:
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
|
||||
@@ -510,7 +510,7 @@ class TextualInversionTrainer:
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
# function for saving/removing
|
||||
|
||||
@@ -407,7 +407,7 @@ def train(args):
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
# function for saving/removing
|
||||
|
||||
Reference in New Issue
Block a user