diff --git a/networks/lora.py b/networks/lora.py index 1a665fc4..aa1c9331 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -83,18 +83,18 @@ class LoRAModule(torch.nn.Module): org_forwarded = self.org_forward(x) # module dropout - if self.module_dropout: + if self.module_dropout is not None and self.training: if torch.rand(1) < self.module_dropout: return org_forwarded lx = self.lora_down(x) # normal dropout - if self.dropout: + if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) # rank dropout - if self.rank_dropout: + if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout if len(lx.size()) == 3: mask = mask.unsqueeze(1) # for Text Encoder @@ -102,9 +102,15 @@ class LoRAModule(torch.nn.Module): mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d lx = lx * mask + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + lx = self.lora_up(lx) - return org_forwarded + lx * self.multiplier * self.scale + return org_forwarded + lx * self.multiplier * scale class LoRAInfModule(LoRAModule):