mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Merge c692dbe14c into 1dae34b0af
This commit is contained in:
@@ -435,10 +435,13 @@ class TextualInversionTrainer:
|
||||
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
|
||||
|
||||
index_no_updates_list = []
|
||||
index_updates_list = []
|
||||
orig_embeds_params_list = []
|
||||
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
index_no_updates_list.append(index_no_updates)
|
||||
index_updates = ~index_no_updates
|
||||
index_updates_list.append(index_updates)
|
||||
|
||||
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
@@ -634,8 +637,31 @@ class TextualInversionTrainer:
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
with torch.no_grad():
|
||||
# normalize embeddings
|
||||
if args.clip_ti_decay:
|
||||
for text_encoder, index_updates in zip(
|
||||
text_encoders, index_updates_list
|
||||
):
|
||||
pre_norm = (
|
||||
text_encoder.get_input_embeddings()
|
||||
.weight[index_updates, :]
|
||||
.norm(dim=-1, keepdim=True)
|
||||
)
|
||||
|
||||
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
|
||||
text_encoder.get_input_embeddings().weight[
|
||||
index_updates
|
||||
] = torch.nn.functional.normalize(
|
||||
text_encoder.get_input_embeddings().weight[
|
||||
index_updates, :
|
||||
],
|
||||
dim=-1,
|
||||
) * (
|
||||
pre_norm + lambda_ * (args.clip_ti_decay - pre_norm)
|
||||
)
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
||||
text_encoders, orig_embeds_params_list, index_no_updates_list
|
||||
):
|
||||
@@ -818,6 +844,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_ti_decay",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Keep the norm of the textual inversion intact (0.4 is a good starting point)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user