mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
refactor config parse, feature to output config
This commit is contained in:
20
fine_tune.py
20
fine_tune.py
@@ -408,24 +408,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
@@ -23,6 +24,7 @@ import random
|
||||
import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -1889,7 +1891,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
|
||||
)
|
||||
|
||||
parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter")
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
@@ -2016,6 +2026,66 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
if not args.config_file:
|
||||
return args
|
||||
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
|
||||
if args.output_config:
|
||||
# check if config file exists
|
||||
if os.path.exists(config_path):
|
||||
print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}")
|
||||
exit(1)
|
||||
|
||||
# convert args to dictionary
|
||||
args_dict = vars(args)
|
||||
|
||||
# remove unnecessary keys
|
||||
for key in ["config_file", "output_config"]:
|
||||
if key in args_dict:
|
||||
del args_dict[key]
|
||||
|
||||
# convert Path to str in dictionary
|
||||
for key, value in args_dict.items():
|
||||
if isinstance(value, pathlib.Path):
|
||||
args_dict[key] = str(value)
|
||||
|
||||
# convert to toml and output to file
|
||||
with open(config_path, "w") as f:
|
||||
toml.dump(args_dict, f)
|
||||
|
||||
print(f"Saved config file / 設定ファイルを保存しました: {config_path}")
|
||||
exit(0)
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
print(f"{config_path} not found.")
|
||||
exit(1)
|
||||
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
# combine all sections into one
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
# if value is not dict, save key and value as is
|
||||
if not isinstance(section_dict, dict):
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = os.path.splitext(args.config_file)[0]
|
||||
print(args.config_file)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region utils
|
||||
|
||||
20
train_db.py
20
train_db.py
@@ -411,24 +411,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -695,24 +695,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -573,24 +573,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
Reference in New Issue
Block a user