From 5602e0e5fc09c38b5705ca80eae3baff8ec1b115 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 2 Mar 2023 21:51:58 +0900 Subject: [PATCH] change dataset config option to dataset_config --- config_README-ja.md | 2 +- fine_tune.py | 6 +++--- library/config_util.py | 6 +++--- train_db.py | 6 +++--- train_network.py | 6 +++--- train_textual_inversion.py | 6 +++--- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/config_README-ja.md b/config_README-ja.md index 91381904..7f2b6c4c 100644 --- a/config_README-ja.md +++ b/config_README-ja.md @@ -1,6 +1,6 @@ For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. -`--config_file` で渡すことができる設定ファイルに関する説明です。 +`--dataset_config` で渡すことができる設定ファイルに関する説明です。 ## 概要 diff --git a/fine_tune.py b/fine_tune.py index ddc518f5..12557597 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -35,9 +35,9 @@ def train(args): tokenizer = train_util.load_tokenizer(args) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) - if args.config_file is not None: - print(f"Load config file from {args.config_file}") - user_config = config_util.load_user_config(args.config_file) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) diff --git a/library/config_util.py b/library/config_util.py index 10961c4b..a253bfb1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -38,7 +38,7 @@ from .train_util import ( def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--config_file", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") # TODO: inherit Params class in Subset, Dataset @@ -495,7 +495,7 @@ if __name__ == "__main__": parser.add_argument("--support_dreambooth", action="store_true") parser.add_argument("--support_finetuning", action="store_true") parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("config_file") + parser.add_argument("dataset_config") config_args, remain = parser.parse_known_args() parser = argparse.ArgumentParser() @@ -507,7 +507,7 @@ if __name__ == "__main__": print("[argparse_namespace]") print(vars(argparse_namespace)) - user_config = load_user_config(config_args.config_file) + user_config = load_user_config(config_args.dataset_config) print("\n[user_config]") print(user_config) diff --git a/train_db.py b/train_db.py index 6ce1367e..a3021177 100644 --- a/train_db.py +++ b/train_db.py @@ -38,9 +38,9 @@ def train(args): tokenizer = train_util.load_tokenizer(args) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) - if args.config_file is not None: - print(f"Load config file from {args.config_file}") - user_config = config_util.load_user_config(args.config_file) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) diff --git a/train_network.py b/train_network.py index 28c2c769..7aee6514 100644 --- a/train_network.py +++ b/train_network.py @@ -54,7 +54,7 @@ def train(args): cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None - use_user_config = args.config_file is not None + use_user_config = args.dataset_config is not None if args.seed is not None: set_seed(args.seed) @@ -64,8 +64,8 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) if use_user_config: - print(f"Load config file from {args.config_file}") - user_config = config_util.load_user_config(args.config_file) + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): print( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 669be7e0..d91a78ff 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -143,9 +143,9 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) - if args.config_file is not None: - print(f"Load config file from {args.config_file}") - user_config = config_util.load_user_config(args.config_file) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))