diff --git a/README.md b/README.md index 5919f08c..a505c0b3 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Jan 17, 2024 / 2024/1/17: v0.8.1 + +- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). + - Text Encoders were not moved to CPU. +- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) + +- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。 + - Text Encoder が GPU に保持されたままになっていました。 +- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。 + ### Jan 15, 2024 / 2024/1/15: v0.8.0 - Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). diff --git a/sdxl_train_network.py b/sdxl_train_network.py index a35779d0..d810ce7d 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -95,8 +95,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device) - text_encoders[1].to(accelerator.device) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/train_network.py b/train_network.py index a75299cd..c2b7fbde 100644 --- a/train_network.py +++ b/train_network.py @@ -117,7 +117,7 @@ class NetworkTrainer: self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): for t_enc in text_encoders: - t_enc.to(accelerator.device) + t_enc.to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): input_ids = batch["input_ids"].to(accelerator.device) @@ -278,6 +278,7 @@ class NetworkTrainer: accelerator.wait_for_everyone() # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu self.cache_text_encoder_outputs_if_needed( args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype ) @@ -394,8 +395,7 @@ class NetworkTrainer: for t_enc in text_encoders: t_enc.requires_grad_(False) - # acceleratorがなんかよろしくやってくれるらしい - # TODO めちゃくちゃ冗長なのでコードを整理する + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: @@ -407,8 +407,8 @@ class NetworkTrainer: text_encoder = accelerator.prepare(text_encoder) text_encoders = [text_encoder] else: - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: @@ -685,7 +685,7 @@ class NetworkTrainer: if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers(