fix crash gen script, change to network_dropout

This commit is contained in:
Kohya S
2023-06-01 20:07:04 +09:00
parent f4c9276336
commit f8e8df5a04
2 changed files with 21 additions and 19 deletions

View File

@@ -69,14 +69,17 @@ class LoRAModule(torch.nn.Module):
def forward(self, x):
if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
return (
self.org_forward(x)
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale
)
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
@@ -382,7 +385,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
@@ -400,6 +402,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
block_dims=block_dims,
@@ -407,7 +410,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
dropout=dropout,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
@@ -676,6 +678,7 @@ class LoRANetwork(torch.nn.Module):
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
@@ -686,7 +689,6 @@ class LoRANetwork(torch.nn.Module):
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
dropout=None
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -704,19 +706,18 @@ class LoRANetwork(torch.nn.Module):
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")

View File

@@ -209,8 +209,9 @@ def train(args):
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
# LyCORIS will work with this...
network = network_module.create_network(
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs
)
if network is None:
return
@@ -367,7 +368,8 @@ def train(args):
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
@@ -391,7 +393,6 @@ def train(args):
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_dropout": args.dropout,
}
if use_user_config:
@@ -798,6 +799,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=1,
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定",
)
parser.add_argument(
"--network_dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
@@ -819,12 +826,6 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当",
)
parser.add_argument(
"--dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
)
parser.add_argument(
"--base_weights",
type=str,