Implemented Probabilistic Baseline
This commit is contained in:
@@ -3,9 +3,11 @@ from torch.utils.data import Dataset, DataLoader
|
||||
import pandas as pd
|
||||
|
||||
class NrvDataset(Dataset):
|
||||
def __init__(self, dataframe, data_config, sequence_length=96, predict_sequence_length=96):
|
||||
def __init__(self, dataframe, data_config, full_day_skip: bool = False, sequence_length=96, predict_sequence_length=96):
|
||||
self.data_config = data_config
|
||||
self.dataframe = dataframe
|
||||
self.full_day_skip = full_day_skip
|
||||
|
||||
# reset dataframe index
|
||||
self.dataframe.reset_index(drop=True, inplace=True)
|
||||
|
||||
@@ -22,6 +24,9 @@ class NrvDataset(Dataset):
|
||||
total_indices = set(range(len(self.nrv) - self.sequence_length - self.predict_sequence_length))
|
||||
self.valid_indices = sorted(list(total_indices - set(self.samples_to_skip)))
|
||||
|
||||
### TODO: Option to only use full day samples ###
|
||||
### skip all samples between is the easiest way I think (not most efficient though) ###
|
||||
|
||||
def skip_samples(self):
|
||||
nan_rows = self.dataframe[self.dataframe.isnull().any(axis=1)]
|
||||
nan_indices = nan_rows.index
|
||||
@@ -30,6 +35,14 @@ class NrvDataset(Dataset):
|
||||
skip_indices = [item for sublist in skip_indices for item in sublist]
|
||||
skip_indices = list(set(skip_indices))
|
||||
skip_indices.sort()
|
||||
|
||||
# add indices that are not the start of a day (00:15) to the skip indices (use datetime column)
|
||||
# get indices of all 00:15 timestamps
|
||||
if self.full_day_skip:
|
||||
start_of_day_indices = self.dataframe[self.dataframe['datetime'].dt.time == pd.Timestamp('00:15:00').time()].index
|
||||
skip_indices.extend(start_of_day_indices)
|
||||
skip_indices = list(set(skip_indices))
|
||||
|
||||
return skip_indices
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -17,7 +17,7 @@ class DataConfig:
|
||||
self.NRV_HISTORY: bool = True
|
||||
|
||||
### LOAD ###
|
||||
self.LOAD_FORECAST: bool = True
|
||||
self.LOAD_FORECAST: bool = False
|
||||
self.LOAD_HISTORY: bool = False
|
||||
|
||||
### PV ###
|
||||
@@ -51,6 +51,13 @@ class DataProcessor:
|
||||
self.nrv_scaler = MinMaxScaler(feature_range=(-1, 1))
|
||||
self.load_forecast_scaler = MinMaxScaler(feature_range=(-1, 1))
|
||||
|
||||
self.full_day_skip = False
|
||||
|
||||
def set_data_config(self, data_config: DataConfig):
|
||||
self.data_config = data_config
|
||||
|
||||
def set_full_day_skip(self, full_day_skip: bool):
|
||||
self.full_day_skip = full_day_skip
|
||||
|
||||
def set_train_range(self, train_range: tuple):
|
||||
self.train_range = train_range
|
||||
@@ -115,7 +122,8 @@ class DataProcessor:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def get_dataloader(self, dataset, shuffle: bool = True):
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=4)
|
||||
batch_size = len(dataset) if self.batch_size is None else self.batch_size
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)
|
||||
|
||||
def get_train_dataloader(self, transform: bool = True, predict_sequence_length: int = 96):
|
||||
train_df = self.all_features.copy()
|
||||
@@ -131,7 +139,7 @@ class DataProcessor:
|
||||
train_df['load_forecast'] = self.load_forecast_scaler.fit_transform(train_df['load_forecast'].values.reshape(-1, 1)).reshape(-1)
|
||||
train_df['total_load'] = self.load_forecast_scaler.transform(train_df['total_load'].values.reshape(-1, 1)).reshape(-1)
|
||||
|
||||
train_dataset = NrvDataset(train_df, data_config=self.data_config, predict_sequence_length=predict_sequence_length)
|
||||
train_dataset = NrvDataset(train_df, data_config=self.data_config, full_day_skip=self.full_day_skip, predict_sequence_length=predict_sequence_length)
|
||||
return self.get_dataloader(train_dataset)
|
||||
|
||||
def get_test_dataloader(self, transform: bool = True, predict_sequence_length: int = 96):
|
||||
@@ -149,7 +157,7 @@ class DataProcessor:
|
||||
test_df['load_forecast'] = self.load_forecast_scaler.transform(test_df['load_forecast'].values.reshape(-1, 1)).reshape(-1)
|
||||
test_df['total_load'] = self.load_forecast_scaler.transform(test_df['total_load'].values.reshape(-1, 1)).reshape(-1)
|
||||
|
||||
test_dataset = NrvDataset(test_df, data_config=self.data_config, predict_sequence_length=predict_sequence_length)
|
||||
test_dataset = NrvDataset(test_df, data_config=self.data_config, full_day_skip=self.full_day_skip, predict_sequence_length=predict_sequence_length)
|
||||
return self.get_dataloader(test_dataset, shuffle=False)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"sys.path.append('..')\n",
|
||||
"from data import DataProcessor, DataConfig\n",
|
||||
"from trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression\n",
|
||||
"from trainers.probabilistic_baseline import ProbabilisticBaselineTrainer\n",
|
||||
"from trainers.autoregressive_trainer import AutoRegressiveTrainer\n",
|
||||
"from trainers.trainer import Trainer\n",
|
||||
"from utils.clearml import ClearMLHelper\n",
|
||||
@@ -45,9 +46,44 @@
|
||||
"data_config.WIND_FORECAST = False\n",
|
||||
"data_config.WIND_HISTORY = False\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"data_processor = DataProcessor(data_config)\n",
|
||||
"data_processor.set_batch_size(1024)"
|
||||
"data_processor.set_batch_size(1024)\n",
|
||||
"data_processor.set_full_day_skip(True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Probabilistic Baseline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ClearML Task: created new task id=07ad9f41dfbb43ada3c15ec33a85050d\n",
|
||||
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/07ad9f41dfbb43ada3c15ec33a85050d/output/log\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n",
|
||||
"JSON serialization of artifact 'dictionary' failed, reverting to pickle\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"quantiles = [0.01, 0.05, 0.1, 0.15, 0.4, 0.5, 0.6, 0.85, 0.9, 0.95, 0.99]\n",
|
||||
"trainer = ProbabilisticBaselineTrainer(quantiles=quantiles, data_processor=data_processor, clearml_helper=clearml_helper)\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -59,28 +95,33 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ClearML Task: created new task id=c6bc2cb556b84fed81fa04f5b4a323ea\n",
|
||||
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/c6bc2cb556b84fed81fa04f5b4a323ea/output/log\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n"
|
||||
"InsecureRequestWarning: Certificate verification is disabled! Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ClearML Task: created new task id=11553d672a2744479de07c9ac0a9dbde\n",
|
||||
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/11553d672a2744479de07c9ac0a9dbde/output/log\n",
|
||||
"2023-11-19 18:06:57,539 - clearml.Task - INFO - Storing jupyter notebook directly as code\n",
|
||||
"2023-11-19 18:06:57,543 - clearml.Repository Detection - WARNING - Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n",
|
||||
"2023-11-19 18:07:14,402 - clearml.model - WARNING - 500 model found when searching for `file:///workspaces/Thesis/src/notebooks/checkpoint.pt`\n",
|
||||
"2023-11-19 18:07:14,403 - clearml.model - WARNING - Selected model `Non Autoregressive Quantile Regression` (id=bc0cb0d7fc614e2e8b0edf5b85348646)\n",
|
||||
"2023-11-19 18:07:14,412 - clearml.frameworks - INFO - Found existing registered model id=bc0cb0d7fc614e2e8b0edf5b85348646 [/workspaces/Thesis/src/notebooks/checkpoint.pt] reusing it.\n",
|
||||
"2023-11-19 18:07:14,974 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
|
||||
"2023-11-19 18:07:16,827 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
|
||||
"2023-11-19 18:07:18,465 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
|
||||
"2023-11-19 18:07:20,045 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
|
||||
"2023-11-19 18:07:21,843 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
|
||||
"2023-11-19 18:07:28,812 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
|
||||
"Early stopping triggered\n"
|
||||
]
|
||||
}
|
||||
@@ -187,36 +228,27 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/workspaces/Thesis/src/notebooks/../trainers/quantile_trainer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||||
" quantiles_tensor = torch.tensor(quantiles)\n",
|
||||
"/workspaces/Thesis/src/notebooks/../losses/pinball_loss.py:7: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||||
" self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)\n",
|
||||
"InsecureRequestWarning: Certificate verification is disabled! Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n"
|
||||
"/workspaces/Thesis/src/notebooks/../trainers/quantile_trainer.py:18: UserWarning:\n",
|
||||
"\n",
|
||||
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||||
"\n",
|
||||
"/workspaces/Thesis/src/notebooks/../losses/pinball_loss.py:7: UserWarning:\n",
|
||||
"\n",
|
||||
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ClearML Task: created new task id=36976b1159074e698e2c19eb6a3bc290\n",
|
||||
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/36976b1159074e698e2c19eb6a3bc290/output/log\n",
|
||||
"2023-11-16 08:46:07,715 - clearml.Task - INFO - Storing jupyter notebook directly as code\n",
|
||||
"2023-11-16 08:46:07,719 - clearml.Repository Detection - WARNING - Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n",
|
||||
"2023-11-16 08:46:15,693 - clearml.model - WARNING - 500 model found when searching for `file:///workspaces/Thesis/src/notebooks/checkpoint.pt`\n",
|
||||
"2023-11-16 08:46:15,694 - clearml.model - WARNING - Selected model `Quantile Regression: Non Linear with test score` (id=bc0cb0d7fc614e2e8b0edf5b85348646)\n",
|
||||
"2023-11-16 08:46:15,702 - clearml.frameworks - INFO - Found existing registered model id=bc0cb0d7fc614e2e8b0edf5b85348646 [/workspaces/Thesis/src/notebooks/checkpoint.pt] reusing it.\n",
|
||||
"2023-11-16 08:46:16,218 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
|
||||
"2023-11-16 08:46:21,062 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
|
||||
"2023-11-16 08:46:33,228 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
|
||||
"2023-11-16 08:46:42,236 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
|
||||
"2023-11-16 08:46:50,541 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
|
||||
"Early stopping triggered\n"
|
||||
]
|
||||
},
|
||||
@@ -224,7 +256,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 25804/25804 [22:46<00:00, 18.88it/s]\n"
|
||||
"100%|██████████| 25804/25804 [20:02<00:00, 21.45it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
93
src/trainers/probabilistic_baseline.py
Normal file
93
src/trainers/probabilistic_baseline.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from utils.clearml import ClearMLHelper
|
||||
from data.preprocessing import DataProcessor, DataConfig
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
class ProbabilisticBaselineTrainer:
|
||||
def __init__(self, quantiles, data_processor: DataProcessor, clearml_helper: ClearMLHelper):
|
||||
self.data_processor = data_processor
|
||||
|
||||
data_config = DataConfig()
|
||||
self.data_processor.set_data_config(data_config)
|
||||
|
||||
self.clearml_helper = clearml_helper
|
||||
self.quantiles = quantiles
|
||||
|
||||
def init_clearml_task(self):
|
||||
if not self.clearml_helper:
|
||||
return None
|
||||
|
||||
task_name = input("Enter a task name: ")
|
||||
if task_name == "":
|
||||
task_name = "Untitled Task"
|
||||
task = self.clearml_helper.get_task(task_name=task_name)
|
||||
|
||||
change_description = input("Enter a change description: ")
|
||||
if change_description:
|
||||
task.set_comment(change_description)
|
||||
|
||||
task.add_tags(self.__class__.__name__)
|
||||
task.connect(self.data_processor, name="data_processor")
|
||||
|
||||
return task
|
||||
|
||||
def train(self):
|
||||
task = self.init_clearml_task()
|
||||
try:
|
||||
time_steps = [[] for _ in range(96)]
|
||||
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(predict_sequence_length=96)
|
||||
|
||||
for inputs, _ in train_loader:
|
||||
for i in range(96):
|
||||
time_steps[i].extend(inputs[:, i].numpy())
|
||||
|
||||
|
||||
all_quantiles = []
|
||||
for i, time_values in enumerate(time_steps):
|
||||
quantiles = np.quantile(time_values, self.quantiles)
|
||||
all_quantiles.append(quantiles)
|
||||
|
||||
all_quantiles = np.array(all_quantiles)
|
||||
|
||||
# create dictionary
|
||||
quantile_dict = {}
|
||||
quantile_dict["quantiles"] = self.quantiles
|
||||
quantile_dict["quantile_values"] = all_quantiles
|
||||
|
||||
if task:
|
||||
task.upload_artifact("dictionary", quantile_dict)
|
||||
self.finish_training(quantile_values=all_quantiles, task=task)
|
||||
task.close()
|
||||
except Exception:
|
||||
if task:
|
||||
task.close()
|
||||
task.set_archived(True)
|
||||
raise
|
||||
|
||||
def finish_training(self, quantile_values, task):
|
||||
|
||||
fig = self.plot_quantiles(quantile_values)
|
||||
task.get_logger().report_plotly(
|
||||
title=f"Training Quantile Values",
|
||||
series="Quantile Values",
|
||||
figure=fig
|
||||
)
|
||||
|
||||
|
||||
def plot_quantiles(self, quantile_values):
|
||||
fig = go.Figure()
|
||||
|
||||
for i, q in enumerate(self.quantiles):
|
||||
values_for_quantile = quantile_values[:, i]
|
||||
fig.add_trace(go.Scatter(x=np.arange(96), y=values_for_quantile, name=f"Prediction (Q={q})", line=dict(dash='dash')))
|
||||
|
||||
fig.update_layout(title="Quantile Values")
|
||||
fig.update_yaxes(range=[-1, 1])
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user