Files
Kohya-ss-sd-scripts/tests/test_optimizer.py
2024-11-29 15:52:03 -05:00

154 lines
5.0 KiB
Python

from unittest.mock import patch
from library.train_util import get_optimizer
from train_network import setup_parser
import torch
from torch.nn import Parameter
# Optimizer libraries
import bitsandbytes as bnb
from lion_pytorch import lion_pytorch
import schedulefree
import dadaptation
import dadaptation.experimental as dadapt_experimental
import prodigyopt
import schedulefree as sf
import transformers
def test_default_get_optimizer():
with patch("sys.argv", [""]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "torch.optim.adamw.AdamW"
assert optimizer_args == ""
assert isinstance(optimizer, torch.optim.AdamW)
def test_get_schedulefree_optimizer():
with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree"
assert optimizer_args == ""
assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree)
def test_all_supported_optimizers():
optimizers = [
{
"name": "bitsandbytes.optim.adamw.AdamW8bit",
"alias": "AdamW8bit",
"instance": bnb.optim.AdamW8bit,
},
{
"name": "lion_pytorch.lion_pytorch.Lion",
"alias": "Lion",
"instance": lion_pytorch.Lion,
},
{
"name": "torch.optim.adamw.AdamW",
"alias": "AdamW",
"instance": torch.optim.AdamW,
},
{
"name": "bitsandbytes.optim.lion.Lion8bit",
"alias": "Lion8bit",
"instance": bnb.optim.Lion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW8bit",
"alias": "PagedAdamW8bit",
"instance": bnb.optim.PagedAdamW8bit,
},
{
"name": "bitsandbytes.optim.lion.PagedLion8bit",
"alias": "PagedLion8bit",
"instance": bnb.optim.PagedLion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW",
"alias": "PagedAdamW",
"instance": bnb.optim.PagedAdamW,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW32bit",
"alias": "PagedAdamW32bit",
"instance": bnb.optim.PagedAdamW32bit,
},
{"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD},
{
"name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint",
"alias": "DAdaptAdamPreprint",
"instance": dadapt_experimental.DAdaptAdamPreprint,
},
{
"name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad",
"alias": "DAdaptAdaGrad",
"instance": dadaptation.DAdaptAdaGrad,
},
{
"name": "dadaptation.dadapt_adan.DAdaptAdan",
"alias": "DAdaptAdan",
"instance": dadaptation.DAdaptAdan,
},
{
"name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP",
"alias": "DAdaptAdanIP",
"instance": dadapt_experimental.DAdaptAdanIP,
},
{
"name": "dadaptation.dadapt_lion.DAdaptLion",
"alias": "DAdaptLion",
"instance": dadaptation.DAdaptLion,
},
{
"name": "dadaptation.dadapt_sgd.DAdaptSGD",
"alias": "DAdaptSGD",
"instance": dadaptation.DAdaptSGD,
},
{
"name": "prodigyopt.prodigy.Prodigy",
"alias": "Prodigy",
"instance": prodigyopt.Prodigy,
},
{
"name": "transformers.optimization.Adafactor",
"alias": "Adafactor",
"instance": transformers.optimization.Adafactor,
},
{
"name": "schedulefree.adamw_schedulefree.AdamWScheduleFree",
"alias": "AdamWScheduleFree",
"instance": sf.AdamWScheduleFree,
},
{
"name": "schedulefree.sgd_schedulefree.SGDScheduleFree",
"alias": "SGDScheduleFree",
"instance": sf.SGDScheduleFree,
},
]
for opt in optimizers:
with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, _, optimizer = get_optimizer(args, [param])
assert optimizer_name == opt.get("name")
instance = opt.get("instance")
assert instance is not None
assert isinstance(optimizer, instance)