Merge pull request #104 from space-nuko/caption-frequency-metadata

Add tag frequency metadata
This commit is contained in:
Kohya S
2023-01-31 20:56:15 +09:00
committed by GitHub
2 changed files with 11 additions and 0 deletions

View File

@@ -87,6 +87,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None
self.tag_frequency = {}
self.bucket_info = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -545,6 +546,15 @@ class DreamBoothDataset(BaseDataset):
cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {})
self.tag_frequency[os.path.basename(dir)] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
return n_repeats, img_paths, captions
print("prepare train images.")

View File

@@ -335,6 +335,7 @@ def train(args):
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training
}