Plots to compare between quantile regression and diffusion
This commit is contained in:
@@ -86,7 +86,7 @@ class Trainer:
|
||||
|
||||
def random_samples(self, train: bool = True, num_samples: int = 10):
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
predict_sequence_length=96
|
||||
)
|
||||
|
||||
if train:
|
||||
@@ -94,7 +94,14 @@ class Trainer:
|
||||
else:
|
||||
loader = test_loader
|
||||
|
||||
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
|
||||
np.random.seed(42)
|
||||
actual_indices = np.random.choice(loader.dataset.full_day_valid_indices, num_samples, replace=False)
|
||||
indices = {}
|
||||
for i in actual_indices:
|
||||
indices[i] = loader.dataset.valid_indices.index(i)
|
||||
|
||||
print(actual_indices)
|
||||
|
||||
return indices
|
||||
|
||||
def train(self, epochs: int, remotely: bool = False, task: Task = None):
|
||||
@@ -107,8 +114,8 @@ class Trainer:
|
||||
predict_sequence_length=self.model.output_size
|
||||
)
|
||||
|
||||
train_samples = self.random_samples(train=True)
|
||||
test_samples = self.random_samples(train=False)
|
||||
train_samples = self.random_samples(train=True, num_samples=5)
|
||||
test_samples = self.random_samples(train=False, num_samples=5)
|
||||
|
||||
self.init_clearml_task(task)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user