mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge pull request #907 from shirayu/add_option_sample_at_first
Add option --sample_at_first
This commit is contained in:
@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
|
||||
|
||||
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
|
||||
|
||||
- `--sample_at_first`
|
||||
|
||||
学習開始前にサンプル出力します。学習前との比較ができます。
|
||||
|
||||
- `--sample_prompts`
|
||||
|
||||
サンプル出力用プロンプトのファイルを指定します。
|
||||
|
||||
@@ -303,6 +303,9 @@ def train(args):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
|
||||
@@ -2979,6 +2979,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_every_n_epochs",
|
||||
type=int,
|
||||
@@ -4576,15 +4579,19 @@ def sample_images_common(
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
|
||||
@@ -466,6 +466,19 @@ def train(args):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
)
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
|
||||
@@ -373,6 +373,20 @@ def train(args):
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
@@ -279,6 +279,8 @@ def train(args):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
unet.train()
|
||||
# train==True is required to enable gradient_checkpointing
|
||||
|
||||
@@ -750,6 +750,8 @@ class NetworkTrainer:
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
@@ -534,6 +534,20 @@ class TextualInversionTrainer:
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user