Added sample_images() for --sample_at_first

This commit is contained in:
Yuta Hayashibe
2023-10-29 22:08:42 +09:00
parent 5c150675bf
commit 2c731418ad
5 changed files with 36 additions and 1 deletions

View File

@@ -303,6 +303,9 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
# For --sample_at_first
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
for m in training_models:
m.train()

View File

@@ -373,6 +373,20 @@ def train(args):
# training loop
for epoch in range(num_train_epochs):
# For --sample_at_first
train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

View File

@@ -279,6 +279,8 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing

View File

@@ -750,6 +750,8 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):

View File

@@ -534,6 +534,20 @@ class TextualInversionTrainer:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
# For --sample_at_first
self.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
for text_encoder in text_encoders:
text_encoder.train()