Merge branch 'sd3' into val-loss-improvement

This commit is contained in:
Kohya S
2025-02-26 21:09:10 +09:00
3 changed files with 8 additions and 11 deletions

View File

@@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# rename or add position_ids
# remove position_ids for newer transformer, which causes error :(
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
if ANOTHER_POSITION_IDS_KEY in new_sd:
# waifu diffusion v1.4
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
if "text_model.embeddings.position_ids" in new_sd:
del new_sd["text_model.embeddings.position_ids"]
return new_sd

View File

@@ -344,8 +344,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")

View File

@@ -1207,10 +1207,6 @@ class NetworkTrainer:
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
@@ -1309,6 +1305,10 @@ class NetworkTrainer:
clean_memory_on_device(accelerator.device)
progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)