diff --git a/library/flux_utils.py b/library/flux_utils.py index 3f0a0d63..22054854 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -220,8 +220,12 @@ class DummyTextModel(torch.nn.Module): class DummyCLIPL(torch.nn.Module): def __init__(self): super().__init__() - self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output - self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter + self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output + + # dtype and device from these parameters. train_network.py accesses them + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + self.dummy_param_2 = torch.nn.Parameter(torch.zeros(1)) + self.dummy_param_3 = torch.nn.Parameter(torch.zeros(1)) self.text_model = DummyTextModel() @property