From 5c020bed4932b5a147a4e7f84eff9f792eee8e48 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 6 Apr 2023 08:11:54 +0900 Subject: [PATCH] Add attension couple+reginal LoRA --- gen_img_diffusers.py | 85 ++++++++++--- networks/lora.py | 275 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 296 insertions(+), 64 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index af83ce47..313e4048 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo import library.model_util as model_util import library.train_util as train_util +from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo @@ -634,6 +635,7 @@ class PipelineLike: img2img_noise=None, clip_prompts=None, clip_guide_images=None, + networks: Optional[List[LoRANetwork]] = None, **kwargs, ): r""" @@ -717,6 +719,7 @@ class PipelineLike: batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] vae_batch_size = ( batch_size @@ -1010,6 +1013,11 @@ class PipelineLike: # predict the noise residual if self.control_nets: + if reginonal_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings noise_pred = original_control_net.call_unet_and_control_net( i, num_latent_input, @@ -1019,7 +1027,7 @@ class PipelineLike: i / len(timesteps), latent_model_input, t, - text_embeddings, + text_emb_last, ).sample else: noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings( if isinstance(prompt, str): prompt = [prompt] + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + if not skip_parsing: prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) if uncond_prompt is not None: @@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple): negative_scale: float strength: float network_muls: Tuple[float] + num_sub_prompts: int class BatchData(NamedTuple): @@ -2276,16 +2291,20 @@ def main(args): print(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, text_encoder, unet, **net_kwargs + network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs ) else: raise ValueError("No weight. Weight is required.") if network is None: return - if not args.network_merge: + mergiable = hasattr(network, "merge_to") + if args.network_merge and not mergiable: + print("network is not mergiable. ignore merge option.") + + if not args.network_merge or not mergiable: network.apply_to(text_encoder, unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") if args.opt_channels_last: @@ -2349,12 +2368,12 @@ def main(args): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI - # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds = [] for embeds_file in args.textual_inversion_embeddings: @@ -2558,16 +2577,22 @@ def main(args): print(f"resize img2img mask images to {args.W}*{args.H}") mask_images = resize_images(mask_images, (args.W, args.H)) + regional_network = False if networks and mask_images: - # mask を領域情報として流用する、現在は1枚だけ対応 - # TODO 複数のnetwork classの混在時の考慮 + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True print("use mask as region") - # import cv2 - # for i in range(3): - # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) - # cv2.waitKey() - # cv2.destroyAllWindows() - networks[0].__class__.set_regions(networks, np.array(mask_images[0])) + + size = None + for i, network in enumerate(networks): + if i < 3: + np_mask = np.array(mask_images[0]) + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) mask_images = None prev_image = None # for VGG16 guided @@ -2623,7 +2648,14 @@ def main(args): height_1st = height_1st - height_1st % 32 ext_1st = BatchDataExt( - width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls + width_1st, + height_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + ext.strength, + ext.network_muls, + ext.num_sub_prompts, ) batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) images_1st = process_batch(batch_1st, True, True) @@ -2651,7 +2683,7 @@ def main(args): ( return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), - (width, height, steps, scale, negative_scale, strength, network_muls), + (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), ) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) @@ -2743,8 +2775,11 @@ def main(args): # generate if networks: + shared = {} for n, m in zip(networks, network_muls if network_muls else network_default_muls): n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) images = pipe( prompts, @@ -2969,11 +3004,26 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + b1 = BatchData( False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataExt( - width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None + width, + height, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, ), ) if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? @@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", ) + # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) return parser diff --git a/networks/lora.py b/networks/lora.py index 4e0573d0..353b1f5a 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -10,7 +10,6 @@ import numpy as np import torch import re -from library import train_util RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -61,8 +60,6 @@ class LoRAModule(torch.nn.Module): self.multiplier = multiplier self.org_module = org_module # remove in applying - self.region = None - self.region_mask = None def apply_to(self): self.org_forward = self.org_module.forward @@ -105,39 +102,187 @@ class LoRAModule(torch.nn.Module): self.region_mask = None def forward(self, x): - if self.region is None: - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - # regional LoRA FIXME same as additional-network extension - if x.size()[1] % 77 == 0: - # print(f"LoRA for context: {self.lora_name}") - self.region = None - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - # calculate region mask first time - if self.region_mask is None: - if len(x.size()) == 4: - h, w = x.size()[2:4] - else: - seq_len = x.size()[1] - ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) - h = int(self.region.size()[0] / ratio + 0.5) - w = seq_len // h +class LoRAInfModule(LoRAModule): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) - r = self.region.to(x.device) - if r.dtype == torch.bfloat16: - r = r.to(torch.float) - r = r.unsqueeze(0).unsqueeze(1) - # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) - r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") - r = r.to(x.dtype) + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False - if len(x.size()) == 3: - r = torch.reshape(r, (1, x.size()[1], -1)) + self.network: LoRANetwork = None - self.region_mask = r + def set_network(self, network): + self.network = network - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask + def default_forward(self, x): + # print("default_forward", self.lora_name, x.size()) + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if self.network is None or self.network.sub_prompt_index is None: + return self.default_forward(x) + if not self.regional and not self.use_sub_prompt: + return self.default_forward(x) + + if self.regional: + return self.regional_forward(x) + else: + return self.sub_prompt_forward(x) + + def get_mask_for_x(self, x): + # calculate size from shape of x + if len(x.size()) == 4: + h, w = x.size()[2:4] + area = h * w + else: + area = x.size()[1] + + mask = self.network.mask_dic[area] + if mask is None: + raise ValueError(f"mask is None for resolution {area}") + if len(x.size()) != 4: + mask = torch.reshape(mask, (1, -1, 1)) + return mask + + def regional_forward(self, x): + if "attn2_to_out" in self.lora_name: + return self.to_out_forward(x) + + if self.network.mask_dic is None: # sub_prompt_index >= 3 + return self.default_forward(x) + + # apply mask for LoRA result + lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + mask = self.get_mask_for_x(lx) + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + lx = lx * mask + + x = self.org_forward(x) + x = x + lx + + if "attn2_to_q" in self.lora_name and self.network.is_last_network: + x = self.postp_to_q(x) + + return x + + def postp_to_q(self, x): + # repeat x to num_sub_prompts + has_real_uncond = x.size()[0] // self.network.batch_size == 3 + qc = self.network.batch_size # uncond + qc += self.network.batch_size * self.network.num_sub_prompts # cond + if has_real_uncond: + qc += self.network.batch_size # real_uncond + + query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) + query[: self.network.batch_size] = x[: self.network.batch_size] + + for i in range(self.network.batch_size): + qi = self.network.batch_size + i * self.network.num_sub_prompts + query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] + + if has_real_uncond: + query[-self.network.batch_size :] = x[-self.network.batch_size :] + + # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + return query + + def sub_prompt_forward(self, x): + if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA + return self.org_forward(x) + + emb_idx = self.network.sub_prompt_index + if not self.text_encoder: + emb_idx += self.network.batch_size + + # apply sub prompt of X + lx = x[emb_idx :: self.network.num_sub_prompts] + lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale + + # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + + x = self.org_forward(x) + x[emb_idx :: self.network.num_sub_prompts] += lx + + return x + + def to_out_forward(self, x): + # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + + if self.network.is_last_network: + masks = [None] * self.network.num_sub_prompts + self.network.shared[self.lora_name] = (None, masks) + else: + lx, masks = self.network.shared[self.lora_name] + + # call own LoRA + x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] + lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale + + if self.network.is_last_network: + lx = torch.zeros( + (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype + ) + self.network.shared[self.lora_name] = (lx, masks) + + # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 + masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) + + # if not last network, return x and masks + x = self.org_forward(x) + if not self.network.is_last_network: + return x + + lx, masks = self.network.shared.pop(self.lora_name) + + # if last network, combine separated x with mask weighted sum + has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 + + out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) + out[: self.network.batch_size] = x[: self.network.batch_size] # uncond + if has_real_uncond: + out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond + + # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # for i in range(len(masks)): + # if masks[i] is None: + # masks[i] = torch.zeros_like(masks[-1]) + + mask = torch.cat(masks) + mask_sum = torch.sum(mask, dim=0) + 1e-4 + for i in range(self.network.batch_size): + # 1枚の画像ごとに処理する + lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] + lx1 = lx1 * mask + lx1 = torch.sum(lx1, dim=0) + + xi = self.network.batch_size + i * self.network.num_sub_prompts + x1 = x[xi : xi + self.network.num_sub_prompts] + x1 = x1 * mask + x1 = torch.sum(x1, dim=0) + x1 = x1 / mask_sum + + x1 = x1 + lx1 + out[self.network.batch_size + i] = x1 + + # print("to_out_forward", x.size(), out.size(), has_real_uncond) + return out def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): @@ -421,7 +566,7 @@ def get_block_index(lora_name: str) -> int: # Create network from weights for inference, weights are not loaded here (because can be merged) -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -450,7 +595,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh if key not in modules_alpha: modules_alpha = modules_dim[key] - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) return network, weights_sd @@ -479,6 +628,7 @@ class LoRANetwork(torch.nn.Module): conv_block_alphas=None, modules_dim=None, modules_alpha=None, + module_class=LoRAModule, varbose=False, ) -> None: """ @@ -554,7 +704,7 @@ class LoRANetwork(torch.nn.Module): skipped.append(lora_name) continue - lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) + lora = module_class(lora_name, child_module, self.multiplier, dim, alpha) loras.append(lora) return loras, skipped @@ -570,7 +720,7 @@ class LoRANetwork(torch.nn.Module): print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: + if varbose and len(skipped) > 0: print( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) @@ -600,7 +750,7 @@ class LoRANetwork(torch.nn.Module): weights_sd = load_file(file) else: weights_sd = torch.load(file, map_location="cpu") - + info = self.load_state_dict(weights_sd, False) return info @@ -750,6 +900,7 @@ class LoRANetwork(torch.nn.Module): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import save_file + from library import train_util # Precalculate model hashes to save time on indexing if metadata is None: @@ -762,17 +913,45 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) - @staticmethod - def set_regions(networks, image): - image = image.astype(np.float32) / 255.0 - for i, network in enumerate(networks[:3]): - # NOTE: consider averaging overwrapping area - region = image[:, :, i] - if region.max() == 0: - continue - region = torch.tensor(region) - network.set_region(region) + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + if mask.max() == 0: + mask = torch.ones_like(mask) - def set_region(self, region): - for lora in self.unet_loras: - lora.set_region(region) + self.mask = mask + self.sub_prompt_index = sub_prompt_index + self.is_last_network = is_last_network + + for lora in self.text_encoder_loras + self.unet_loras: + lora.set_network(self) + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + self.batch_size = batch_size + self.num_sub_prompts = num_sub_prompts + self.current_size = (height, width) + self.shared = shared + + # create masks + mask = self.mask + mask_dic = {} + mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w + ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight + dtype = ref_weight.dtype + device = ref_weight.device + + def resize_add(mh, mw): + # print(mh, mw, mh * mw) + m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 + m = m.to(device, dtype=dtype) + mask_dic[mh * mw] = m + + h = height // 8 + w = width // 8 + for _ in range(4): + resize_add(h, w) + if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 + resize_add(h + h % 2, w + w % 2) + h = (h + 1) // 2 + w = (w + 1) // 2 + + self.mask_dic = mask_dic