Files
Kohya-ss-sd-scripts/library/original_unet.py
2023-06-24 09:35:33 +09:00

1607 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
# 条件分岐等で不要な部分は削除している
# コードの多くはDiffusersからコピーしている
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
# Unnecessary parts are deleted by condition branching.
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
"""
v1.5とv2.1の相違点は
- attention_head_dimがintかlist[int]か
- cross_attention_dimが768か1024か
- use_linear_projection: trueがない=False, 1.5)かあるか
- upcast_attentionがFalse(1.5)かTrue(2.1)か
- (以下は多分無視していい)
- sample_sizeが64か96か
- dual_cross_attentionがあるかないか
- num_class_embedsがあるかないか
- only_cross_attentionがあるかないか
v1.5
{
"_class_name": "UNet2DConditionModel",
"_diffusers_version": "0.6.0",
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"center_input_sample": false,
"cross_attention_dim": 768,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"flip_sin_to_cos": true,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"out_channels": 4,
"sample_size": 64,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
]
}
v2.1
{
"_class_name": "UNet2DConditionModel",
"_diffusers_version": "0.10.0.dev0",
"act_fn": "silu",
"attention_head_dim": [
5,
10,
20,
20
],
"block_out_channels": [
320,
640,
1280,
1280
],
"center_input_sample": false,
"cross_attention_dim": 1024,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"dual_cross_attention": false,
"flip_sin_to_cos": true,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": null,
"only_cross_attention": false,
"out_channels": 4,
"sample_size": 96,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
],
"use_linear_projection": true,
"upcast_attention": true
}
"""
import math
from types import SimpleNamespace
from typing import Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
IN_CHANNELS: int = 4
OUT_CHANNELS: int = 4
LAYERS_PER_BLOCK: int = 2
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
TIME_EMBED_FREQ_SHIFT: int = 0
NORM_GROUPS: int = 32
NORM_EPS: float = 1e-5
TRANSFORMER_NORM_NUM_GROUPS = 32
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
# region memory effcient attention
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# constants
EPSILON = 1e-6
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
class FlashAttentionFunction(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
scale = q.shape[-1] ** -0.5
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.0)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.0)
p = exp_attn_weights / lc
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
# endregion
def get_parameter_dtype(parameter: torch.nn.Module):
return next(parameter.parameters()).dtype
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class SampleOutput:
def __init__(self, sample):
self.sample = sample
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
self.use_in_shortcut = self.in_channels != self.out_channels
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
add_downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(LAYERS_PER_BLOCK):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
else:
self.downsamplers = None
self.gradient_checkpointing = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
pass
def set_use_sdpa(self, sdpa):
pass
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class Downsample2D(nn.Module):
def __init__(self, channels, out_channels):
super().__init__()
self.channels = channels
self.out_channels = out_channels
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class CrossAttention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
# no dropout here
self.use_memory_efficient_attention_xformers = False
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
def set_use_sdpa(self, sdpa):
self.use_sdpa = sdpa
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
if self.use_sdpa:
return self.forward_sdpa(hidden_states, context, mask)
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
hidden_states = self._attention(query, key, value)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# hidden_states = self.to_out[1](hidden_states) # no dropout
return hidden_states
def _attention(self, query, key, value):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
# TODO support Hypernetworks
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
import xformers.ops
h = self.heads
q_in = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
out = rearrange(out, "b n h d -> b n (h d)", h=h)
out = self.to_out[0](out)
return out
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
flash_func = FlashAttentionFunction
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out[0](out)
return out
def forward_sdpa(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out[0](out)
return out
# feedforward
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
):
super().__init__()
inner_dim = int(dim * 4) # mult is always 4
self.net = nn.ModuleList([])
# project in
self.net.append(GEGLU(dim, inner_dim))
# project dropout
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
# project out
self.net.append(nn.Linear(inner_dim, dim))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
):
super().__init__()
# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
upcast_attention=upcast_attention,
)
self.ff = FeedForward(dim)
# 2. Cross-Attn
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
upcast_attention=upcast_attention,
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim)
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool):
self.attn1.set_use_sdpa(sdpa)
self.attn2.set_use_sdpa(sdpa)
def forward(self, hidden_states, context=None, timestep=None):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
use_linear_projection: bool = False,
upcast_attention: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.use_linear_projection = use_linear_projection
self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
]
)
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def set_use_memory_efficient_attention(self, xformers, mem_eff):
for transformer in self.transformer_blocks:
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa):
for transformer in self.transformer_blocks:
transformer.set_use_sdpa(sdpa)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# 1. Input
batch, _, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
# 3. Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return SampleOutput(sample=output)
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
add_downsample=True,
cross_attention_dim=1280,
attn_num_head_channels=1,
use_linear_projection=False,
upcast_attention=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
self.attn_num_head_channels = attn_num_head_channels
for i in range(LAYERS_PER_BLOCK):
in_channels = in_channels if i == 0 else out_channels
resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
else:
self.downsamplers = None
self.gradient_checkpointing = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(sdpa)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
attn_num_head_channels=1,
cross_attention_dim=1280,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
# Middle block has two resnets and one attention
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
),
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
),
]
attentions = [
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(sdpa)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
for i, resnet in enumerate(self.resnets):
attn = None if i == 0 else self.attentions[i - 1]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
if attn is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
if attn is not None:
hidden_states = attn(hidden_states, encoder_hidden_states).sample
hidden_states = resnet(hidden_states, temb)
return hidden_states
class Upsample2D(nn.Module):
def __init__(self, channels, out_channels):
super().__init__()
self.channels = channels
self.out_channels = out_channels
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size):
assert hidden_states.shape[1] == self.channels
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
hidden_states = self.conv(hidden_states)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
add_upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(LAYERS_PER_BLOCK_UP):
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
pass
def set_use_sdpa(self, sdpa):
pass
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
attn_num_head_channels=1,
cross_attention_dim=1280,
add_upsample=True,
use_linear_projection=False,
upcast_attention=False,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(LAYERS_PER_BLOCK_UP):
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
)
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def set_use_memory_efficient_attention(self, xformers, mem_eff):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
for attn in self.attentions:
attn.set_use_sdpa(spda)
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def get_down_block(
down_block_type,
in_channels,
out_channels,
add_downsample,
attn_num_head_channels,
cross_attention_dim,
use_linear_projection,
upcast_attention,
):
if down_block_type == "DownBlock2D":
return DownBlock2D(
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
)
elif down_block_type == "CrossAttnDownBlock2D":
return CrossAttnDownBlock2D(
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
def get_up_block(
up_block_type,
in_channels,
out_channels,
prev_output_channel,
add_upsample,
attn_num_head_channels,
cross_attention_dim=None,
use_linear_projection=False,
upcast_attention=False,
):
if up_block_type == "UpBlock2D":
return UpBlock2D(
in_channels=in_channels,
prev_output_channel=prev_output_channel,
out_channels=out_channels,
add_upsample=add_upsample,
)
elif up_block_type == "CrossAttnUpBlock2D":
return CrossAttnUpBlock2D(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
attn_num_head_channels=attn_num_head_channels,
cross_attention_dim=cross_attention_dim,
add_upsample=add_upsample,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
class UNet2DConditionModel(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
sample_size: Optional[int] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
cross_attention_dim: int = 1280,
use_linear_projection: bool = False,
upcast_attention: bool = False,
**kwargs,
):
super().__init__()
assert sample_size is not None, "sample_size must be specified"
print(
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
)
# 外部からの参照用に定義しておく
self.in_channels = IN_CHANNELS
self.out_channels = OUT_CHANNELS
self.sample_size = sample_size
self.prepare_config()
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
# input
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
# time
self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * 4
# down
output_channel = BLOCK_OUT_CHANNELS[0]
for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
input_channel = output_channel
output_channel = BLOCK_OUT_CHANNELS[i]
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
down_block = get_down_block(
down_block_type,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
attn_num_head_channels=attention_head_dim[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=BLOCK_OUT_CHANNELS[-1],
attn_num_head_channels=attention_head_dim[-1],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
reversed_attention_head_dim = list(reversed(attention_head_dim))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(UP_BLOCK_TYPES):
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
add_upsample=add_upsample,
attn_num_head_channels=reversed_attention_head_dim[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@property
def dtype(self) -> torch.dtype:
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
return get_parameter_dtype(self)
@property
def device(self) -> torch.device:
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
return get_parameter_device(self)
def set_attention_slice(self, slice_size):
raise NotImplementedError("Attention slicing is not supported for this model.")
def is_gradient_checkpointing(self) -> bool:
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self):
self.set_gradient_checkpointing(value=True)
def disable_gradient_checkpointing(self):
self.set_gradient_checkpointing(value=False)
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None:
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
module.set_use_sdpa(sdpa)
def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
module.gradient_checkpointing = value
# endregion
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)
def handle_unusual_timesteps(self, sample, timesteps):
r"""
timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
"""
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
return timesteps