remove unnecessary code

This commit is contained in:
Kohya S
2023-07-12 21:53:02 +09:00
parent 3c67e595b8
commit 8df948565a

View File

@@ -734,24 +734,6 @@ class Transformer2DModel(nn.Module):
return output
def forward_xxx(self, hidden_states, encoder_hidden_states=None, timestep=None):
if self.training and self.gradient_checkpointing:
# print("Transformer2DModel: Using gradient checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
return func(*inputs)
return custom_forward
output = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.forward_body), hidden_states, encoder_hidden_states, timestep
)
else:
output = self.forward_body(hidden_states, encoder_hidden_states, timestep)
return output
class Upsample2D(nn.Module):
def __init__(self, channels, out_channels):