@@ -9,11 +9,13 @@
import math
import os
from contextlib import contextmanager
from typing import Dict , List , Optional , Tuple , Type , Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
from torch import Tensor
import re
from library . utils import setup_logging
from library . sdxl_original_unet import SdxlUNet2DConditionModel
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
rank_dropout = None ,
module_dropout = None ,
split_dims : Optional [ List [ int ] ] = None ,
ggpo_beta : Optional [ float ] = None ,
ggpo_sigma : Optional [ float ] = None ,
) :
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
self . rank_dropout = rank_dropout
self . module_dropout = module_dropout
self . ggpo_sigma = ggpo_sigma
self . ggpo_beta = ggpo_beta
if self . ggpo_beta is not None and self . ggpo_sigma is not None :
self . combined_weight_norms = None
self . grad_norms = None
self . perturbation_norm_factor = 1.0 / math . sqrt ( org_module . weight . shape [ 0 ] )
self . initialize_norm_cache ( org_module . weight )
self . org_module_shape : tuple [ int ] = org_module . weight . shape
def apply_to ( self ) :
self . org_forward = self . org_module . forward
self . org_module . forward = self . forward
del self . org_module
def forward ( self , x ) :
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
lx = self . lora_up ( lx )
return org_forwarded + lx * self . multiplier * scale
# LoRA Gradient-Guided Perturbation Optimization
if self . training and self . ggpo_sigma is not None and self . ggpo_beta is not None and self . combined_weight_norms is not None and self . grad_norms is not None :
with torch . no_grad ( ) :
perturbation_scale = ( self . ggpo_sigma * torch . sqrt ( self . combined_weight_norms * * 2 ) ) + ( self . ggpo_beta * ( self . grad_norms * * 2 ) )
perturbation_scale_factor = ( perturbation_scale * self . perturbation_norm_factor ) . to ( self . device )
perturbation = torch . randn ( self . org_module_shape , dtype = self . dtype , device = self . device )
perturbation . mul_ ( perturbation_scale_factor )
perturbation_output = x @ perturbation . T # Result: (batch × n)
return org_forwarded + ( self . multiplier * scale * lx ) + perturbation_output
else :
return org_forwarded + lx * self . multiplier * scale
else :
lxs = [ lora_down ( x ) for lora_down in self . lora_down ]
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
return org_forwarded + torch . cat ( lxs , dim = - 1 ) * self . multiplier * scale
@torch.no_grad ( )
def initialize_norm_cache ( self , org_module_weight : Tensor ) :
# Choose a reasonable sample size
n_rows = org_module_weight . shape [ 0 ]
sample_size = min ( 1000 , n_rows ) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch . randperm ( n_rows ) [ : sample_size ]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight . to ( dtype = torch . float32 )
sampled_weights = weights_float32 [ indices ] . to ( device = self . device )
# Calculate sampled norms
sampled_norms = torch . norm ( sampled_weights , dim = 1 , keepdim = True )
# Store the mean norm as our estimate
self . org_weight_norm_estimate = sampled_norms . mean ( )
# Optional: store standard deviation for confidence intervals
self . org_weight_norm_std = sampled_norms . std ( )
# Free memory
del sampled_weights , weights_float32
@torch.no_grad ( )
def validate_norm_approximation ( self , org_module_weight : Tensor , verbose = True ) :
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = [ ]
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range ( 0 , org_module_weight . shape [ 0 ] , chunk_size ) :
end_idx = min ( i + chunk_size , org_module_weight . shape [ 0 ] )
chunk = org_module_weight [ i : end_idx ] . to ( device = self . device , dtype = self . dtype )
chunk_norms = torch . norm ( chunk , dim = 1 , keepdim = True )
true_norms . append ( chunk_norms . cpu ( ) )
del chunk
true_norms = torch . cat ( true_norms , dim = 0 )
true_mean_norm = true_norms . mean ( ) . item ( )
# Compare with our estimate
estimated_norm = self . org_weight_norm_estimate . item ( )
# Calculate error metrics
absolute_error = abs ( true_mean_norm - estimated_norm )
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose :
logger . info ( f " True mean norm: { true_mean_norm : .6f } " )
logger . info ( f " Estimated norm: { estimated_norm : .6f } " )
logger . info ( f " Absolute error: { absolute_error : .6f } " )
logger . info ( f " Relative error: { relative_error : .2f } % " )
return {
' true_mean_norm ' : true_mean_norm ,
' estimated_norm ' : estimated_norm ,
' absolute_error ' : absolute_error ,
' relative_error ' : relative_error
}
@torch.no_grad ( )
def update_norms ( self ) :
# Not running GGPO so not currently running update norms
if self . ggpo_beta is None or self . ggpo_sigma is None :
return
# only update norms when we are training
if self . training is False :
return
module_weights = self . lora_up . weight @ self . lora_down . weight
module_weights . mul ( self . scale )
self . weight_norms = torch . norm ( module_weights , dim = 1 , keepdim = True )
self . combined_weight_norms = torch . sqrt ( ( self . org_weight_norm_estimate * * 2 ) +
torch . sum ( module_weights * * 2 , dim = 1 , keepdim = True ) )
@torch.no_grad ( )
def update_grad_norms ( self ) :
if self . training is False :
print ( f " skipping update_grad_norms for { self . lora_name } " )
return
lora_down_grad = None
lora_up_grad = None
for name , param in self . named_parameters ( ) :
if name == " lora_down.weight " :
lora_down_grad = param . grad
elif name == " lora_up.weight " :
lora_up_grad = param . grad
# Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None :
with torch . autocast ( self . device . type ) :
approx_grad = self . scale * ( ( self . lora_up . weight @ lora_down_grad ) + ( lora_up_grad @ self . lora_down . weight ) )
self . grad_norms = torch . norm ( approx_grad , dim = 1 , keepdim = True )
@property
def device ( self ) :
return next ( self . parameters ( ) ) . device
@property
def dtype ( self ) :
return next ( self . parameters ( ) ) . dtype
class LoRAInfModule ( LoRAModule ) :
def __init__ (
@@ -420,6 +555,16 @@ def create_network(
if split_qkv is not None :
split_qkv = True if split_qkv == " True " else False
ggpo_beta = kwargs . get ( " ggpo_beta " , None )
ggpo_sigma = kwargs . get ( " ggpo_sigma " , None )
if ggpo_beta is not None :
ggpo_beta = float ( ggpo_beta )
if ggpo_sigma is not None :
ggpo_sigma = float ( ggpo_sigma )
# train T5XXL
train_t5xxl = kwargs . get ( " train_t5xxl " , False )
if train_t5xxl is not None :
@@ -449,6 +594,8 @@ def create_network(
in_dims = in_dims ,
train_double_block_indices = train_double_block_indices ,
train_single_block_indices = train_single_block_indices ,
ggpo_beta = ggpo_beta ,
ggpo_sigma = ggpo_sigma ,
verbose = verbose ,
)
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
in_dims : Optional [ List [ int ] ] = None ,
train_double_block_indices : Optional [ List [ bool ] ] = None ,
train_single_block_indices : Optional [ List [ bool ] ] = None ,
ggpo_beta : Optional [ float ] = None ,
ggpo_sigma : Optional [ float ] = None ,
verbose : Optional [ bool ] = False ,
) - > None :
super ( ) . __init__ ( )
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if ggpo_beta is not None and ggpo_sigma is not None :
logger . info ( f " LoRA-GGPO training sigma: { ggpo_sigma } beta: { ggpo_beta } " )
if self . split_qkv :
logger . info ( f " split qkv for LoRA " )
if self . train_blocks is not None :
logger . info ( f " train { self . train_blocks } blocks only " )
if train_t5xxl :
logger . info ( f " train T5XXL as well " )
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
rank_dropout = rank_dropout ,
module_dropout = module_dropout ,
split_dims = split_dims ,
ggpo_beta = ggpo_beta ,
ggpo_sigma = ggpo_sigma ,
)
loras . append ( lora )
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
for lora in self . text_encoder_loras + self . unet_loras :
lora . enabled = is_enabled
def update_norms ( self ) :
for lora in self . text_encoder_loras + self . unet_loras :
lora . update_norms ( )
def update_grad_norms ( self ) :
for lora in self . text_encoder_loras + self . unet_loras :
lora . update_grad_norms ( )
def grad_norms ( self ) - > Tensor :
grad_norms = [ ]
for lora in self . text_encoder_loras + self . unet_loras :
if hasattr ( lora , " grad_norms " ) and lora . grad_norms is not None :
grad_norms . append ( lora . grad_norms . mean ( dim = 0 ) )
return torch . stack ( grad_norms ) if len ( grad_norms ) > 0 else torch . tensor ( [ ] )
def weight_norms ( self ) - > Tensor :
weight_norms = [ ]
for lora in self . text_encoder_loras + self . unet_loras :
if hasattr ( lora , " weight_norms " ) and lora . weight_norms is not None :
weight_norms . append ( lora . weight_norms . mean ( dim = 0 ) )
return torch . stack ( weight_norms ) if len ( weight_norms ) > 0 else torch . tensor ( [ ] )
def combined_weight_norms ( self ) - > Tensor :
combined_weight_norms = [ ]
for lora in self . text_encoder_loras + self . unet_loras :
if hasattr ( lora , " combined_weight_norms " ) and lora . combined_weight_norms is not None :
combined_weight_norms . append ( lora . combined_weight_norms . mean ( dim = 0 ) )
return torch . stack ( combined_weight_norms ) if len ( combined_weight_norms ) > 0 else torch . tensor ( [ ] )
def load_weights ( self , file ) :
if os . path . splitext ( file ) [ 1 ] == " .safetensors " :
from safetensors . torch import load_file