fix seed for each dataset to make shuffling same

This commit is contained in:
Kohya S
2023-03-26 22:17:03 +09:00
parent 559a1aeeda
commit 14891523ce
2 changed files with 45 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import (
dataclass,
)
import functools
import random
from textwrap import dedent, indent
import json
from pathlib import Path
@@ -428,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
print(info)
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets)

View File

@@ -277,7 +277,7 @@ class BaseSubset:
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float,int],
token_warmup_step: Union[float, int],
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
@@ -419,6 +419,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
# augmentation
self.aug_helper = AugHelper()
@@ -435,8 +436,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_seed(self, seed):
self.seed = seed
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch:
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
@@ -476,12 +480,15 @@ class BaseDataset(torch.utils.data.Dataset):
caption = ""
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
print(subset.token_warmup_min, subset.token_warmup_step)
if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
def dropout_tags(tokens):
@@ -667,6 +674,9 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()
@@ -1073,7 +1083,7 @@ class DreamBoothDataset(BaseDataset):
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
@@ -1134,6 +1144,8 @@ class FineTuningDataset(BaseDataset):
# path情報を作る
if os.path.exists(image_key):
abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
@@ -1330,9 +1342,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します")
epoch = 1
steps = 1
train_dataset.set_current_epoch(epoch)
train_dataset.set_current_step(steps)
train_dataset.set_current_epoch(1)
k = 0
indices = list(range(len(train_dataset)))
random.shuffle(indices)
@@ -1358,6 +1374,15 @@ def debug_dataset(train_dataset, show_input_ids=False):
cv2.destroyAllWindows()
if k == 27:
break
if k == ord("e"):
epoch += 1
steps = len(train_dataset) * (epoch - 1)
train_dataset.set_current_epoch(epoch)
print(f"epoch: {epoch}")
steps += 1
train_dataset.set_current_step(steps)
if k == 27 or (example["images"] is None and i >= 8):
break
@@ -2001,7 +2026,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)
def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2:
@@ -2089,7 +2114,7 @@ def add_dataset_arguments(
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -3025,13 +3050,15 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion
# colalte_fn用 epoch,stepはmultiprocessing.Value
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self,epoch,step):
self.current_epoch=epoch
self.current_step=step
def __init__(self, epoch, step):
self.current_epoch = epoch
self.current_step = step
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]