From 055f02e1e16f2b5d71d48e9aba7e751b825e6b46 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Feb 2024 21:16:42 +0900 Subject: [PATCH] add logging args for training scripts --- fine_tune.py | 3 +- sdxl_train.py | 25 +++++++++++---- sdxl_train_control_net_lllite.py | 28 ++++++++++++----- sdxl_train_control_net_lllite_old.py | 26 ++++++++++++---- train_controlnet.py | 16 +++++++--- train_db.py | 11 +++++-- train_network.py | 46 +++++++++++++++++++++------- train_textual_inversion.py | 13 ++++++-- train_textual_inversion_XTI.py | 27 +++++++++++----- 9 files changed, 147 insertions(+), 48 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 86f5b253..66f07288 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -24,6 +24,7 @@ setup_logging() import logging logger = logging.getLogger(__name__) + import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -476,7 +477,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - add_logging_arguments(parser) # モジュール指定がないのがちょっと気持ち悪い / bit weird that this does not have module prefix + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) diff --git a/sdxl_train.py b/sdxl_train.py index 8f1740af..3a09e4c8 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -20,10 +20,14 @@ from diffusers import DDPMScheduler from library import sdxl_model_util import library.train_util as train_util -from library.utils import setup_logging + +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -96,7 +100,9 @@ def train(args): train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) - assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -367,7 +373,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -437,7 +445,9 @@ def train(args): accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # accelerator.print( # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" # ) @@ -457,7 +467,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -734,6 +744,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) @@ -756,7 +767,9 @@ def setup_parser() -> argparse.ArgumentParser: help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", ) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--no_half_vae", diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 36dd72e6..15ee2f85 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -45,11 +45,14 @@ from library.custom_train_functions import ( apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -127,7 +130,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -257,7 +262,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -326,7 +333,9 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -344,7 +353,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -557,6 +566,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -572,8 +582,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 554b79ce..542cd734 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -41,11 +41,14 @@ from library.custom_train_functions import ( apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -123,7 +126,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -225,7 +230,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -299,7 +306,9 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -525,6 +534,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -540,8 +550,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/train_controlnet.py b/train_controlnet.py index fbeb7afd..7489c48f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -35,11 +35,14 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -256,7 +259,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -312,7 +317,9 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -335,7 +342,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -576,6 +583,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) diff --git a/train_db.py b/train_db.py index 98438666..09cb6900 100644 --- a/train_db.py +++ b/train_db.py @@ -35,9 +35,11 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) # perlin_noise, @@ -197,7 +199,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -268,7 +272,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -459,6 +463,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) diff --git a/train_network.py b/train_network.py index 99c95aac..bd5e3693 100644 --- a/train_network.py +++ b/train_network.py @@ -41,11 +41,14 @@ from library.custom_train_functions import ( add_v_prediction_like_loss, apply_debiased_estimation, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) + class NetworkTrainer: def __init__(self): self.vae_scale_factor = 0.18215 @@ -947,6 +950,7 @@ class NetworkTrainer: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) @@ -954,7 +958,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) parser.add_argument( "--save_model_as", type=str, @@ -966,10 +972,17 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") parser.add_argument( - "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", ) parser.add_argument( "--network_alpha", @@ -984,14 +997,25 @@ def setup_parser() -> argparse.ArgumentParser: help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", ) parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" - ) - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") - parser.add_argument( - "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) parser.add_argument( - "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" + ) + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", ) parser.add_argument( "--dim_from_weights", diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d3d10f17..5900ad19 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -32,9 +32,11 @@ from library.custom_train_functions import ( add_v_prediction_like_loss, apply_debiased_estimation, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) imagenet_templates_small = [ @@ -746,6 +748,7 @@ class TextualInversionTrainer: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -761,7 +764,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -771,7 +776,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 45ced3cf..9b7e04c8 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -36,9 +36,11 @@ from library.custom_train_functions import ( ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) imagenet_templates_small = [ @@ -322,7 +324,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - logger.info(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -380,7 +384,9 @@ def train(args): logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") logger.info(f" num epochs / epoch数: {num_train_epochs}") logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") - logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + logger.info( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -397,10 +403,12 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): @@ -653,6 +661,7 @@ def load_weights(file): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -668,7 +677,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -678,7 +689,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true",