Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-10-27 10:20:14 +09:00
2 changed files with 14 additions and 4 deletions

View File

@@ -301,10 +301,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
if device:
weight = weight.to(device)
# merge to weight
if device:
@@ -336,13 +336,16 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
merged_sd[lora_module_name] = weight
merged_sd[lora_module_name] = weight.to("cpu")
# extract from merged weights
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
if device:
mat = mat.to(device)
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -381,7 +384,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
# build minimum metadata
dims = f"{new_rank}"