From cebee0269834217c9f271224b2d355b65afd9d4f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 13 Feb 2023 23:38:38 +0900 Subject: [PATCH] Official weights to LoRA --- gen_img_diffusers_control_net.py | 42 +-- networks/control_net_lora.py | 415 ++++++++++++++++++++------- networks/extract_control_net_lora.py | 206 +++++++++++++ tools/canny.py | 24 ++ train_control_net.py | 2 +- 5 files changed, 559 insertions(+), 130 deletions(-) create mode 100644 networks/extract_control_net_lora.py create mode 100644 tools/canny.py diff --git a/gen_img_diffusers_control_net.py b/gen_img_diffusers_control_net.py index 53eb50e7..1c873f5e 100644 --- a/gen_img_diffusers_control_net.py +++ b/gen_img_diffusers_control_net.py @@ -826,14 +826,14 @@ class PipelineLike(): if isinstance(mask_image[0], PIL.Image.Image): mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents + # # encode the init image into latents and scale the latents + # init_image = init_image.to(device=self.device, dtype=latents_dtype) + # init_latent_dist = self.vae.encode(init_image).latent_dist + # init_latents = init_latent_dist.sample(generator=generator) + # init_latents = 0.18215 * init_latents + # if len(init_latents) == 1: + # init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + # init_latents_orig = init_latents # # preprocess mask # if mask_image is not None: @@ -846,7 +846,8 @@ class PipelineLike(): # raise ValueError("The mask and init_image should be the same size!") # init imageをhintとして使う - hint_latents = init_latents + hint = init_image + # hint_latents = init_latents # org_dtype = init_image.dtype # hint = torch.nn.functional.interpolate(init_image.to(torch.float32), scale_factor=(1/8, 1/8), mode="bilinear") # hint = hint[:, 0].unsqueeze(1) # RGB -> BW @@ -876,7 +877,7 @@ class PipelineLike(): if accepts_eta: extra_step_kwargs["eta"] = eta - hint_latents = torch.cat([hint_latents, hint_latents]) + # hint_latents = torch.cat([hint_latents, hint_latents]) num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 for i, t in enumerate(tqdm(timesteps)): @@ -885,11 +886,9 @@ class PipelineLike(): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - self.lora_network.set_as_control_path(True) - # self.unet(latent_model_input * hint, t, encoder_hidden_states=text_embeddings).sample - self.unet(hint_latents, t, encoder_hidden_states=text_embeddings) - self.lora_network.set_as_control_path(False) - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.lora_network.call_unet(self.unet, hint, latent_model_input, t, encoder_hidden_states=text_embeddings)[0] # .sample + + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -1812,7 +1811,8 @@ def preprocess_image(image): image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2.0 * image - 1.0 + # return 2.0 * image - 1.0 + return image # ControlNet def preprocess_mask(mask): @@ -2016,6 +2016,7 @@ def main(args): if args.network_module: networks = [] for i, network_module in enumerate(args.network_module): + # control_net_lora固定なのでimportする必要はないがとりあえず print("import network module:", network_module) imported_module = importlib.import_module(network_module) @@ -2040,14 +2041,19 @@ def main(args): metadata = f.metadata() if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") + + from safetensors.torch import load_file + sd = load_file(network_weight) - network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) + network = imported_module.ControlLoRANetwork(unet, sd, network_mul) else: raise ValueError("No weight. Weight is required.") if network is None: return - network.apply_to(text_encoder, unet) + network.apply_to() # text_encoder, unet) + info = network.load_state_dict(sd) + print(f"loading network: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) diff --git a/networks/control_net_lora.py b/networks/control_net_lora.py index 58771922..6a98260e 100644 --- a/networks/control_net_lora.py +++ b/networks/control_net_lora.py @@ -7,11 +7,12 @@ import math import os from typing import List import torch +from diffusers import UNet2DConditionModel from library import train_util -class LoRAModule(torch.nn.Module): +class ControlLoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ @@ -25,17 +26,25 @@ class LoRAModule(torch.nn.Module): if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels - self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) - self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + + self.lora_dim = min(self.lora_dim, in_dim, out_dim) + if self.lora_dim != lora_dim: + print(f"{lora_name} dim (rank) is changed: {self.lora_dim}") + + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: in_dim = org_module.in_features out_dim = org_module.out_features - self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = lora_dim if alpha is None or alpha == 0 else alpha + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える @@ -55,138 +64,322 @@ class LoRAModule(torch.nn.Module): self.is_control_path = control_path def forward(self, x): - if self.is_control_path: - lora_x = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - self.previous_lora_x = lora_x - else: - lora_x = self.previous_lora_x - del self.previous_lora_x - return self.org_forward(x) + lora_x + if not self.is_control_path: + return self.org_forward(x) + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): - if network_dim is None: - network_dim = 4 # default - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) - return network - - -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location='cpu') - - # get dim (rank) - network_alpha = None - network_dim = None - for key, value in weights_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] - - if network_alpha is None: - network_alpha = network_dim - - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) - network.weights_sd = weights_sd - return network - - -class LoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] +class ControlLoRANetwork(torch.nn.Module): + # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + # TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: + def __init__(self, unet, weights_sd, multiplier=1.0, lora_dim=4, alpha=1) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim self.alpha = alpha # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + def create_modules(prefix, root_module: torch.nn.Module) -> List[ControlLoRAModule]: # , target_replace_modules loras = [] for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): - lora_name = prefix + '.' + name + '.' + child_name - lora_name = lora_name.replace('.', '_') - lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha) - loras.append(lora) + # # if module.__class__.__name__ in target_replace_modules: + # for child_name, child_module in module.named_modules(): + if module.__class__.__name__ == "Linear" or module.__class__.__name__ == "Conv2d": # and module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name # + '.' + child_name + lora_name = lora_name.replace('.', '_') + + if weights_sd is None: + dim, alpha = self.lora_dim, self.alpha + else: + down_weight = weights_sd.get(lora_name + ".lora_down.weight", None) + if down_weight is None: + continue + dim = down_weight.size()[0] + alpha = weights_sd.get(lora_name + ".alpha", dim) + + lora = ControlLoRAModule(lora_name, module, self.multiplier, dim, alpha) + loras.append(lora) return loras - self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, - text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + self.unet_loras = create_modules(ControlLoRANetwork.LORA_PREFIX_UNET, unet) # , LoRANetwork.UNET_TARGET_REPLACE_MODULE) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - self.weights_sd = None + # make control model + self.control_model = torch.nn.Module() - # assertion - names = set() - for lora in self.text_encoder_loras + self.unet_loras: - assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" - names.add(lora.lora_name) + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] + zero_convs = torch.nn.ModuleList() + for i, dim in enumerate(dims): + sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) + zero_convs.append(sub_list) + self.control_model.add_module("zero_convs", zero_convs) - def load_weights(self, file): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - self.weights_sd = load_file(file) - else: - self.weights_sd = torch.load(file, map_location='cpu') + middle_block_out = torch.nn.Conv2d(1280, 1280, 1) + self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) - def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): - if self.weights_sd: - weights_has_text_encoder = weights_has_unet = False - for key in self.weights_sd.keys(): - if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): - weights_has_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): - weights_has_unet = True + dims = [16, 16, 32, 32, 96, 96, 256, 320] + strides = [1, 1, 2, 1, 2, 1, 2, 1] + prev_dim = 3 + input_hint_block = torch.nn.Sequential() + for i, (dim, stride) in enumerate(zip(dims, strides)): + input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) + if i < len(dims) - 1: + input_hint_block.append(torch.nn.SiLU()) + prev_dim = dim + self.control_model.add_module("input_hint_block", input_hint_block) - if apply_text_encoder is None: - apply_text_encoder = weights_has_text_encoder - else: - assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" - if apply_unet is None: - apply_unet = weights_has_unet - else: - assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" - else: - assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" + # def load_weights(self, file): + # if os.path.splitext(file)[1] == '.safetensors': + # from safetensors.torch import load_file, safe_open + # self.weights_sd = load_file(file) + # else: + # self.weights_sd = torch.load(file, map_location='cpu') - assert not apply_text_encoder, "ControlNet does not support for text encoder" - - if apply_text_encoder: - print("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - print("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: + def apply_to(self): + for lora in self.unet_loras: lora.apply_to() self.add_module(lora.lora_name, lora) - if self.weights_sd: - # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) - info = self.load_state_dict(self.weights_sd, False) - print(f"weights are loaded: {info}") + def call_unet(self, unet, hint, sample, timestep, encoder_hidden_states): + # control path + hint = hint.to(sample.dtype).to(sample.device) + guided_hint = self.control_model.input_hint_block(hint) - def set_as_control_path(self, control_path): - for lora in self.text_encoder_loras + self.unet_loras: - lora.set_as_control_path(control_path) + for lora_module in self.unet_loras: + lora_module.set_as_control_path(True) + + outs = self.unet_forward(unet, guided_hint, None, sample, timestep, encoder_hidden_states) + + # U-Net + for lora_module in self.unet_loras: + lora_module.set_as_control_path(False) + + sample = self.unet_forward(unet, None, outs, sample, timestep, encoder_hidden_states) + + return sample + + def unet_forward(self, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): + # copy from UNet2DConditionModel + default_overall_up_factor = 2**unet.num_upsamplers + + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if unet.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = unet.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=unet.dtype) + emb = unet.time_embedding(t_emb) + + if ctrl_outs is None: + outs = [] # control path + + # 2. pre-process + sample = unet.conv_in(sample) + if guided_hint is not None: + sample += guided_hint + if ctrl_outs is None: + outs.append(self.control_model.zero_convs[0][0](sample)) # , emb, encoder_hidden_states)) + + # 3. down + zc_idx = 1 + down_block_res_samples = (sample,) + for downsample_block in unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if ctrl_outs is None: + for rs in res_samples: + print("zc", zc_idx, rs.size()) + outs.append(self.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + down_block_res_samples += res_samples + + # 4. mid + sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + if ctrl_outs is None: + outs.append(self.control_model.middle_block_out[0](sample)) + return outs + if ctrl_outs is not None: + sample += ctrl_outs.pop() + + # 5. up + for i, upsample_block in enumerate(unet.up_blocks): + is_final_block = i == len(unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if ctrl_outs is not None and len(ctrl_outs) > 0: + res_samples = list(res_samples) + apply_ctrl_outs = ctrl_outs[-len(res_samples):] + ctrl_outs = ctrl_outs[:-len(res_samples)] + for j in range(len(res_samples)): + print(i, j) + res_samples[j] = res_samples[j] + apply_ctrl_outs[j] + res_samples = tuple(res_samples) + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = unet.conv_norm_out(sample) + sample = unet.conv_act(sample) + sample = unet.conv_out(sample) + + return (sample,) + + """ + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + """ def enable_gradient_checkpointing(self): # not supported diff --git a/networks/extract_control_net_lora.py b/networks/extract_control_net_lora.py new file mode 100644 index 00000000..d2c460ba --- /dev/null +++ b/networks/extract_control_net_lora.py @@ -0,0 +1,206 @@ +# extract approximating LoRA by svd from SD 1.5 vs ControlNet +# https://github.com/lllyasviel/ControlNet/blob/main/tool_transfer_control.py +# +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from diffusers import UNet2DConditionModel + +import library.model_util as model_util +import control_net_lora + + +CLAMP_QUANTILE = 0.99 +MIN_DIFF = 1e-6 + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def svd(args): + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + save_dtype = str_to_dtype(args.save_precision) + + # Diffusersのキーに変換するため、original sdとcontrol sdからU-Netに重みを読み込む ############### + + # original sdをDiffusersに読み込む + print(f"loading original SD model : {args.model_org}") + org_text_encoder, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + + org_sd = torch.load(args.model_org, map_location='cpu') + if 'state_dict' in org_sd: + org_sd = org_sd['state_dict'] + + # control sdからキー変換しつつU-Netに対応する部分のみ取り出す + print(f"loading control SD model : {args.model_tuned}") + + ctrl_sd = torch.load(args.model_tuned, map_location='cpu') + ctrl_unet_sd = org_sd # あらかじめloadしておくことでcontrol sdにない部分はoriginal sdと同じにする + for key in list(ctrl_sd.keys()): + if key.startswith("control_"): + unet_key = "model.diffusion_" + key[len("control_"):] + if unet_key not in ctrl_unet_sd: # zero conv + continue + ctrl_unet_sd[unet_key] = ctrl_sd[key] + + unet_config = model_util.create_unet_diffusers_config(False) + ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(False, ctrl_unet_sd, unet_config) + + # load weights to U-Net + ctrl_unet = UNet2DConditionModel(**unet_config) + info = ctrl_unet.load_state_dict(ctrl_unet_sd_du) + print("loading control u-net:", info) + + # LoRAに対応する部分のU-Netの重みを読み込む ################################# + + org_unet_sd_du = org_unet.state_dict() + + diffs = {} + for (org_name, org_module), (ctrl_name, ctrl_module) in zip(org_unet.named_modules(), ctrl_unet.named_modules()): + if org_module.__class__.__name__ != "Linear" and org_module.__class__.__name__ != "Conv2d": + continue + assert org_name == ctrl_name + + lora_name = control_net_lora.ControlLoRANetwork.LORA_PREFIX_UNET + '.' + org_name # + '.' + child_name + lora_name = lora_name.replace('.', '_') + + diff = ctrl_module.weight - org_module.weight + diff = diff.float() + + if torch.max(torch.abs(diff)) < 1e-5: + # print(f"weights are same: {lora_name}") + continue + print(lora_name) + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with svd + print("calculating by svd") + rank = args.dim + ctrl_lora_sd = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + conv2d = (len(mat.size()) == 4) + kernel_size = None if not conv2d else mat.size()[2:] + + if not conv2d or kernel_size == (1, 1): + if conv2d: + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.unsqueeze(2).unsqueeze(3) + Vh = Vh.unsqueeze(2).unsqueeze(3) + else: + # conv2d kernel != (1,1) + in_channels = mat.size()[1] + current_rank = min(rank, in_channels, mat.size()[0]) + if current_rank != rank: + print(f"channels of conv2d is too small. rank is changed to {current_rank} @ {lora_name}: {mat.size()}") + + mat = mat.flatten(start_dim=1) + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :current_rank] + S = S[:current_rank] + U = U @ torch.diag(S) + + Vh = Vh[:current_rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + # U is (out_channels, rank) with 1x1 conv. So, + U = U.reshape(U.shape[0], U.shape[1], 1, 1) + # V is (rank, in_channels * kernel_size1 * kernel_size2) + # now reshape: + Vh = Vh.reshape(Vh.shape[0], in_channels, *kernel_size) + + ctrl_lora_sd[lora_name + ".lora_up.weight"] = U + ctrl_lora_sd[lora_name + ".lora_down.weight"] = Vh + ctrl_lora_sd[lora_name + ".alpha"] = torch.tensor(current_rank) + + # create LoRA from sd + lora_network = control_net_lora.ControlLoRANetwork(org_unet, ctrl_lora_sd, 1.0) + lora_network.apply_to() + + for key, value in ctrl_sd.items(): + if 'zero_convs' in key or 'input_hint_block' in key or 'middle_block_out' in key: + ctrl_lora_sd[key] = value + + info = lora_network.load_state_dict(ctrl_lora_sd) + print(f"loading control lora sd: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + # # minimum metadata + # metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + + # lora_network.save_weights(args.save_to, save_dtype, metadata) + save_file(ctrl_lora_sd, args.save_to) + print(f"LoRA weights are saved to: {args.save_to}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") + parser.add_argument("--model_org", type=str, default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") + parser.add_argument("--model_tuned", type=str, default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + + args = parser.parse_args() + svd(args) diff --git a/tools/canny.py b/tools/canny.py new file mode 100644 index 00000000..2f01bbf9 --- /dev/null +++ b/tools/canny.py @@ -0,0 +1,24 @@ +import argparse +import cv2 + + +def canny(args): + img = cv2.imread(args.input) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + canny_img = cv2.Canny(img, args.thres1, args.thres2) + # canny_img = 255 - canny_img + + cv2.imwrite(args.output, canny_img) + print("done!") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, default=None, help="input path") + parser.add_argument("--output", type=str, default=None, help="output path") + parser.add_argument("--thres1", type=int, default=32, help="thres1") + parser.add_argument("--thres2", type=int, default=224, help="thres2") + + args = parser.parse_args() + canny(args) diff --git a/train_control_net.py b/train_control_net.py index 74d384bf..465d032c 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -380,7 +380,7 @@ def train(args): net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') - network: control_net_rola.LoRANetwork = network_module.create_network( + network: control_net_rola.ControlLoRANetwork = network_module.create_network( 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return