diff --git a/fine_tune.py b/fine_tune.py index a75c4673..86f5b253 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -18,9 +18,11 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) import library.train_util as train_util import library.config_util as config_util @@ -41,6 +43,7 @@ from library.custom_train_functions import ( def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -227,7 +230,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -291,7 +296,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -471,6 +476,7 @@ def train(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) # モジュール指定がないのがちょっと気持ち悪い / bit weird that this does not have module prefix train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -479,7 +485,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--learning_rate_te", diff --git a/library/utils.py b/library/utils.py index 1b5d7eb6..af0563df 100644 --- a/library/utils.py +++ b/library/utils.py @@ -8,18 +8,53 @@ def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() -def setup_logging(log_level=logging.INFO): - if logging.root.handlers: # Already configured - return +def add_logging_arguments(parser): + parser.add_argument( + "--console_log_level", + type=str, + default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", + ) + parser.add_argument( + "--console_log_file", + type=str, + default=None, + help="Log to a file instead of stdout / 標準出力ではなくファイルにログを出力する", + ) + parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") - try: - from rich.logging import RichHandler - handler = RichHandler() - except ImportError: - print("rich is not installed, using basic logging") - handler = logging.StreamHandler(sys.stdout) # same as print - handler.propagate = False +def setup_logging(args=None, log_level=None, reset=False): + if logging.root.handlers: + if reset: + # remove all handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + else: + return + + if log_level is None and args is not None: + log_level = args.console_log_level + if log_level is None: + log_level = "INFO" + log_level = getattr(logging, log_level) + + if args is not None and args.console_log_file: + handler = logging.FileHandler(args.console_log_file, mode="w") + else: + handler = None + if not args or not args.console_log_simple: + try: + from rich.logging import RichHandler + + handler = RichHandler() + except ImportError: + print("rich is not installed, using basic logging") + + if handler is None: + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False formatter = logging.Formatter( fmt="%(message)s",