From 6805cafa9be066e549b3aa2b9f53ec03a91ffda6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 23 Jan 2024 20:17:19 +0900 Subject: [PATCH] fix TI training crashes in multigpu #1019 --- train_textual_inversion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1..f1cf6fbd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -505,7 +505,7 @@ class TextualInversionTrainer: if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -730,14 +730,13 @@ class TextualInversionTrainer: is_main_process = accelerator.is_main_process if is_main_process: text_encoder = accelerator.unwrap_model(text_encoder) + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() accelerator.end_training() if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)