mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
1 Commits
feature-ch
...
galore
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9de721198a |
798
library/galore_optimizer.py
Normal file
798
library/galore_optimizer.py
Normal file
@@ -0,0 +1,798 @@
|
||||
# copy from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/
|
||||
# original license is Apache License 2.0
|
||||
import ast
|
||||
import math
|
||||
import warnings
|
||||
from typing import Callable, Dict, Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from library import train_util
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GaLoreProjector:
|
||||
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"):
|
||||
self.rank = rank
|
||||
self.verbose = verbose
|
||||
self.update_proj_gap = update_proj_gap
|
||||
self.scale = scale
|
||||
self.ortho_matrix = None
|
||||
self.proj_type = proj_type
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
|
||||
if self.proj_type == "std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "reverse_std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == "right":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == "left":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "full":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="full")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
|
||||
|
||||
return low_rank_grad
|
||||
|
||||
def project_back(self, low_rank_grad):
|
||||
|
||||
if self.proj_type == "std":
|
||||
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == "reverse_std":
|
||||
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
elif self.proj_type == "right":
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
elif self.proj_type == "left":
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == "full":
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
|
||||
|
||||
return full_rank_grad * self.scale
|
||||
|
||||
# svd decomposition
|
||||
def get_orthogonal_matrix(self, weights, rank, type):
|
||||
module_params = weights
|
||||
|
||||
if module_params.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = module_params.data.dtype
|
||||
original_device = module_params.data.device
|
||||
matrix = module_params.data.float()
|
||||
else:
|
||||
float_data = True
|
||||
matrix = module_params.data
|
||||
|
||||
U, s, Vh = torch.linalg.svd(matrix)
|
||||
|
||||
# make the smaller matrix always to be orthogonal matrix
|
||||
if type == "right":
|
||||
A = U[:, :rank] @ torch.diag(s[:rank])
|
||||
B = Vh[:rank, :]
|
||||
|
||||
if not float_data:
|
||||
B = B.to(original_device).type(original_type)
|
||||
return B
|
||||
elif type == "left":
|
||||
A = U[:, :rank]
|
||||
B = torch.diag(s[:rank]) @ Vh[:rank, :]
|
||||
if not float_data:
|
||||
A = A.to(original_device).type(original_type)
|
||||
return A
|
||||
elif type == "full":
|
||||
A = U[:, :rank]
|
||||
B = Vh[:rank, :]
|
||||
if not float_data:
|
||||
A = A.to(original_device).type(original_type)
|
||||
B = B.to(original_device).type(original_type)
|
||||
return [A, B]
|
||||
else:
|
||||
raise ValueError("type should be left, right or full")
|
||||
|
||||
|
||||
class GaLoreAdamW(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
||||
Regularization](https://arxiv.org/abs/1711.05101).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.001):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
eps (`float`, *optional*, defaults to 1e-06):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Decoupled weight decay to apply.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterable[nn.parameter.Parameter],
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
no_deprecation_warning: bool = False,
|
||||
):
|
||||
if not no_deprecation_warning:
|
||||
warnings.warn(
|
||||
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
|
||||
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
|
||||
" warning",
|
||||
FutureWarning,
|
||||
)
|
||||
require_version("torch>=1.5.0") # add_ with alpha
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
# State initialization
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
# compute norm gradient
|
||||
norm_grad = exp_avg / denom
|
||||
|
||||
# GaLore Projection Back
|
||||
if "rank" in group:
|
||||
norm_grad = state["projector"].project_back(norm_grad)
|
||||
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class GaLoreAdafactor(Optimizer):
|
||||
"""
|
||||
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
||||
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||
|
||||
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
||||
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
||||
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||
`relative_step=False`.
|
||||
|
||||
Arguments:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*):
|
||||
The external learning rate.
|
||||
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
|
||||
Regularization constants for square gradient and parameter scale respectively
|
||||
clip_threshold (`float`, *optional*, defaults to 1.0):
|
||||
Threshold of root mean square of final gradient update
|
||||
decay_rate (`float`, *optional*, defaults to -0.8):
|
||||
Coefficient used to compute running averages of square
|
||||
beta1 (`float`, *optional*):
|
||||
Coefficient used for computing running averages of gradient
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Weight decay (L2 penalty)
|
||||
scale_parameter (`bool`, *optional*, defaults to `True`):
|
||||
If True, learning rate is scaled by root mean square
|
||||
relative_step (`bool`, *optional*, defaults to `True`):
|
||||
If True, time-dependent learning rate is computed instead of external learning rate
|
||||
warmup_init (`bool`, *optional*, defaults to `False`):
|
||||
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
||||
|
||||
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
||||
|
||||
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
||||
|
||||
- Training without LR warmup or clip_threshold is not recommended.
|
||||
|
||||
- use scheduled LR warm-up to fixed LR
|
||||
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
||||
- Disable relative updates
|
||||
- Use scale_parameter=False
|
||||
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
||||
```
|
||||
|
||||
Others reported the following combination to work well:
|
||||
|
||||
```python
|
||||
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
```
|
||||
|
||||
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
||||
scheduler as following:
|
||||
|
||||
```python
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
||||
```
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
# replace AdamW with Adafactor
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
lr=1e-3,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
relative_step=False,
|
||||
scale_parameter=False,
|
||||
warmup_init=False,
|
||||
)
|
||||
```"""
|
||||
|
||||
# make default to be the same as trainer
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=None,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
warmup_init=False,
|
||||
):
|
||||
# scale_parameter=True,
|
||||
# relative_step=True,
|
||||
|
||||
require_version("torch>=1.5.0") # add_ with alpha
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
|
||||
if warmup_init and not relative_step:
|
||||
raise ValueError("`warmup_init=True` requires `relative_step=True`")
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"eps": eps,
|
||||
"clip_threshold": clip_threshold,
|
||||
"decay_rate": decay_rate,
|
||||
"beta1": beta1,
|
||||
"weight_decay": weight_decay,
|
||||
"scale_parameter": scale_parameter,
|
||||
"relative_step": relative_step,
|
||||
"warmup_init": warmup_init,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@staticmethod
|
||||
def _get_lr(param_group, param_state):
|
||||
rel_step_sz = param_group["lr"]
|
||||
if param_group["relative_step"]:
|
||||
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
||||
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
||||
param_scale = 1.0
|
||||
if param_group["scale_parameter"]:
|
||||
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
||||
return param_scale * rel_step_sz
|
||||
|
||||
@staticmethod
|
||||
def _get_options(param_group, param_shape):
|
||||
factored = len(param_shape) >= 2
|
||||
use_first_moment = param_group["beta1"] is not None
|
||||
return factored, use_first_moment
|
||||
|
||||
@staticmethod
|
||||
def _rms(tensor):
|
||||
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||
|
||||
@staticmethod
|
||||
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
||||
# copy from fairseq's adafactor implementation:
|
||||
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adafactor does not support sparse gradients.")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||
# State Initialization
|
||||
if "RMS" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state["step"] += 1
|
||||
state["RMS"] = self._rms(p_data_fp32)
|
||||
lr = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
update = (grad**2) + group["eps"][0]
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
update.mul_(lr)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
||||
update = exp_avg
|
||||
|
||||
# GaLore Projection Back
|
||||
if "rank" in group:
|
||||
update = state["projector"].project_back(update)
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
try:
|
||||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
except ImportError:
|
||||
# define a dummy Optimizer2State class
|
||||
class Optimizer2State(Optimizer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError("Please install bitsandbytes to use this optimizer")
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def prefetch_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def init_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def update_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def check_overrides(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def to_gpu(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def to_cpu(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class GaLoreAdamW8bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
is_paged=False,
|
||||
):
|
||||
super().__init__(
|
||||
"adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
overflows = []
|
||||
|
||||
if not self.initialized:
|
||||
self.check_overrides()
|
||||
self.to_gpu() # needed for fairseq pure fp16 training
|
||||
self.initialized = True
|
||||
|
||||
# if self.is_paged: self.page_mng.prefetch_all()
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
|
||||
if "weight_decay" in group and group["weight_decay"] > 0:
|
||||
# ensure that the weight decay is not applied to the norm grad
|
||||
group["weight_decay_saved"] = group["weight_decay"]
|
||||
group["weight_decay"] = 0
|
||||
|
||||
grad = state["projector"].project(p.grad, state["step"])
|
||||
|
||||
# suboptimal implementation
|
||||
p.saved_data = p.data.clone()
|
||||
p.data = grad.clone().to(p.data.dtype).to(p.data.device)
|
||||
p.data.zero_()
|
||||
p.grad = grad
|
||||
|
||||
if "state1" not in state:
|
||||
self.init_state(group, p, gindex, pindex)
|
||||
|
||||
self.prefetch_state(p)
|
||||
self.update_step(group, p, gindex, pindex)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# GaLore Projection Back
|
||||
if "rank" in group:
|
||||
p.data = p.saved_data.add_(state["projector"].project_back(p.data))
|
||||
|
||||
# apply weight decay
|
||||
if "weight_decay_saved" in group:
|
||||
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
|
||||
group["weight_decay"] = group["weight_decay_saved"]
|
||||
del group["weight_decay_saved"]
|
||||
|
||||
if self.is_paged:
|
||||
# all paged operation are asynchronous, we need
|
||||
# to sync to make sure all tensors are in the right state
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_optimizer(args, optimizer_type, trainable_params, training_models, num_processes):
|
||||
# trainable_params is list of dict, each dict contains "params" and "lr"
|
||||
# list may contain multiple dicts: [unet] or [unet, te1] or [unet, te1, te2]
|
||||
# block lr is not supported
|
||||
assert len(trainable_params) == len(training_models), "block lr is not supported"
|
||||
|
||||
lr = args.learning_rate
|
||||
|
||||
optimizer_kwargs = {}
|
||||
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
||||
for arg in args.optimizer_args:
|
||||
key, value = arg.split("=")
|
||||
value = ast.literal_eval(value)
|
||||
optimizer_kwargs[key] = value
|
||||
|
||||
rank = optimizer_kwargs.pop("rank", 128)
|
||||
update_proj_gap = optimizer_kwargs.pop("update_proj_gap", 50)
|
||||
galore_scale = optimizer_kwargs.pop("galore_scale", 1.0)
|
||||
proj_type = optimizer_kwargs.pop("proj_type", "std")
|
||||
weight_decay = optimizer_kwargs.get("weight_decay", 0.0) # do not pop, as it is used in the optimizer
|
||||
|
||||
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
||||
# target_modules_list = ["attn", "mlp"]
|
||||
target_modules_list = ["attn", "mlp", "ff"] # for SDXL U-Net
|
||||
|
||||
param_groups = []
|
||||
param_lr = {}
|
||||
for model, params in zip(training_models, trainable_params):
|
||||
logger.info(f"model: {model.__class__.__name__}")
|
||||
galore_params = []
|
||||
group_lr = params.get("lr", lr)
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if not isinstance(module, nn.Linear):
|
||||
continue
|
||||
|
||||
if not any(target_key in module_name for target_key in target_modules_list):
|
||||
continue
|
||||
|
||||
logger.info("enable GaLore for weights in module: " + module_name)
|
||||
galore_params.append(module.weight)
|
||||
|
||||
id_galore_params = [id(p) for p in galore_params]
|
||||
# make parameters without "rank" to another group
|
||||
regular_params = [p for p in params["params"] if id(p) not in id_galore_params]
|
||||
|
||||
# then call galore_adamw
|
||||
param_groups.append({"params": regular_params, "lr": group_lr})
|
||||
|
||||
param_groups.append(
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": rank,
|
||||
"update_proj_gap": update_proj_gap,
|
||||
"scale": galore_scale,
|
||||
"proj_type": proj_type,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
# record lr
|
||||
for p in regular_params + galore_params:
|
||||
param_lr[id(p)] = group_lr
|
||||
|
||||
# select optimizer
|
||||
scheduler = None
|
||||
if optimizer_type == "galore_adamw":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=lr, **optimizer_kwargs)
|
||||
elif optimizer_type == "galore_adafactor":
|
||||
beta1 = None if optimizer_kwargs.get("beta1", 0.0) == 0.0 else optimizer_kwargs.pop("beta1")
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=lr, beta1=beta1, **optimizer_kwargs)
|
||||
elif optimizer_type == "galore_adamw8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=lr, **optimizer_kwargs)
|
||||
elif optimizer_type == "galore_adamw8bit_per_layer":
|
||||
# TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
|
||||
optimizer_dict = {}
|
||||
all_params = []
|
||||
for params in trainable_params:
|
||||
all_params.extend(params["params"])
|
||||
|
||||
for p in all_params:
|
||||
if p.requires_grad:
|
||||
if id(p) in id_galore_params:
|
||||
optimizer_dict[p] = GaLoreAdamW8bit(
|
||||
[
|
||||
{
|
||||
"params": [p],
|
||||
"rank": rank,
|
||||
"update_proj_gap": update_proj_gap * 2,
|
||||
"scale": galore_scale,
|
||||
"proj_type": proj_type,
|
||||
}
|
||||
],
|
||||
lr=param_lr[id(p)],
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=param_lr[id(p)], weight_decay=weight_decay)
|
||||
|
||||
# get scheduler dict
|
||||
# scheduler needs accelerate.prepare?
|
||||
scheduler_dict = {}
|
||||
for p in all_params:
|
||||
if p.requires_grad:
|
||||
scheduler_dict[p] = train_util.get_scheduler_fix(args, optimizer_dict[p], num_processes)
|
||||
|
||||
def optimizer_hook(p):
|
||||
if p.grad is None:
|
||||
return
|
||||
optimizer_dict[p].step()
|
||||
optimizer_dict[p].zero_grad()
|
||||
scheduler_dict[p].step()
|
||||
|
||||
# Register the hook onto every parameter
|
||||
for p in all_params:
|
||||
if p.requires_grad:
|
||||
p.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
# make dummy scheduler and optimizer
|
||||
class DummyScheduler:
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
class DummyOptimizer:
|
||||
def __init__(self, optimizer_dict):
|
||||
self.optimizer_dict = optimizer_dict
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
pass
|
||||
|
||||
scheduler = DummyScheduler(optimizer_dict[all_params[0]])
|
||||
optimizer = DummyOptimizer(optimizer_dict)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
|
||||
|
||||
return optimizer, scheduler
|
||||
@@ -3671,7 +3671,7 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("8bit".lower()):
|
||||
elif optimizer_type.endswith("8bit".lower()) and not optimizer_type.startswith("GaLore".lower()):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
@@ -3880,6 +3880,11 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.startswith("GaLore".lower()):
|
||||
logger.info(f"use GaLore optimizer | {optimizer_kwargs}")
|
||||
optimizer = "galore"
|
||||
return None, None, optimizer
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
|
||||
@@ -11,6 +11,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -378,7 +379,17 @@ def train(args):
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
lr_scheduler = None
|
||||
if optimizer == "galore":
|
||||
from library import galore_optimizer
|
||||
|
||||
# if lr_scheduler is not layerwise, it is None. if layerwise, it is a dummy scheduler
|
||||
optimizer, lr_scheduler = galore_optimizer.get_optimizer(
|
||||
args, args.optimizer_type, params_to_optimize, training_models, accelerator.num_processes
|
||||
)
|
||||
|
||||
if lr_scheduler is None:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
|
||||
Reference in New Issue
Block a user