mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
* Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
|
# Thanks to cloneofsimo
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
import torch
|
|
from safetensors.torch import load_file, save_file, safe_open
|
|
from tqdm import tqdm
|
|
from library import train_util, model_util
|
|
import numpy as np
|
|
from library.utils import setup_logging
|
|
setup_logging()
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def load_state_dict(file_name):
|
|
if model_util.is_safetensors(file_name):
|
|
sd = load_file(file_name)
|
|
with safe_open(file_name, framework="pt") as f:
|
|
metadata = f.metadata()
|
|
else:
|
|
sd = torch.load(file_name, map_location="cpu")
|
|
metadata = None
|
|
|
|
return sd, metadata
|
|
|
|
|
|
def save_to_file(file_name, model, metadata):
|
|
if model_util.is_safetensors(file_name):
|
|
save_file(model, file_name, metadata)
|
|
else:
|
|
torch.save(model, file_name)
|
|
|
|
|
|
def split_lora_model(lora_sd, unit):
|
|
max_rank = 0
|
|
|
|
# Extract loaded lora dim and alpha
|
|
for key, value in lora_sd.items():
|
|
if "lora_down" in key:
|
|
rank = value.size()[0]
|
|
if rank > max_rank:
|
|
max_rank = rank
|
|
logger.info(f"Max rank: {max_rank}")
|
|
|
|
rank = unit
|
|
split_models = []
|
|
new_alpha = None
|
|
while rank < max_rank:
|
|
logger.info(f"Splitting rank {rank}")
|
|
new_sd = {}
|
|
for key, value in lora_sd.items():
|
|
if "lora_down" in key:
|
|
new_sd[key] = value[:rank].contiguous()
|
|
elif "lora_up" in key:
|
|
new_sd[key] = value[:, :rank].contiguous()
|
|
else:
|
|
# なぜかscaleするとおかしくなる……
|
|
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
|
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
|
# logger.info(key, value.size(), this_rank, rank, value, scale)
|
|
# new_alpha = value * scale # always same
|
|
# new_sd[key] = new_alpha
|
|
new_sd[key] = value
|
|
|
|
split_models.append((new_sd, rank, new_alpha))
|
|
rank += unit
|
|
|
|
return max_rank, split_models
|
|
|
|
|
|
def split(args):
|
|
logger.info("loading Model...")
|
|
lora_sd, metadata = load_state_dict(args.model)
|
|
|
|
logger.info("Splitting Model...")
|
|
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
|
|
|
comment = metadata.get("ss_training_comment", "")
|
|
for state_dict, new_rank, new_alpha in split_models:
|
|
# update metadata
|
|
if metadata is None:
|
|
new_metadata = {}
|
|
else:
|
|
new_metadata = metadata.copy()
|
|
|
|
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
|
new_metadata["ss_network_dim"] = str(new_rank)
|
|
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
|
|
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
|
metadata["sshs_model_hash"] = model_hash
|
|
metadata["sshs_legacy_hash"] = legacy_hash
|
|
|
|
filename, ext = os.path.splitext(args.save_to)
|
|
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
|
|
|
logger.info(f"saving model to: {model_file_name}")
|
|
save_to_file(model_file_name, state_dict, new_metadata)
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
|
parser.add_argument(
|
|
"--save_to",
|
|
type=str,
|
|
default=None,
|
|
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default=None,
|
|
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
|
|
args = parser.parse_args()
|
|
split(args)
|