diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 8ba1c988..6ea4bc33 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -734,24 +734,6 @@ class Transformer2DModel(nn.Module): return output - def forward_xxx(self, hidden_states, encoder_hidden_states=None, timestep=None): - if self.training and self.gradient_checkpointing: - # print("Transformer2DModel: Using gradient checkpointing") - - def create_custom_forward(func): - def custom_forward(*inputs): - return func(*inputs) - - return custom_forward - - output = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.forward_body), hidden_states, encoder_hidden_states, timestep - ) - else: - output = self.forward_body(hidden_states, encoder_hidden_states, timestep) - - return output - class Upsample2D(nn.Module): def __init__(self, channels, out_channels):