diff --git a/networks/lora.py b/networks/lora.py index 6af1a1f2..171db455 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,6 +5,7 @@ import math import os +from fnmatch import fnmatch from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL from transformers import CLIPTextModel @@ -1366,6 +1367,7 @@ class LoRANetwork(torch.nn.Module): org_module._lora_restored = False lora.enabled = False + @torch.no_grad() def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}): downkeys = [] upkeys = [] @@ -1383,7 +1385,7 @@ class LoRANetwork(torch.nn.Module): for i in range(len(downkeys)): max_norm_value = max_norm for key in scale_map.keys(): - if key in downkeys[i]: + if fnmatch(downkeys[i], key): max_norm_value = scale_map[key] down = state_dict[downkeys[i]].to(device) @@ -1409,7 +1411,7 @@ class LoRANetwork(torch.nn.Module): keys_scaled += 1 state_dict[upkeys[i]] *= sqrt_ratio state_dict[downkeys[i]] *= sqrt_ratio - scalednorm = updown.norm() * ratio + scalednorm: torch.Tensor = updown.norm() * ratio norms.append(scalednorm.item()) return keys_scaled, sum(norms) / len(norms), max(norms)