diff --git a/networks/lora.py b/networks/lora.py index 9243f1e1..b936bfb2 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -31,7 +31,7 @@ class LoRAModule(torch.nn.Module): self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: - alpha = alpha.detach().numpy() + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える