mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Add LoRA-GGPO for Flux
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
@@ -27,6 +28,42 @@ logger = logging.getLogger(__name__)
|
||||
NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
@contextmanager
|
||||
def temp_random_seed(seed, device=None):
|
||||
"""
|
||||
Context manager that temporarily sets a specific random seed and then
|
||||
restores the original RNG state afterward.
|
||||
|
||||
Args:
|
||||
seed (int): The random seed to set temporarily
|
||||
device (torch.device, optional): The device to set the seed for.
|
||||
If None, will detect from the current context.
|
||||
"""
|
||||
# Save original RNG states
|
||||
original_cpu_rng_state = torch.get_rng_state()
|
||||
original_cuda_rng_states = None
|
||||
if torch.cuda.is_available():
|
||||
original_cuda_rng_states = torch.cuda.get_rng_state_all()
|
||||
|
||||
# Determine if we need to set CUDA seed
|
||||
set_cuda = False
|
||||
if device is not None:
|
||||
set_cuda = device.type == 'cuda'
|
||||
elif torch.cuda.is_available():
|
||||
set_cuda = True
|
||||
|
||||
try:
|
||||
# Set the temporary seed
|
||||
torch.manual_seed(seed)
|
||||
if set_cuda:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
yield
|
||||
finally:
|
||||
# Restore original RNG states
|
||||
torch.set_rng_state(original_cpu_rng_state)
|
||||
if torch.cuda.is_available() and original_cuda_rng_states is not None:
|
||||
torch.cuda.set_rng_state_all(original_cuda_rng_states)
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
@@ -44,6 +81,8 @@ class LoRAModule(torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -103,9 +142,16 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self._org_module_weight = self.org_module.weight.detach()
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
@@ -140,7 +186,15 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None:
|
||||
with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed):
|
||||
perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(self.perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
else:
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
|
||||
@@ -167,6 +221,58 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
# Not running GGPO so not currently running update norms
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
if self.lora_down.weight.requires_grad is not True:
|
||||
print(f"skipping update_norms for {self.lora_name}")
|
||||
return
|
||||
|
||||
lora_down_grad = None
|
||||
lora_up_grad = None
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if name == "lora_down.weight":
|
||||
lora_down_grad = param.grad
|
||||
elif name == "lora_up.weight":
|
||||
lora_up_grad = param.grad
|
||||
|
||||
with torch.autocast(self.device.type):
|
||||
module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight)
|
||||
org_device = self._org_module_weight.device
|
||||
org_dtype = self._org_module_weight.dtype
|
||||
org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype)
|
||||
combined_weight = org_weight + module_weights
|
||||
|
||||
self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True)
|
||||
|
||||
self._org_module_weight.to(device=org_device, dtype=org_dtype)
|
||||
|
||||
|
||||
# Calculate gradient norms if we have both gradients
|
||||
if lora_down_grad is not None and lora_up_grad is not None:
|
||||
with torch.autocast(self.device.type):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
@@ -420,6 +526,16 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
ggpo_beta = kwargs.get("ggpo_beta", None)
|
||||
ggpo_sigma = kwargs.get("ggpo_sigma", None)
|
||||
|
||||
if ggpo_beta is not None:
|
||||
ggpo_beta = float(ggpo_beta)
|
||||
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -449,6 +565,8 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -561,6 +679,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -599,10 +719,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -722,6 +848,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
@@ -790,6 +918,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def update_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_norms()
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
Reference in New Issue
Block a user