mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
add freeze block lr
This commit is contained in:
@@ -3061,6 +3061,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_last_layers_to_freeze",
|
||||
type=int,
|
||||
default=None,
|
||||
help="num_last_layers_to_freeze",
|
||||
)
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
@@ -5598,6 +5604,20 @@ def sample_image_inference(
|
||||
pass
|
||||
|
||||
|
||||
def freeze_blocks_lr(model, num_last_layers_to_freeze, base_lr, block_name="x_block"):
|
||||
bottom_layers = list(model.children())[-num_last_layers_to_freeze:]
|
||||
|
||||
params_to_optimize = []
|
||||
|
||||
for layer in reversed(bottom_layers):
|
||||
for name, param in layer.named_parameters():
|
||||
if block_name in name:
|
||||
params_to_optimize.append({"params": [param], "lr": 0.0})
|
||||
else:
|
||||
params_to_optimize.append({"params": [param], "lr": base_lr})
|
||||
|
||||
return params_to_optimize
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ def train(args):
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
|
||||
# assert (
|
||||
# not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||
# ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||
@@ -283,7 +282,8 @@ def train(args):
|
||||
# if train_unet:
|
||||
training_models.append(mmdit)
|
||||
# if block_lrs is None:
|
||||
params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
|
||||
params_to_optimize = train_util.freeze_blocks_lr(mmdit, args.num_last_layers_to_freeze,args.args.learning_rate)
|
||||
# params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
|
||||
# else:
|
||||
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user