Implemented Probabilistic Baseline
This commit is contained in:
93
src/trainers/probabilistic_baseline.py
Normal file
93
src/trainers/probabilistic_baseline.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from utils.clearml import ClearMLHelper
|
||||
from data.preprocessing import DataProcessor, DataConfig
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
class ProbabilisticBaselineTrainer:
|
||||
def __init__(self, quantiles, data_processor: DataProcessor, clearml_helper: ClearMLHelper):
|
||||
self.data_processor = data_processor
|
||||
|
||||
data_config = DataConfig()
|
||||
self.data_processor.set_data_config(data_config)
|
||||
|
||||
self.clearml_helper = clearml_helper
|
||||
self.quantiles = quantiles
|
||||
|
||||
def init_clearml_task(self):
|
||||
if not self.clearml_helper:
|
||||
return None
|
||||
|
||||
task_name = input("Enter a task name: ")
|
||||
if task_name == "":
|
||||
task_name = "Untitled Task"
|
||||
task = self.clearml_helper.get_task(task_name=task_name)
|
||||
|
||||
change_description = input("Enter a change description: ")
|
||||
if change_description:
|
||||
task.set_comment(change_description)
|
||||
|
||||
task.add_tags(self.__class__.__name__)
|
||||
task.connect(self.data_processor, name="data_processor")
|
||||
|
||||
return task
|
||||
|
||||
def train(self):
|
||||
task = self.init_clearml_task()
|
||||
try:
|
||||
time_steps = [[] for _ in range(96)]
|
||||
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(predict_sequence_length=96)
|
||||
|
||||
for inputs, _ in train_loader:
|
||||
for i in range(96):
|
||||
time_steps[i].extend(inputs[:, i].numpy())
|
||||
|
||||
|
||||
all_quantiles = []
|
||||
for i, time_values in enumerate(time_steps):
|
||||
quantiles = np.quantile(time_values, self.quantiles)
|
||||
all_quantiles.append(quantiles)
|
||||
|
||||
all_quantiles = np.array(all_quantiles)
|
||||
|
||||
# create dictionary
|
||||
quantile_dict = {}
|
||||
quantile_dict["quantiles"] = self.quantiles
|
||||
quantile_dict["quantile_values"] = all_quantiles
|
||||
|
||||
if task:
|
||||
task.upload_artifact("dictionary", quantile_dict)
|
||||
self.finish_training(quantile_values=all_quantiles, task=task)
|
||||
task.close()
|
||||
except Exception:
|
||||
if task:
|
||||
task.close()
|
||||
task.set_archived(True)
|
||||
raise
|
||||
|
||||
def finish_training(self, quantile_values, task):
|
||||
|
||||
fig = self.plot_quantiles(quantile_values)
|
||||
task.get_logger().report_plotly(
|
||||
title=f"Training Quantile Values",
|
||||
series="Quantile Values",
|
||||
figure=fig
|
||||
)
|
||||
|
||||
|
||||
def plot_quantiles(self, quantile_values):
|
||||
fig = go.Figure()
|
||||
|
||||
for i, q in enumerate(self.quantiles):
|
||||
values_for_quantile = quantile_values[:, i]
|
||||
fig.add_trace(go.Scatter(x=np.arange(96), y=values_for_quantile, name=f"Prediction (Q={q})", line=dict(dash='dash')))
|
||||
|
||||
fig.update_layout(title="Quantile Values")
|
||||
fig.update_yaxes(range=[-1, 1])
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user