diff --git a/library/train_util.py b/library/train_util.py index 85b58d7e..0946c31d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -87,6 +87,7 @@ class BaseDataset(torch.utils.data.Dataset): self.enable_bucket = False self.min_bucket_reso = None self.max_bucket_reso = None + self.tag_frequency = {} self.bucket_info = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -545,6 +546,15 @@ class DreamBoothDataset(BaseDataset): cap_for_img = read_caption(img_path) captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {}) + self.tag_frequency[os.path.basename(dir)] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + if tag and not tag.isspace(): + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + return n_repeats, img_paths, captions print("prepare train images.") diff --git a/train_network.py b/train_network.py index 37a10f65..aebc4a40 100644 --- a/train_network.py +++ b/train_network.py @@ -335,6 +335,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment # will not be updated after training }