add --persistent_data_loader_workers option

This commit is contained in:
hitomi
2023-02-01 16:02:15 +08:00
parent 4cabb37977
commit 26a81d075c
5 changed files with 9 additions and 7 deletions

View File

@@ -163,7 +163,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None: