diff --git a/library/config_util.py b/library/config_util.py index b1543f63..97bbb4a8 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -4,6 +4,7 @@ from dataclasses import ( dataclass, ) import functools +import random from textwrap import dedent, indent import json from pathlib import Path @@ -428,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu print(info) # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): print(f"[Dataset {i}]") dataset.make_buckets() + dataset.set_seed(seed) return DatasetGroup(datasets) diff --git a/library/train_util.py b/library/train_util.py index 6a5679d3..2d93b126 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -277,7 +277,7 @@ class BaseSubset: caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, token_warmup_min: int, - token_warmup_step: Union[float,int], + token_warmup_step: Union[float, int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -419,6 +419,7 @@ class BaseDataset(torch.utils.data.Dataset): self.current_step: int = 0 self.max_train_steps: int = 0 + self.seed: int = 0 # augmentation self.aug_helper = AugHelper() @@ -435,8 +436,11 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_seed(self, seed): + self.seed = seed + def set_current_epoch(self, epoch): - if not self.current_epoch == epoch: + if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする self.shuffle_buckets() self.current_epoch = epoch @@ -476,12 +480,15 @@ class BaseDataset(torch.utils.data.Dataset): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + print(subset.token_warmup_min, subset.token_warmup_step) if subset.token_warmup_step < 1: subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens_len = ( + math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + + subset.token_warmup_min + ) tokens = tokens[:tokens_len] def dropout_tags(tokens): @@ -667,6 +674,9 @@ class BaseDataset(torch.utils.data.Dataset): self._length = len(self.buckets_indices) def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() @@ -1073,7 +1083,7 @@ class DreamBoothDataset(BaseDataset): self.register_image(info, subset) n += info.num_repeats else: - info.num_repeats += 1 + info.num_repeats += 1 # rewrite registered info n += 1 if n >= num_train_images: break @@ -1134,6 +1144,8 @@ class FineTuningDataset(BaseDataset): # path情報を作る if os.path.exists(image_key): abs_path = image_key + elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" else: npz_path = os.path.join(subset.image_dir, image_key + ".npz") if os.path.exists(npz_path): @@ -1330,9 +1342,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") + print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します") + + epoch = 1 + steps = 1 + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) - train_dataset.set_current_epoch(1) k = 0 indices = list(range(len(train_dataset))) random.shuffle(indices) @@ -1358,6 +1374,15 @@ def debug_dataset(train_dataset, show_input_ids=False): cv2.destroyAllWindows() if k == 27: break + if k == ord("e"): + epoch += 1 + steps = len(train_dataset) * (epoch - 1) + train_dataset.set_current_epoch(epoch) + print(f"epoch: {epoch}") + + steps += 1 + train_dataset.set_current_step(steps) + if k == 27 or (example["images"] is None and i >= 8): break @@ -2001,7 +2026,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - + def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: @@ -2089,7 +2114,7 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) - + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに @@ -3025,13 +3050,15 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion -# colalte_fn用 epoch,stepはmultiprocessing.Value + +# collate_fn用 epoch,stepはmultiprocessing.Value class collater_class: - def __init__(self,epoch,step): - self.current_epoch=epoch - self.current_step=step + def __init__(self, epoch, step): + self.current_epoch = epoch + self.current_step = step + def __call__(self, examples): dataset = torch.utils.data.get_worker_info().dataset dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0]