Implemented Non Autorgressive Quantile Regression

This commit is contained in:
Victor Mylle
2023-11-18 17:42:06 +00:00
parent 75f1f64c38
commit 1268af47a6
9 changed files with 196493 additions and 161 deletions

View File

@@ -19,7 +19,12 @@ class AutoRegressiveTrainer(Trainer):
fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f'Sample {i+1}' for i in range(num_samples)])
for i, idx in enumerate(sample_indices):
initial, predictions, target = self.auto_regressive(data_loader, idx)
auto_regressive_output = self.auto_regressive(data_loader, idx)
if len(auto_regressive_output) == 3:
initial, predictions, target = auto_regressive_output
else:
initial, predictions, _, target = auto_regressive_output
sub_fig = self.get_plot(initial, target, predictions, show_legend=(i == 0))
row = i + 1