Implemented Non Autorgressive Quantile Regression
This commit is contained in:
@@ -19,7 +19,12 @@ class AutoRegressiveTrainer(Trainer):
|
||||
fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f'Sample {i+1}' for i in range(num_samples)])
|
||||
|
||||
for i, idx in enumerate(sample_indices):
|
||||
initial, predictions, target = self.auto_regressive(data_loader, idx)
|
||||
auto_regressive_output = self.auto_regressive(data_loader, idx)
|
||||
if len(auto_regressive_output) == 3:
|
||||
initial, predictions, target = auto_regressive_output
|
||||
else:
|
||||
initial, predictions, _, target = auto_regressive_output
|
||||
|
||||
sub_fig = self.get_plot(initial, target, predictions, show_legend=(i == 0))
|
||||
|
||||
row = i + 1
|
||||
|
||||
Reference in New Issue
Block a user