diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8575698d..ad26870f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -435,10 +435,13 @@ class TextualInversionTrainer: text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] + index_updates_list = [] orig_embeds_params_list = [] for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders): index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates_list.append(index_no_updates) + index_updates = ~index_no_updates + index_updates_list.append(index_updates) # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() @@ -634,8 +637,31 @@ class TextualInversionTrainer: lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - # Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad(): + # normalize embeddings + if args.clip_ti_decay: + for text_encoder, index_updates in zip( + text_encoders, index_updates_list + ): + pre_norm = ( + text_encoder.get_input_embeddings() + .weight[index_updates, :] + .norm(dim=-1, keepdim=True) + ) + + lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) + text_encoder.get_input_embeddings().weight[ + index_updates + ] = torch.nn.functional.normalize( + text_encoder.get_input_embeddings().weight[ + index_updates, : + ], + dim=-1, + ) * ( + pre_norm + lambda_ * (args.clip_ti_decay - pre_norm) + ) + + # Let's make sure we don't update any embedding weights besides the newly added token for text_encoder, orig_embeds_params, index_no_updates in zip( text_encoders, orig_embeds_params_list, index_no_updates_list ): @@ -818,6 +844,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--clip_ti_decay", + type=float, + default=None, + help="Keep the norm of the textual inversion intact (0.4 is a good starting point)", + ) return parser