gelore impl

This commit is contained in:
Kohya S
2024-03-07 23:31:57 +09:00
parent 14c9372a38
commit 9de721198a
3 changed files with 816 additions and 2 deletions

798
library/galore_optimizer.py Normal file
View File

@@ -0,0 +1,798 @@
# copy from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/
# original license is Apache License 2.0
import ast
import math
import warnings
from typing import Callable, Dict, Iterable, Tuple
import torch
from torch import nn
from torch.optim import Optimizer
from transformers.utils.versions import require_version
from library import train_util
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class GaLoreProjector:
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type
def project(self, full_rank_grad, iter):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "right":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "left":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "full":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="full")
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
return low_rank_grad
def project_back(self, low_rank_grad):
if self.proj_type == "std":
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "reverse_std":
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "right":
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "left":
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "full":
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
return full_rank_grad * self.scale
# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
U, s, Vh = torch.linalg.svd(matrix)
# make the smaller matrix always to be orthogonal matrix
if type == "right":
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type == "left":
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
elif type == "full":
A = U[:, :rank]
B = Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
B = B.to(original_device).type(original_type)
return [A, B]
else:
raise ValueError("type should be left, right or full")
class GaLoreAdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
no_deprecation_warning: bool = False,
):
if not no_deprecation_warning:
warnings.warn(
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
" warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
grad = state["projector"].project(grad, state["step"])
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# compute norm gradient
norm_grad = exp_avg / denom
# GaLore Projection Back
if "rank" in group:
norm_grad = state["projector"].project_back(norm_grad)
p.add_(norm_grad, alpha=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss
class GaLoreAdafactor(Optimizer):
"""
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
`relative_step=False`.
Arguments:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*):
The external learning rate.
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
Regularization constants for square gradient and parameter scale respectively
clip_threshold (`float`, *optional*, defaults to 1.0):
Threshold of root mean square of final gradient update
decay_rate (`float`, *optional*, defaults to -0.8):
Coefficient used to compute running averages of square
beta1 (`float`, *optional*):
Coefficient used for computing running averages of gradient
weight_decay (`float`, *optional*, defaults to 0.0):
Weight decay (L2 penalty)
scale_parameter (`bool`, *optional*, defaults to `True`):
If True, learning rate is scaled by root mean square
relative_step (`bool`, *optional*, defaults to `True`):
If True, time-dependent learning rate is computed instead of external learning rate
warmup_init (`bool`, *optional*, defaults to `False`):
Time-dependent learning rate computation depends on whether warm-up initialization is being used
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
- Training without LR warmup or clip_threshold is not recommended.
- use scheduled LR warm-up to fixed LR
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
- Disable relative updates
- Use scale_parameter=False
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
Example:
```python
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
```
Others reported the following combination to work well:
```python
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
```
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
scheduler as following:
```python
from transformers.optimization import Adafactor, AdafactorSchedule
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
```
Usage:
```python
# replace AdamW with Adafactor
optimizer = Adafactor(
model.parameters(),
lr=1e-3,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
```"""
# make default to be the same as trainer
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
scale_parameter=False,
relative_step=False,
warmup_init=False,
):
# scale_parameter=True,
# relative_step=True,
require_version("torch>=1.5.0") # add_ with alpha
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")
defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta1": beta1,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step": relative_step,
"warmup_init": warmup_init,
}
super().__init__(params, defaults)
@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
grad = state["projector"].project(grad, state["step"])
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if "RMS" not in state:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
# GaLore Projection Back
if "rank" in group:
update = state["projector"].project_back(update)
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
return loss
try:
from bitsandbytes.optim.optimizer import Optimizer2State
except ImportError:
# define a dummy Optimizer2State class
class Optimizer2State(Optimizer):
def __init__(self, *args, **kwargs):
raise ImportError("Please install bitsandbytes to use this optimizer")
def step(self, *args, **kwargs):
pass
def prefetch_state(self, *args, **kwargs):
pass
def init_state(self, *args, **kwargs):
pass
def update_step(self, *args, **kwargs):
pass
def check_overrides(self, *args, **kwargs):
pass
def to_gpu(self, *args, **kwargs):
pass
def to_cpu(self, *args, **kwargs):
pass
class GaLoreAdamW8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
super().__init__(
"adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged
)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
# if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
if "weight_decay" in group and group["weight_decay"] > 0:
# ensure that the weight decay is not applied to the norm grad
group["weight_decay_saved"] = group["weight_decay"]
group["weight_decay"] = 0
grad = state["projector"].project(p.grad, state["step"])
# suboptimal implementation
p.saved_data = p.data.clone()
p.data = grad.clone().to(p.data.dtype).to(p.data.device)
p.data.zero_()
p.grad = grad
if "state1" not in state:
self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
# GaLore Projection Back
if "rank" in group:
p.data = p.saved_data.add_(state["projector"].project_back(p.data))
# apply weight decay
if "weight_decay_saved" in group:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
group["weight_decay"] = group["weight_decay_saved"]
del group["weight_decay_saved"]
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
def get_optimizer(args, optimizer_type, trainable_params, training_models, num_processes):
# trainable_params is list of dict, each dict contains "params" and "lr"
# list may contain multiple dicts: [unet] or [unet, te1] or [unet, te1, te2]
# block lr is not supported
assert len(trainable_params) == len(training_models), "block lr is not supported"
lr = args.learning_rate
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
rank = optimizer_kwargs.pop("rank", 128)
update_proj_gap = optimizer_kwargs.pop("update_proj_gap", 50)
galore_scale = optimizer_kwargs.pop("galore_scale", 1.0)
proj_type = optimizer_kwargs.pop("proj_type", "std")
weight_decay = optimizer_kwargs.get("weight_decay", 0.0) # do not pop, as it is used in the optimizer
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
# target_modules_list = ["attn", "mlp"]
target_modules_list = ["attn", "mlp", "ff"] # for SDXL U-Net
param_groups = []
param_lr = {}
for model, params in zip(training_models, trainable_params):
logger.info(f"model: {model.__class__.__name__}")
galore_params = []
group_lr = params.get("lr", lr)
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
if not any(target_key in module_name for target_key in target_modules_list):
continue
logger.info("enable GaLore for weights in module: " + module_name)
galore_params.append(module.weight)
id_galore_params = [id(p) for p in galore_params]
# make parameters without "rank" to another group
regular_params = [p for p in params["params"] if id(p) not in id_galore_params]
# then call galore_adamw
param_groups.append({"params": regular_params, "lr": group_lr})
param_groups.append(
{
"params": galore_params,
"rank": rank,
"update_proj_gap": update_proj_gap,
"scale": galore_scale,
"proj_type": proj_type,
"lr": group_lr,
}
)
# record lr
for p in regular_params + galore_params:
param_lr[id(p)] = group_lr
# select optimizer
scheduler = None
if optimizer_type == "galore_adamw":
optimizer = GaLoreAdamW(param_groups, lr=lr, **optimizer_kwargs)
elif optimizer_type == "galore_adafactor":
beta1 = None if optimizer_kwargs.get("beta1", 0.0) == 0.0 else optimizer_kwargs.pop("beta1")
optimizer = GaLoreAdafactor(param_groups, lr=lr, beta1=beta1, **optimizer_kwargs)
elif optimizer_type == "galore_adamw8bit":
optimizer = GaLoreAdamW8bit(param_groups, lr=lr, **optimizer_kwargs)
elif optimizer_type == "galore_adamw8bit_per_layer":
# TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
optimizer_dict = {}
all_params = []
for params in trainable_params:
all_params.extend(params["params"])
for p in all_params:
if p.requires_grad:
if id(p) in id_galore_params:
optimizer_dict[p] = GaLoreAdamW8bit(
[
{
"params": [p],
"rank": rank,
"update_proj_gap": update_proj_gap * 2,
"scale": galore_scale,
"proj_type": proj_type,
}
],
lr=param_lr[id(p)],
weight_decay=weight_decay,
)
else:
import bitsandbytes as bnb
optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=param_lr[id(p)], weight_decay=weight_decay)
# get scheduler dict
# scheduler needs accelerate.prepare?
scheduler_dict = {}
for p in all_params:
if p.requires_grad:
scheduler_dict[p] = train_util.get_scheduler_fix(args, optimizer_dict[p], num_processes)
def optimizer_hook(p):
if p.grad is None:
return
optimizer_dict[p].step()
optimizer_dict[p].zero_grad()
scheduler_dict[p].step()
# Register the hook onto every parameter
for p in all_params:
if p.requires_grad:
p.register_post_accumulate_grad_hook(optimizer_hook)
# make dummy scheduler and optimizer
class DummyScheduler:
def step(self):
pass
class DummyOptimizer:
def __init__(self, optimizer_dict):
self.optimizer_dict = optimizer_dict
def step(self):
pass
def zero_grad(self, set_to_none=False):
pass
scheduler = DummyScheduler(optimizer_dict[all_params[0]])
optimizer = DummyOptimizer(optimizer_dict)
else:
raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
return optimizer, scheduler

View File

@@ -3671,7 +3671,7 @@ def get_optimizer(args, trainable_params):
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("8bit".lower()):
elif optimizer_type.endswith("8bit".lower()) and not optimizer_type.startswith("GaLore".lower()):
try:
import bitsandbytes as bnb
except ImportError:
@@ -3880,6 +3880,11 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.startswith("GaLore".lower()):
logger.info(f"use GaLore optimizer | {optimizer_kwargs}")
optimizer = "galore"
return None, None, optimizer
if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ微妙

View File

@@ -11,6 +11,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -378,7 +379,17 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
lr_scheduler = None
if optimizer == "galore":
from library import galore_optimizer
# if lr_scheduler is not layerwise, it is None. if layerwise, it is a dummy scheduler
optimizer, lr_scheduler = galore_optimizer.get_optimizer(
args, args.optimizer_type, params_to_optimize, training_models, accelerator.num_processes
)
if lr_scheduler is None:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16: