Implemented Probabilistic Baseline

This commit is contained in:
Victor Mylle
2023-11-19 20:20:31 +00:00
parent 1268af47a6
commit 166d3967e1
4 changed files with 182 additions and 36 deletions

View File

@@ -0,0 +1,93 @@
from utils.clearml import ClearMLHelper
from data.preprocessing import DataProcessor, DataConfig
import numpy as np
import plotly.graph_objects as go
class ProbabilisticBaselineTrainer:
def __init__(self, quantiles, data_processor: DataProcessor, clearml_helper: ClearMLHelper):
self.data_processor = data_processor
data_config = DataConfig()
self.data_processor.set_data_config(data_config)
self.clearml_helper = clearml_helper
self.quantiles = quantiles
def init_clearml_task(self):
if not self.clearml_helper:
return None
task_name = input("Enter a task name: ")
if task_name == "":
task_name = "Untitled Task"
task = self.clearml_helper.get_task(task_name=task_name)
change_description = input("Enter a change description: ")
if change_description:
task.set_comment(change_description)
task.add_tags(self.__class__.__name__)
task.connect(self.data_processor, name="data_processor")
return task
def train(self):
task = self.init_clearml_task()
try:
time_steps = [[] for _ in range(96)]
train_loader, test_loader = self.data_processor.get_dataloaders(predict_sequence_length=96)
for inputs, _ in train_loader:
for i in range(96):
time_steps[i].extend(inputs[:, i].numpy())
all_quantiles = []
for i, time_values in enumerate(time_steps):
quantiles = np.quantile(time_values, self.quantiles)
all_quantiles.append(quantiles)
all_quantiles = np.array(all_quantiles)
# create dictionary
quantile_dict = {}
quantile_dict["quantiles"] = self.quantiles
quantile_dict["quantile_values"] = all_quantiles
if task:
task.upload_artifact("dictionary", quantile_dict)
self.finish_training(quantile_values=all_quantiles, task=task)
task.close()
except Exception:
if task:
task.close()
task.set_archived(True)
raise
def finish_training(self, quantile_values, task):
fig = self.plot_quantiles(quantile_values)
task.get_logger().report_plotly(
title=f"Training Quantile Values",
series="Quantile Values",
figure=fig
)
def plot_quantiles(self, quantile_values):
fig = go.Figure()
for i, q in enumerate(self.quantiles):
values_for_quantile = quantile_values[:, i]
fig.add_trace(go.Scatter(x=np.arange(96), y=values_for_quantile, name=f"Prediction (Q={q})", line=dict(dash='dash')))
fig.update_layout(title="Quantile Values")
fig.update_yaxes(range=[-1, 1])
return fig