diff --git a/README.md b/README.md index b32014f6..168482fd 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 31 Mar. 2023, 2023/3/31: + - Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`. + - Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354) + - `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。 + - `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354) - 30 Mar. 2023, 2023/3/30: - Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev! - See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details. diff --git a/library/model_util.py b/library/model_util.py index f3f236af..9b4405eb 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -841,7 +841,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): if is_safetensors(ckpt_path): checkpoint = None - state_dict = load_file(ckpt_path, device) + state_dict = load_file(ckpt_path) # , device) # may causes error else: checkpoint = torch.load(ckpt_path, map_location=device) if "state_dict" in checkpoint: diff --git a/train_network.py b/train_network.py index eb5301e2..2b824018 100644 --- a/train_network.py +++ b/train_network.py @@ -25,7 +25,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight # TODO 他のスクリプトと共通化する @@ -131,16 +131,21 @@ def train(args): # TODO: modify other training scripts as well if pi == accelerator.state.local_process_index: print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device) + + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator.device if args.lowram else "cpu" + ) + + # work on low-ram device + if args.lowram: + text_encoder.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + gc.collect() torch.cuda.empty_cache() accelerator.wait_for_everyone() - # work on low-ram device - # NOTE: this may not be necessary because we already load them on gpu - if args.lowram: - text_encoder.to(accelerator.device) - unet.to(accelerator.device) # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -197,7 +202,7 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -564,9 +569,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - + if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし