mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
202 lines
8.0 KiB
Python
202 lines
8.0 KiB
Python
import torch
|
||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||
|
||
from library.original_unet import SampleOutput
|
||
|
||
|
||
def unet_forward_XTI(
|
||
self,
|
||
sample: torch.FloatTensor,
|
||
timestep: Union[torch.Tensor, float, int],
|
||
encoder_hidden_states: torch.Tensor,
|
||
class_labels: Optional[torch.Tensor] = None,
|
||
return_dict: bool = True,
|
||
) -> Union[Dict, Tuple]:
|
||
r"""
|
||
Args:
|
||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||
return_dict (`bool`, *optional*, defaults to `True`):
|
||
Whether or not to return a dict instead of a plain tuple.
|
||
|
||
Returns:
|
||
`SampleOutput` or `tuple`:
|
||
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||
"""
|
||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||
# on the fly if necessary.
|
||
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||
default_overall_up_factor = 2**self.num_upsamplers
|
||
|
||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||
# 64で割り切れないときはupsamplerにサイズを伝える
|
||
forward_upsample_size = False
|
||
upsample_size = None
|
||
|
||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||
# logger.info("Forward upsample size to force interpolation output size.")
|
||
forward_upsample_size = True
|
||
|
||
# 1. time
|
||
timesteps = timestep
|
||
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||
|
||
t_emb = self.time_proj(timesteps)
|
||
|
||
# timesteps does not contain any weights and will always return f32 tensors
|
||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||
# there might be better ways to encapsulate this.
|
||
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||
# time_projでキャストしておけばいいんじゃね?
|
||
t_emb = t_emb.to(dtype=self.dtype)
|
||
emb = self.time_embedding(t_emb)
|
||
|
||
# 2. pre-process
|
||
sample = self.conv_in(sample)
|
||
|
||
# 3. down
|
||
down_block_res_samples = (sample,)
|
||
down_i = 0
|
||
for downsample_block in self.down_blocks:
|
||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||
# まあこちらのほうがわかりやすいかもしれない
|
||
if downsample_block.has_cross_attention:
|
||
sample, res_samples = downsample_block(
|
||
hidden_states=sample,
|
||
temb=emb,
|
||
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
|
||
)
|
||
down_i += 2
|
||
else:
|
||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||
|
||
down_block_res_samples += res_samples
|
||
|
||
# 4. mid
|
||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
||
|
||
# 5. up
|
||
up_i = 7
|
||
for i, upsample_block in enumerate(self.up_blocks):
|
||
is_final_block = i == len(self.up_blocks) - 1
|
||
|
||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||
|
||
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||
if not is_final_block and forward_upsample_size:
|
||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||
|
||
if upsample_block.has_cross_attention:
|
||
sample = upsample_block(
|
||
hidden_states=sample,
|
||
temb=emb,
|
||
res_hidden_states_tuple=res_samples,
|
||
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
|
||
upsample_size=upsample_size,
|
||
)
|
||
up_i += 3
|
||
else:
|
||
sample = upsample_block(
|
||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||
)
|
||
|
||
# 6. post-process
|
||
sample = self.conv_norm_out(sample)
|
||
sample = self.conv_act(sample)
|
||
sample = self.conv_out(sample)
|
||
|
||
if not return_dict:
|
||
return (sample,)
|
||
|
||
return SampleOutput(sample=sample)
|
||
|
||
|
||
def downblock_forward_XTI(
|
||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||
):
|
||
output_states = ()
|
||
i = 0
|
||
|
||
for resnet, attn in zip(self.resnets, self.attentions):
|
||
if self.training and self.gradient_checkpointing:
|
||
|
||
def create_custom_forward(module, return_dict=None):
|
||
def custom_forward(*inputs):
|
||
if return_dict is not None:
|
||
return module(*inputs, return_dict=return_dict)
|
||
else:
|
||
return module(*inputs)
|
||
|
||
return custom_forward
|
||
|
||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
||
)[0]
|
||
else:
|
||
hidden_states = resnet(hidden_states, temb)
|
||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||
|
||
output_states += (hidden_states,)
|
||
i += 1
|
||
|
||
if self.downsamplers is not None:
|
||
for downsampler in self.downsamplers:
|
||
hidden_states = downsampler(hidden_states)
|
||
|
||
output_states += (hidden_states,)
|
||
|
||
return hidden_states, output_states
|
||
|
||
|
||
def upblock_forward_XTI(
|
||
self,
|
||
hidden_states,
|
||
res_hidden_states_tuple,
|
||
temb=None,
|
||
encoder_hidden_states=None,
|
||
upsample_size=None,
|
||
):
|
||
i = 0
|
||
for resnet, attn in zip(self.resnets, self.attentions):
|
||
# pop res hidden states
|
||
res_hidden_states = res_hidden_states_tuple[-1]
|
||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||
|
||
if self.training and self.gradient_checkpointing:
|
||
|
||
def create_custom_forward(module, return_dict=None):
|
||
def custom_forward(*inputs):
|
||
if return_dict is not None:
|
||
return module(*inputs, return_dict=return_dict)
|
||
else:
|
||
return module(*inputs)
|
||
|
||
return custom_forward
|
||
|
||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
||
)[0]
|
||
else:
|
||
hidden_states = resnet(hidden_states, temb)
|
||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||
|
||
i += 1
|
||
|
||
if self.upsamplers is not None:
|
||
for upsampler in self.upsamplers:
|
||
hidden_states = upsampler(hidden_states, upsample_size)
|
||
|
||
return hidden_states
|