scale in rank dropout, check training in dropout

This commit is contained in:
Kohya S
2023-06-02 07:29:59 +09:00
parent dde7807b00
commit 0f0158ddaa

View File

@@ -83,18 +83,18 @@ class LoRAModule(torch.nn.Module):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout:
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout:
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout:
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
@@ -102,9 +102,15 @@ class LoRAModule(torch.nn.Module):
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * self.scale
return org_forwarded + lx * self.multiplier * scale
class LoRAInfModule(LoRAModule):