Add logging arguments and update logging setup

This commit is contained in:
Kohya S
2024-02-04 20:44:10 +09:00
parent 6279b33736
commit efd3b58973
2 changed files with 57 additions and 14 deletions

View File

@@ -18,9 +18,11 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library.utils import setup_logging
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
@@ -41,6 +43,7 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -227,7 +230,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -291,7 +296,7 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
@@ -471,6 +476,7 @@ def train(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser) # モジュール指定がないのがちょっと気持ち悪い / bit weird that this does not have module prefix
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
@@ -479,7 +485,9 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",

View File

@@ -8,18 +8,53 @@ def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
def setup_logging(log_level=logging.INFO):
if logging.root.handlers: # Already configured
return
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
type=str,
default=None,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
)
parser.add_argument(
"--console_log_file",
type=str,
default=None,
help="Log to a file instead of stdout / 標準出力ではなくファイルにログを出力する",
)
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
try:
from rich.logging import RichHandler
handler = RichHandler()
except ImportError:
print("rich is not installed, using basic logging")
handler = logging.StreamHandler(sys.stdout) # same as print
handler.propagate = False
def setup_logging(args=None, log_level=None, reset=False):
if logging.root.handlers:
if reset:
# remove all handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
else:
return
if log_level is None and args is not None:
log_level = args.console_log_level
if log_level is None:
log_level = "INFO"
log_level = getattr(logging, log_level)
if args is not None and args.console_log_file:
handler = logging.FileHandler(args.console_log_file, mode="w")
else:
handler = None
if not args or not args.console_log_simple:
try:
from rich.logging import RichHandler
handler = RichHandler()
except ImportError:
print("rich is not installed, using basic logging")
if handler is None:
handler = logging.StreamHandler(sys.stdout) # same as print
handler.propagate = False
formatter = logging.Formatter(
fmt="%(message)s",