diff --git a/library/train_util.py b/library/train_util.py index 96d32e3b..d2ab95a0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3061,6 +3061,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_layers_to_freeze", + type=int, + default=None, + help="num_last_layers_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5598,6 +5604,20 @@ def sample_image_inference( pass +def freeze_blocks_lr(model, num_last_layers_to_freeze, base_lr, block_name="x_block"): + bottom_layers = list(model.children())[-num_last_layers_to_freeze:] + + params_to_optimize = [] + + for layer in reversed(bottom_layers): + for name, param in layer.named_parameters(): + if block_name in name: + params_to_optimize.append({"params": [param], "lr": 0.0}) + else: + params_to_optimize.append({"params": [param], "lr": base_lr}) + + return params_to_optimize + # endregion diff --git a/sd3_train.py b/sd3_train.py index 9a7de239..8bff476a 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,7 +60,6 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" - # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -283,7 +282,8 @@ def train(args): # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize = train_util.freeze_blocks_lr(mmdit, args.num_last_layers_to_freeze,args.args.learning_rate) + # params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))