Merge pull request #799 from kohya-ss/dev

support diffusers' new VAE
This commit is contained in:
Kohya S
2023-09-02 14:56:37 +09:00
committed by GitHub

View File

@@ -22,10 +22,10 @@ import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, ResnetBlock2D
from diffusers.models.vae import DecoderOutput, Encoder, AutoencoderKLOutput, DiagonalGaussianDistribution
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
def slice_h(x, num_slices):
@@ -209,7 +209,7 @@ class SlicingEncoder(nn.Module):
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
@@ -221,7 +221,7 @@ class SlicingEncoder(nn.Module):
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
@@ -381,7 +381,7 @@ class SlicingDecoder(nn.Module):
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
@@ -406,7 +406,7 @@ class SlicingDecoder(nn.Module):
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
attention_head_dim=output_channel,
temb_channels=None,
)
self.up_blocks.append(up_block)