mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
remove unnecessary code
This commit is contained in:
@@ -734,24 +734,6 @@ class Transformer2DModel(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward_xxx(self, hidden_states, encoder_hidden_states=None, timestep=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# print("Transformer2DModel: Using gradient checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
return func(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.forward_body), hidden_states, encoder_hidden_states, timestep
|
||||
)
|
||||
else:
|
||||
output = self.forward_body(hidden_states, encoder_hidden_states, timestep)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
def __init__(self, channels, out_channels):
|
||||
|
||||
Reference in New Issue
Block a user