From e8e53ab185fbb7943aec5b3c8a5ea9380d9c5866 Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Thu, 18 Jan 2024 23:21:57 +0000 Subject: [PATCH] Updated training script for GRU model --- src/training_scripts/diffusion_training.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/training_scripts/diffusion_training.py b/src/training_scripts/diffusion_training.py index dc5452d..210e6ef 100644 --- a/src/training_scripts/diffusion_training.py +++ b/src/training_scripts/diffusion_training.py @@ -10,7 +10,7 @@ from torch.nn import MSELoss, L1Loss from datetime import datetime import torch.nn as nn from src.models.time_embedding_layer import TimeEmbedding -from src.models.diffusion_model import SimpleDiffusionModel +from src.models.diffusion_model import GRUDiffusionModel, SimpleDiffusionModel from src.trainers.diffusion_trainer import DiffusionTrainer @@ -37,7 +37,7 @@ data_config.NOMINAL_NET_POSITION = True data_config = task.connect(data_config, name="data_features") -data_processor = DataProcessor(data_config, path="", lstm=False) +data_processor = DataProcessor(data_config, path="", lstm=True) data_processor.set_batch_size(8192) data_processor.set_full_day_skip(True) @@ -53,7 +53,8 @@ model_parameters = { model_parameters = task.connect(model_parameters, name="model_parameters") #### Model #### -model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"]) +# model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"]) +model = GRUDiffusionModel(96, [256, 256], other_inputs_dim=inputDim[2], time_dim=64, gru_hidden_size=128) print("Starting training ...")