update grad hook creation to fix TE lr in sd3 fine tuning

This commit is contained in:
Kohya S
2024-11-14 19:33:12 +09:00
parent 2cb7a6db02
commit 2bb0f547d7
3 changed files with 22 additions and 13 deletions

View File

@@ -80,7 +80,9 @@ def train(args):
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
) or not args.cpu_offload_checkpointing, (
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -480,13 +482,16 @@ def train(args):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def grad_hook(tensor: torch.Tensor, param_group=param_group):
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
optimizer.step_param(tensor, p_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(grad_hook)
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers

View File

@@ -5913,6 +5913,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names.append("unet")
names.append("text_encoder1")
names.append("text_encoder2")
names.append("text_encoder3") # SD3
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)

View File

@@ -606,13 +606,16 @@ def train(args):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def grad_hook(tensor: torch.Tensor, param_group=param_group):
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
optimizer.step_param(tensor, p_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(grad_hook)
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers