Add fnmatch. Make max_norm no_grad

This commit is contained in:
rockerBOO
2025-01-23 14:24:57 -05:00
parent b0d0d43bfa
commit dfe1da4d36

View File

@@ -5,6 +5,7 @@
import math
import os
from fnmatch import fnmatch
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
@@ -1366,6 +1367,7 @@ class LoRANetwork(torch.nn.Module):
org_module._lora_restored = False
lora.enabled = False
@torch.no_grad()
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
downkeys = []
upkeys = []
@@ -1383,7 +1385,7 @@ class LoRANetwork(torch.nn.Module):
for i in range(len(downkeys)):
max_norm_value = max_norm
for key in scale_map.keys():
if key in downkeys[i]:
if fnmatch(downkeys[i], key):
max_norm_value = scale_map[key]
down = state_dict[downkeys[i]].to(device)
@@ -1409,7 +1411,7 @@ class LoRANetwork(torch.nn.Module):
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
scalednorm: torch.Tensor = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)