From c06a86706a3bcd8ff0b5b8c816004c3c93e48e00 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 2 Sep 2023 14:54:42 +0900 Subject: [PATCH] support diffusers' new VAE --- library/slicing_vae.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 490b5a75..31b2bd0a 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -22,10 +22,10 @@ import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin -from diffusers.utils import BaseOutput -from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, ResnetBlock2D -from diffusers.models.vae import DecoderOutput, Encoder, AutoencoderKLOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.autoencoder_kl import AutoencoderKLOutput def slice_h(x, num_slices): @@ -209,7 +209,7 @@ class SlicingEncoder(nn.Module): downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - attn_num_head_channels=None, + attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) @@ -221,7 +221,7 @@ class SlicingEncoder(nn.Module): resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", - attn_num_head_channels=None, + attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, ) @@ -381,7 +381,7 @@ class SlicingDecoder(nn.Module): resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", - attn_num_head_channels=None, + attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, ) @@ -406,7 +406,7 @@ class SlicingDecoder(nn.Module): resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - attn_num_head_channels=None, + attention_head_dim=output_channel, temb_channels=None, ) self.up_blocks.append(up_block)