Fix split_qkv

This commit is contained in:
kohya-ss
2024-10-29 21:51:56 +09:00
parent 1065dd1b56
commit 0af4edd8a6

View File

@@ -540,8 +540,8 @@ class LoRANetwork(torch.nn.Module):
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
qkv_dim, rank = up_weights[0].size()
split_dim = qkv_dim // 3
split_dim, rank = up_weights[0].size()
qkv_dim = split_dim * 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(3):