This commit is contained in:
Guib Fuo
2026-04-01 01:26:27 +00:00
committed by GitHub

View File

@@ -435,10 +435,13 @@ class TextualInversionTrainer:
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
index_no_updates_list = []
index_updates_list = []
orig_embeds_params_list = []
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
index_no_updates_list.append(index_no_updates)
index_updates = ~index_no_updates
index_updates_list.append(index_updates)
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
@@ -634,8 +637,31 @@ class TextualInversionTrainer:
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
# normalize embeddings
if args.clip_ti_decay:
for text_encoder, index_updates in zip(
text_encoders, index_updates_list
):
pre_norm = (
text_encoder.get_input_embeddings()
.weight[index_updates, :]
.norm(dim=-1, keepdim=True)
)
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
text_encoder.get_input_embeddings().weight[
index_updates
] = torch.nn.functional.normalize(
text_encoder.get_input_embeddings().weight[
index_updates, :
],
dim=-1,
) * (
pre_norm + lambda_ * (args.clip_ti_decay - pre_norm)
)
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
@@ -818,6 +844,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--clip_ti_decay",
type=float,
default=None,
help="Keep the norm of the textual inversion intact (0.4 is a good starting point)",
)
return parser