diff --git a/library/train_util.py b/library/train_util.py index 415f9b70..1a42d591 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1423,7 +1423,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") parser.add_argument("--noise_offset", type=float, default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") - + parser.add_argument("--lowram", action="store_true", + help="load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)") + if support_dreambooth: # DreamBooth training parser.add_argument("--prior_loss_weight", type=float, default=1.0, diff --git a/train_network.py b/train_network.py index 5983a7ef..e29e0174 100644 --- a/train_network.py +++ b/train_network.py @@ -156,9 +156,10 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - # unnecessary, but work on low-ram device - text_encoder.to("cuda") - unet.to("cuda") + # work on low-ram device + if args.lowram: + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)