add tag dropout

This commit is contained in:
Kohya S
2023-02-09 21:35:27 +09:00
parent f7b5abb595
commit 3a72e6f003
5 changed files with 60 additions and 46 deletions

View File

@@ -38,7 +38,7 @@ def train(args):
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
@@ -230,8 +230,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
train_dataset.set_current_epoch(epoch + 1)
for m in training_models:
m.train()