Implemented Probabilistic Baseline

This commit is contained in:
Victor Mylle
2023-11-19 20:20:31 +00:00
parent 1268af47a6
commit 166d3967e1
4 changed files with 182 additions and 36 deletions

View File

@@ -10,6 +10,7 @@
"sys.path.append('..')\n",
"from data import DataProcessor, DataConfig\n",
"from trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression\n",
"from trainers.probabilistic_baseline import ProbabilisticBaselineTrainer\n",
"from trainers.autoregressive_trainer import AutoRegressiveTrainer\n",
"from trainers.trainer import Trainer\n",
"from utils.clearml import ClearMLHelper\n",
@@ -45,9 +46,44 @@
"data_config.WIND_FORECAST = False\n",
"data_config.WIND_HISTORY = False\n",
"\n",
"\n",
"data_processor = DataProcessor(data_config)\n",
"data_processor.set_batch_size(1024)"
"data_processor.set_batch_size(1024)\n",
"data_processor.set_full_day_skip(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Probabilistic Baseline"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClearML Task: created new task id=07ad9f41dfbb43ada3c15ec33a85050d\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/07ad9f41dfbb43ada3c15ec33a85050d/output/log\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n",
"JSON serialization of artifact 'dictionary' failed, reverting to pickle\n"
]
}
],
"source": [
"quantiles = [0.01, 0.05, 0.1, 0.15, 0.4, 0.5, 0.6, 0.85, 0.9, 0.95, 0.99]\n",
"trainer = ProbabilisticBaselineTrainer(quantiles=quantiles, data_processor=data_processor, clearml_helper=clearml_helper)\n",
"trainer.train()"
]
},
{
@@ -59,28 +95,33 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClearML Task: created new task id=c6bc2cb556b84fed81fa04f5b4a323ea\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/c6bc2cb556b84fed81fa04f5b4a323ea/output/log\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n"
"InsecureRequestWarning: Certificate verification is disabled! Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClearML Task: created new task id=11553d672a2744479de07c9ac0a9dbde\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/11553d672a2744479de07c9ac0a9dbde/output/log\n",
"2023-11-19 18:06:57,539 - clearml.Task - INFO - Storing jupyter notebook directly as code\n",
"2023-11-19 18:06:57,543 - clearml.Repository Detection - WARNING - Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n",
"2023-11-19 18:07:14,402 - clearml.model - WARNING - 500 model found when searching for `file:///workspaces/Thesis/src/notebooks/checkpoint.pt`\n",
"2023-11-19 18:07:14,403 - clearml.model - WARNING - Selected model `Non Autoregressive Quantile Regression` (id=bc0cb0d7fc614e2e8b0edf5b85348646)\n",
"2023-11-19 18:07:14,412 - clearml.frameworks - INFO - Found existing registered model id=bc0cb0d7fc614e2e8b0edf5b85348646 [/workspaces/Thesis/src/notebooks/checkpoint.pt] reusing it.\n",
"2023-11-19 18:07:14,974 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
"2023-11-19 18:07:16,827 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
"2023-11-19 18:07:18,465 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
"2023-11-19 18:07:20,045 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
"2023-11-19 18:07:21,843 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
"2023-11-19 18:07:28,812 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Non%20Autoregressive%20Model%20%28Non%20Linear%29%20using%20full%20day%20skip%20for%20training%20samples.11553d672a2744479de07c9ac0a9dbde/models/checkpoint.pt\n",
"Early stopping triggered\n"
]
}
@@ -187,36 +228,27 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspaces/Thesis/src/notebooks/../trainers/quantile_trainer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" quantiles_tensor = torch.tensor(quantiles)\n",
"/workspaces/Thesis/src/notebooks/../losses/pinball_loss.py:7: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)\n",
"InsecureRequestWarning: Certificate verification is disabled! Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n"
"/workspaces/Thesis/src/notebooks/../trainers/quantile_trainer.py:18: UserWarning:\n",
"\n",
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"\n",
"/workspaces/Thesis/src/notebooks/../losses/pinball_loss.py:7: UserWarning:\n",
"\n",
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClearML Task: created new task id=36976b1159074e698e2c19eb6a3bc290\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/36976b1159074e698e2c19eb6a3bc290/output/log\n",
"2023-11-16 08:46:07,715 - clearml.Task - INFO - Storing jupyter notebook directly as code\n",
"2023-11-16 08:46:07,719 - clearml.Repository Detection - WARNING - Can't get url information for git repo in /workspaces/Thesis/src/notebooks\n",
"2023-11-16 08:46:15,693 - clearml.model - WARNING - 500 model found when searching for `file:///workspaces/Thesis/src/notebooks/checkpoint.pt`\n",
"2023-11-16 08:46:15,694 - clearml.model - WARNING - Selected model `Quantile Regression: Non Linear with test score` (id=bc0cb0d7fc614e2e8b0edf5b85348646)\n",
"2023-11-16 08:46:15,702 - clearml.frameworks - INFO - Found existing registered model id=bc0cb0d7fc614e2e8b0edf5b85348646 [/workspaces/Thesis/src/notebooks/checkpoint.pt] reusing it.\n",
"2023-11-16 08:46:16,218 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
"2023-11-16 08:46:21,062 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
"2023-11-16 08:46:33,228 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
"2023-11-16 08:46:42,236 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
"2023-11-16 08:46:50,541 - clearml.Task - INFO - Completed model upload to http://192.168.1.182:8081/Thesis/NrvForecast/Quantile%20Regression%253A%20Non%20Linear%20Debugging%20%28plot%20every%201%20epoch%29.36976b1159074e698e2c19eb6a3bc290/models/checkpoint.pt\n",
"Early stopping triggered\n"
]
},
@@ -224,7 +256,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 25804/25804 [22:46<00:00, 18.88it/s]\n"
"100%|██████████| 25804/25804 [20:02<00:00, 21.45it/s]\n"
]
}
],