fix ControlNet not working

This commit is contained in:
Kohya S
2023-07-30 14:09:43 +09:00
parent 2a4ae88f18
commit 0eacadfa99

View File

@@ -4,8 +4,7 @@ import cv2
import torch
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util
@@ -235,10 +234,6 @@ def unet_forward(
print("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if unet.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
@@ -277,7 +272,7 @@ def unet_forward(
# 3. down
down_block_res_samples = (sample,)
for downsample_block in unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@@ -321,7 +316,7 @@ def unet_forward(
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
@@ -338,4 +333,4 @@ def unet_forward(
sample = unet.conv_act(sample)
sample = unet.conv_out(sample)
return UNet2DConditionOutput(sample=sample)
return SampleOutput(sample=sample)