improve OFT implementation closes #944

This commit is contained in:
Kohya S
2024-09-13 22:37:11 +09:00
parent 43ad73860d
commit 93d9fbf607
4 changed files with 88 additions and 37 deletions

View File

@@ -143,7 +143,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- transformers, accelerate and huggingface_hub are updated.
- If you encounter any issues, please report them.
- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
- Improvements in OFT (Orthogonal Finetuning) Implementation
1. Optimization of Calculation Order:
- Changed the calculation order in the forward method from (Wx)R to W(xR).
- This has improved computational efficiency and processing speed.
2. Correction of Bias Application:
- In the previous implementation, R was incorrectly applied to the bias.
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
3. Efficiency Enhancement in Matrix Operations:
- Introduced einsum in both the forward and merge_to methods.
- This has optimized matrix operations, resulting in further speed improvements.
4. Proper Handling of Data Types:
- Improved to use torch.float32 during calculations and convert results back to the original data type.
- This maintains precision while ensuring compatibility with the original model.
5. Unified Processing for Conv2d and Linear Layers:
- Implemented a consistent method for applying OFT to both layer types.
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
- Additional Information
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
* Performance Improvement: Training speed has been improved by approximately 30%.
* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.

View File

@@ -86,7 +86,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")

View File

@@ -18,7 +18,7 @@ def main(file):
keys = list(sd.keys())
for key in keys:
if "lora_up" in key or "lora_down" in key:
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")

View File

@@ -4,13 +4,17 @@ import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
import einops
from transformers import CLIPTextModel
import numpy as np
import torch
import torch.nn.functional as F
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module):
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.constraint = alpha * out_dim
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu
self.out_dim = out_dim
self.shape = org_module.weight.shape
@@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module):
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
R = torch.block_diag(*block_R_weighted)
return R
if self.I.device != block_Q.device:
self.I = self.I.to(block_Q.device)
I = self.I
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
block_R_weighted = self.multiplier * (block_R - I) + I
return block_R_weighted
def forward(self, x, scale=None):
x = self.org_forward(x)
if self.multiplier == 0.0:
return x
return self.org_forward(x)
org_module = self.org_module[0]
org_dtype = x.dtype
R = self.get_weight().to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)
x = x.permute(0, 3, 1, 2)
else:
x = torch.matmul(x, R)
return x
R = self.get_weight().to(torch.float32)
W = org_module.weight.to(torch.float32)
if len(W.shape) == 4: # Conv2d
W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size)
RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped)
RW = einops.rearrange(RW, "k m ... -> (k m) ...")
result = F.conv2d(
x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups
)
else: # Linear
W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size)
RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped)
RW = einops.rearrange(RW, "k m p -> (k m) p")
result = F.linear(x, RW.to(org_dtype), org_module.bias)
return result
class OFTInfModule(OFTModule):
@@ -115,18 +133,19 @@ class OFTInfModule(OFTModule):
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None, sign=1):
R = self.get_weight(multiplier) * sign
def merge_to(self, multiplier=None):
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
R = R.to(org_weight.device, dtype=org_weight.dtype)
org_weight = org_sd["weight"].to(torch.float32)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
R = self.get_weight(multiplier).to(torch.float32)
weight = org_weight.reshape(self.num_blocks, self.block_size, -1)
weight = torch.einsum("k n m, k n ... -> k m ...", R, weight)
weight = weight.reshape(org_weight.shape)
# convert back to original dtype
weight = weight.to(org_sd["weight"].dtype)
# set weight to org_module
org_sd["weight"] = weight
@@ -145,8 +164,16 @@ def create_network(
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
if network_alpha is None: # should be set
logger.info(
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
)
network_alpha = 1e-3
elif network_alpha >= 1:
logger.warning(
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
)
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
@@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and param.dim() == 4:
if has_conv2d is None and "in_layers_2" in name:
has_conv2d = True
if all_linear is None:
if param.dim() == 3 and "attn" not in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None:
if all_linear is None and "_ff_" in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None:
break
if has_conv2d is None:
has_conv2d = False
@@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module):
self.alpha = alpha
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}"
)
# create module instances