mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Fix emb_dim to work.
This commit is contained in:
@@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
target_replace_modules: List[str],
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
include_conv2d_if_filter: bool = False,
|
||||
) -> List[LoRAModule]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_SD3
|
||||
@@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if filter is not None and not filter in lora_name:
|
||||
continue
|
||||
force_incl_conv2d = False
|
||||
if filter is not None:
|
||||
if not filter in lora_name:
|
||||
continue
|
||||
force_incl_conv2d = include_conv2d_if_filter
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
@@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
elif force_incl_conv2d:
|
||||
# x_embedder
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
@@ -428,7 +436,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for filter, in_dim in zip(
|
||||
[
|
||||
"context_embedder",
|
||||
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
|
||||
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
|
||||
"x_embedder",
|
||||
"y_embedder",
|
||||
"final_layer_adaLN_modulation",
|
||||
@@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
],
|
||||
self.emb_dims,
|
||||
):
|
||||
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
|
||||
# x_embedder is conv2d, so we need to include it
|
||||
loras, _ = create_modules(
|
||||
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
|
||||
)
|
||||
# if len(loras) > 0:
|
||||
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
|
||||
|
||||
Reference in New Issue
Block a user