diff --git a/fine_tune.py b/fine_tune.py index e3cf247e..2b5255dc 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -408,24 +408,6 @@ if __name__ == "__main__": parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args) diff --git a/library/train_util.py b/library/train_util.py index 230985ef..9f541b6c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import importlib import json +import pathlib import re import shutil import time @@ -23,6 +24,7 @@ import random import hashlib import subprocess from io import BytesIO +import toml from tqdm import tqdm import torch @@ -1889,7 +1891,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", ) - parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", + ) + parser.add_argument( + "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" + ) if support_dreambooth: # DreamBooth training @@ -2016,6 +2026,66 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): ) +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if args.output_config: + # check if config file exists + if os.path.exists(config_path): + print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + exit(1) + + # convert args to dictionary + args_dict = vars(args) + + # remove unnecessary keys + for key in ["config_file", "output_config"]: + if key in args_dict: + del args_dict[key] + + # convert Path to str in dictionary + for key, value in args_dict.items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + + # convert to toml and output to file + with open(config_path, "w") as f: + toml.dump(args_dict, f) + + print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + exit(0) + + if not os.path.exists(config_path): + print(f"{config_path} not found.") + exit(1) + + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + print(args.config_file) + + return args + + # endregion # region utils diff --git a/train_db.py b/train_db.py index 2ad9c69c..c812bbc7 100644 --- a/train_db.py +++ b/train_db.py @@ -411,24 +411,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_network.py b/train_network.py index f78d8e47..ca0da112 100644 --- a/train_network.py +++ b/train_network.py @@ -695,24 +695,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e9fba76..f591dea1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -573,24 +573,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args)