mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Add fnmatch. Make max_norm no_grad
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from fnmatch import fnmatch
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
@@ -1366,6 +1367,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
@@ -1383,7 +1385,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for i in range(len(downkeys)):
|
||||
max_norm_value = max_norm
|
||||
for key in scale_map.keys():
|
||||
if key in downkeys[i]:
|
||||
if fnmatch(downkeys[i], key):
|
||||
max_norm_value = scale_map[key]
|
||||
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
@@ -1409,7 +1411,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
scalednorm: torch.Tensor = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
Reference in New Issue
Block a user