From 26a81d075cd347f3c38a1178a7a8d68cd99f7060 Mon Sep 17 00:00:00 2001 From: hitomi Date: Wed, 1 Feb 2023 16:02:15 +0800 Subject: [PATCH] add --persistent_data_loader_workers option --- fine_tune.py | 2 +- library/train_util.py | 6 ++++-- train_db.py | 2 +- train_network.py | 4 ++-- train_textual_inversion.py | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 8e615203..a0ef978e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -163,7 +163,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: diff --git a/library/train_util.py b/library/train_util.py index 85b58d7e..22b84574 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -140,7 +140,7 @@ class BaseDataset(torch.utils.data.Dataset): if type(str_to) == list: caption = random.choice(str_to) else: - caption = str_to + caption = str_to else: caption = caption.replace(str_from, str_to) @@ -246,7 +246,7 @@ class BaseDataset(torch.utils.data.Dataset): mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error print(f"mean ar error (without repeats): {mean_img_ar_error}") - + # 参照用indexを作る self.buckets_indices: list(BucketBatchIndex) = [] @@ -1154,6 +1154,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--persistent_data_loader_workers", action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/train_db.py b/train_db.py index fe6fd4e6..bf25aae4 100644 --- a/train_db.py +++ b/train_db.py @@ -133,7 +133,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: diff --git a/train_network.py b/train_network.py index 37a10f65..afcc71b7 100644 --- a/train_network.py +++ b/train_network.py @@ -214,7 +214,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -224,7 +224,7 @@ def train(args): # lr schedulerを用意する # lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 35b4ede6..ea70195b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -217,7 +217,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: