From 0af4edd8a63d7fcdf02bdcbd11b8770fd1cae162 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 21:51:56 +0900 Subject: [PATCH] Fix split_qkv --- networks/lora_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index cbabf8da..249298b3 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -540,8 +540,8 @@ class LoRANetwork(torch.nn.Module): down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) # merge up weight (sum of split_dim, rank*3) - qkv_dim, rank = up_weights[0].size() - split_dim = qkv_dim // 3 + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) i = 0 for j in range(3):