diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 988eae75..60a24972 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2262,6 +2262,8 @@ def main(args): if args.network_module: networks = [] network_default_muls = [] + network_pre_calc=args.network_pre_calc + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) @@ -2298,11 +2300,11 @@ def main(args): if network is None: return - mergiable = hasattr(network, "merge_to") - if args.network_merge and not mergiable: + mergeable = network.is_mergeable() + if args.network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergiable: + if not args.network_merge or not mergeable: network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -2311,6 +2313,10 @@ def main(args): network.to(memory_format=torch.channels_last) network.to(dtype).to(device) + if network_pre_calc: + print("backup original weights") + network.backup_weights() + networks.append(network) else: network.merge_to(text_encoder, unet, weights_sd, dtype, device) @@ -2815,11 +2821,19 @@ def main(args): # generate if networks: + # 追加ネットワークの処理 shared = {} for n, m in zip(networks, network_muls if network_muls else network_default_muls): n.set_multiplier(m) if regional_network: n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + print("pre-calculation... done") images = pipe( prompts, @@ -3204,6 +3218,7 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") + parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する") parser.add_argument( "--textual_inversion_embeddings", type=str, diff --git a/networks/lora.py b/networks/lora.py index 353b1f5a..898ffce9 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -66,6 +66,39 @@ class LoRAModule(torch.nn.Module): self.org_module.forward = self.forward del self.org_module + def forward(self, x): + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + +class LoRAInfModule(LoRAModule): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False + + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする def merge_to(self, sd, dtype, device): # get up/down weight up_weight = sd["lora_up.weight"].to(torch.float).to(device) @@ -97,44 +130,45 @@ class LoRAModule(torch.nn.Module): org_sd["weight"] = weight.to(dtype) self.org_module.load_state_dict(org_sd) + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + def set_region(self, region): self.region = region self.region_mask = None - def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - - -class LoRAInfModule(LoRAModule): - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) - - # check regional or not by lora_name - self.text_encoder = False - if lora_name.startswith("lora_te_"): - self.regional = False - self.use_sub_prompt = True - self.text_encoder = True - elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: - self.regional = False - self.use_sub_prompt = True - elif "time_emb" in lora_name: - self.regional = False - self.use_sub_prompt = False - else: - self.regional = True - self.use_sub_prompt = False - - self.network: LoRANetwork = None - - def set_network(self, network): - self.network = network - def default_forward(self, x): # print("default_forward", self.lora_name, x.size()) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): + if not self.enabled: + return self.org_forward(x) + if self.network is None or self.network.sub_prompt_index is None: return self.default_forward(x) if not self.regional and not self.use_sub_prompt: @@ -769,6 +803,10 @@ class LoRANetwork(torch.nn.Module): lora.apply_to() self.add_module(lora.lora_name, lora) + # マージできるかどうかを返す + def is_mergeable(self): + return True + # TODO refactor to common function with apply_to def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_text_encoder = apply_unet = False @@ -955,3 +993,40 @@ class LoRANetwork(torch.nn.Module): w = (w + 1) // 2 self.mask_dic = mask_dic + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False