diff --git a/src/trainers/diffusion_trainer.py b/src/trainers/diffusion_trainer.py index 6f35ae3..5d882c7 100644 --- a/src/trainers/diffusion_trainer.py +++ b/src/trainers/diffusion_trainer.py @@ -19,7 +19,13 @@ def sample_diffusion(model: DiffusionModel, n: int, inputs: torch.tensor, noise_ alpha = 1. - beta alpha_hat = torch.cumprod(alpha, dim=0) - inputs = inputs.repeat(n, 1).to(device) + # inputs: (num_features) -> (batch_size, num_features) + # inputs: (time_steps, num_features) -> (batch_size, time_steps, num_features) + if len(inputs.shape) == 2: + inputs = inputs.repeat(n, 1) + elif len(inputs.shape) == 3: + inputs = inputs.repeat(n, 1, 1) + model.eval() with torch.no_grad(): x = torch.randn(inputs.shape[0], ts_length).to(device) @@ -101,8 +107,12 @@ class DiffusionTrainer: input_data = torch.randn(1024, 96).to(self.device) time_steps = torch.randn(1024).long().to(self.device) - other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(self.device) + if self.data_processor.lstm: + inputDim = self.data_processor.get_input_size() + other_input_data = torch.randn(1024, inputDim[1], self.model.other_inputs_dim).to(self.device) + else: + other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(self.device) task.set_configuration_object("model", str(summary(self.model, input_data=[input_data, time_steps, other_input_data]))) self.data_processor = task.connect(self.data_processor, name="data_processor") @@ -222,7 +232,7 @@ class DiffusionTrainer: number_of_samples = 100 sample = self.sample(self.model, number_of_samples, inputs) - + # reduce samples from (batch_size*number_of_samples, time_steps) to (batch_size, number_of_samples, time_steps) samples_batched = sample.reshape(inputs.shape[0], number_of_samples, 96) diff --git a/src/training_scripts/diffusion_training.py b/src/training_scripts/diffusion_training.py index 210e6ef..c206c2a 100644 --- a/src/training_scripts/diffusion_training.py +++ b/src/training_scripts/diffusion_training.py @@ -38,10 +38,11 @@ data_config.NOMINAL_NET_POSITION = True data_config = task.connect(data_config, name="data_features") data_processor = DataProcessor(data_config, path="", lstm=True) -data_processor.set_batch_size(8192) +data_processor.set_batch_size(128) data_processor.set_full_day_skip(True) inputDim = data_processor.get_input_size() +print("Input dim: ", inputDim) model_parameters = { "epochs": 5000, @@ -54,7 +55,7 @@ model_parameters = task.connect(model_parameters, name="model_parameters") #### Model #### # model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"]) -model = GRUDiffusionModel(96, [256, 256], other_inputs_dim=inputDim[2], time_dim=64, gru_hidden_size=128) +model = GRUDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[2], time_dim=model_parameters["time_dim"], gru_hidden_size=256) print("Starting training ...")