refactor config parse, feature to output config

This commit is contained in:
Kohya S
2023-03-19 10:11:11 +09:00
parent c3f9eb10f1
commit 83e102c691
5 changed files with 75 additions and 77 deletions

View File

@@ -408,24 +408,6 @@ if __name__ == "__main__":
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args()
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -3,6 +3,7 @@
import argparse
import importlib
import json
import pathlib
import re
import shutil
import time
@@ -23,6 +24,7 @@ import random
import hashlib
import subprocess
from io import BytesIO
import toml
from tqdm import tqdm
import torch
@@ -1889,7 +1891,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
)
parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter")
parser.add_argument(
"--config_file",
type=str,
default=None,
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
)
parser.add_argument(
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
if support_dreambooth:
# DreamBooth training
@@ -2016,6 +2026,66 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
)
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
if not args.config_file:
return args
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if args.output_config:
# check if config file exists
if os.path.exists(config_path):
print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}")
exit(1)
# convert args to dictionary
args_dict = vars(args)
# remove unnecessary keys
for key in ["config_file", "output_config"]:
if key in args_dict:
del args_dict[key]
# convert Path to str in dictionary
for key, value in args_dict.items():
if isinstance(value, pathlib.Path):
args_dict[key] = str(value)
# convert to toml and output to file
with open(config_path, "w") as f:
toml.dump(args_dict, f)
print(f"Saved config file / 設定ファイルを保存しました: {config_path}")
exit(0)
if not os.path.exists(config_path):
print(f"{config_path} not found.")
exit(1)
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
# combine all sections into one
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
# if value is not dict, save key and value as is
if not isinstance(section_dict, dict):
ignore_nesting_dict[section_name] = section_dict
continue
# if value is dict, save all key and value into one dict
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = os.path.splitext(args.config_file)[0]
print(args.config_file)
return args
# endregion
# region utils

View File

@@ -411,24 +411,6 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -695,24 +695,6 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -573,24 +573,6 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
args = train_util.read_config_from_file(args, parser)
train(args)