Changed steps in diffusion model

This commit is contained in:
Victor Mylle
2024-01-20 09:44:14 +00:00
parent c6fa17fa40
commit acaad2710a
6 changed files with 106 additions and 25 deletions

View File

@@ -51,7 +51,7 @@ class DiffusionTrainer:
self.model = model
self.device = device
self.noise_steps = 1000
self.noise_steps = 20
self.beta_start = 1e-4
self.beta_end = 0.02
self.ts_length = 96
@@ -130,8 +130,8 @@ class DiffusionTrainer:
predict_sequence_length=self.ts_length
)
train_sample_indices = self.random_samples(train=True, num_samples=10)
test_sample_indices = self.random_samples(train=False, num_samples=10)
train_sample_indices = self.random_samples(train=True, num_samples=5)
test_sample_indices = self.random_samples(train=False, num_samples=5)
for epoch in range(epochs):
running_loss = 0.0
@@ -153,7 +153,7 @@ class DiffusionTrainer:
running_loss /= len(train_loader.dataset)
if epoch % 20 == 0 and epoch != 0:
if epoch % 40 == 0 and epoch != 0:
self.test(test_loader, epoch, task)
if task:
@@ -164,7 +164,7 @@ class DiffusionTrainer:
value=loss.item(),
)
if epoch % 100 == 0 and epoch != 0:
if epoch % 150 == 0 and epoch != 0:
self.debug_plots(task, True, train_loader, train_sample_indices, epoch)
self.debug_plots(task, False, test_loader, test_sample_indices, epoch)
@@ -177,6 +177,7 @@ class DiffusionTrainer:
features, target, _ = data_loader.dataset[idx]
features = features.to(self.device)
features = features.unsqueeze(0)
self.model.eval()
with torch.no_grad():