mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix seed for each dataset to make shuffling same
This commit is contained in:
@@ -4,6 +4,7 @@ from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import functools
|
||||
import random
|
||||
from textwrap import dedent, indent
|
||||
import json
|
||||
from pathlib import Path
|
||||
@@ -428,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
print(info)
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
|
||||
|
||||
@@ -277,7 +277,7 @@ class BaseSubset:
|
||||
caption_dropout_every_n_epochs: int,
|
||||
caption_tag_dropout_rate: float,
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float,int],
|
||||
token_warmup_step: Union[float, int],
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.num_repeats = num_repeats
|
||||
@@ -419,6 +419,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.current_step: int = 0
|
||||
self.max_train_steps: int = 0
|
||||
self.seed: int = 0
|
||||
|
||||
# augmentation
|
||||
self.aug_helper = AugHelper()
|
||||
@@ -435,8 +436,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
if not self.current_epoch == epoch:
|
||||
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||
self.shuffle_buckets()
|
||||
self.current_epoch = epoch
|
||||
|
||||
@@ -476,12 +480,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caption = ""
|
||||
else:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
print(subset.token_warmup_min, subset.token_warmup_step)
|
||||
if subset.token_warmup_step < 1:
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
||||
tokens_len = (
|
||||
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||
+ subset.token_warmup_min
|
||||
)
|
||||
tokens = tokens[:tokens_len]
|
||||
|
||||
def dropout_tags(tokens):
|
||||
@@ -667,6 +674,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
def shuffle_buckets(self):
|
||||
# set random seed for this epoch
|
||||
random.seed(self.seed + self.current_epoch)
|
||||
|
||||
random.shuffle(self.buckets_indices)
|
||||
self.bucket_manager.shuffle()
|
||||
|
||||
@@ -1073,7 +1083,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
@@ -1134,6 +1144,8 @@ class FineTuningDataset(BaseDataset):
|
||||
# path情報を作る
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
@@ -1330,9 +1342,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します")
|
||||
|
||||
epoch = 1
|
||||
steps = 1
|
||||
train_dataset.set_current_epoch(epoch)
|
||||
train_dataset.set_current_step(steps)
|
||||
|
||||
train_dataset.set_current_epoch(1)
|
||||
k = 0
|
||||
indices = list(range(len(train_dataset)))
|
||||
random.shuffle(indices)
|
||||
@@ -1358,6 +1374,15 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
cv2.destroyAllWindows()
|
||||
if k == 27:
|
||||
break
|
||||
if k == ord("e"):
|
||||
epoch += 1
|
||||
steps = len(train_dataset) * (epoch - 1)
|
||||
train_dataset.set_current_epoch(epoch)
|
||||
print(f"epoch: {epoch}")
|
||||
|
||||
steps += 1
|
||||
train_dataset.set_current_step(steps)
|
||||
|
||||
if k == 27 or (example["images"] is None and i >= 8):
|
||||
break
|
||||
|
||||
@@ -2001,7 +2026,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||
)
|
||||
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
if args.v_parameterization and not args.v2:
|
||||
@@ -2089,7 +2114,7 @@ def add_dataset_arguments(
|
||||
default=0,
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
@@ -3025,13 +3050,15 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||
|
||||
# endregion
|
||||
|
||||
# colalte_fn用 epoch,stepはmultiprocessing.Value
|
||||
|
||||
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||
class collater_class:
|
||||
def __init__(self,epoch,step):
|
||||
self.current_epoch=epoch
|
||||
self.current_step=step
|
||||
def __init__(self, epoch, step):
|
||||
self.current_epoch = epoch
|
||||
self.current_step = step
|
||||
|
||||
def __call__(self, examples):
|
||||
dataset = torch.utils.data.get_worker_info().dataset
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
return examples[0]
|
||||
|
||||
Reference in New Issue
Block a user