Made more changes

This commit is contained in:
2024-04-16 22:07:53 +02:00
parent 937b6abc0b
commit 0edcc91e65
12 changed files with 214 additions and 36 deletions

View File

@@ -208,7 +208,7 @@ class DiffusionTrainer:
running_loss /= len(train_loader.dataset)
if epoch % 150 == 0 and epoch != 0:
if epoch % 75 == 0 and epoch != 0:
crps, _ = self.test(test_loader, epoch, task)
if best_crps is None or crps < best_crps:
@@ -217,7 +217,7 @@ class DiffusionTrainer:
else:
early_stopping += 1
if early_stopping > 15:
if early_stopping > 5:
break
if task:
@@ -249,7 +249,7 @@ class DiffusionTrainer:
test_loader=test_loader,
initial_penalty=900,
target_charge_cycles=283,
learning_rate=1,
initial_learning_rate=1,
max_iterations=50,
tolerance=1,
)
@@ -438,9 +438,10 @@ class DiffusionTrainer:
test_loader=test_loader,
initial_penalty=self.prev_optimal_penalty,
target_charge_cycles=283,
learning_rate=1,
initial_learning_rate=1,
max_iterations=50,
tolerance=1,
iteration=epoch,
)
)

View File

@@ -192,9 +192,10 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
test_loader=dataloader,
initial_penalty=900,
target_charge_cycles=283,
learning_rate=2,
initial_learning_rate=5,
max_iterations=100,
tolerance=1,
iteration=epoch,
)
)
@@ -823,7 +824,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
test_loader=dataloader,
initial_penalty=500,
target_charge_cycles=283,
learning_rate=2,
initial_learning_rate=2,
max_iterations=100,
tolerance=1,
)