mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
scale in rank dropout, check training in dropout
This commit is contained in:
@@ -83,18 +83,18 @@ class LoRAModule(torch.nn.Module):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout:
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout:
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout:
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
@@ -102,9 +102,15 @@ class LoRAModule(torch.nn.Module):
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * self.scale
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
|
||||
Reference in New Issue
Block a user