Improved policy executer

This commit is contained in:
Victor Mylle
2024-01-16 23:22:05 +00:00
parent d1074281c4
commit b87ad1bf42
7 changed files with 1328 additions and 101 deletions

View File

@@ -2,12 +2,13 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../..')"
"sys.path.append('../..')\n",
"import torch"
]
},
{
@@ -68,57 +69,30 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n"
"ClearML Task: created new task id=b71216825809432682ea3c7841c07612\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/b71216825809432682ea3c7841c07612/output/log\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"InsecureRequestWarning: Certificate verification is disabled! Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClearML Task: created new task id=348145474a1140a6bdf8c81553a358b2\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/348145474a1140a6bdf8c81553a358b2/output/log\n",
"2023-12-28 20:57:29,259 - clearml.Task - INFO - Storing jupyter notebook directly as code\n",
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n",
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n",
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n",
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n",
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n",
"Index(['datetime', 'nrv', 'load_forecast', 'total_load', 'wind_forecast',\n",
" 'wind_history', 'nominal_net_position', 'quarter', 'day_of_week'],\n",
" dtype='object')\n"
"500 model found when searching for `file:///workspaces/Thesis/src/notebooks/checkpoint.pt`\n",
"Selected model `Autoregressive Non Linear Quantile Regression + Quarter + DoW + Net` (id=bc0cb0d7fc614e2e8b0edf5b85348646)\n"
]
}
],
"source": [
"inputDim = data_processor.get_input_size()\n",
"learningRate = 0.0001\n",
"epochs=5000\n",
"epochs=150\n",
"\n",
"#### Model ####\n",
"model = SimpleDiffusionModel(96, [512, 512, 512], other_inputs_dim=inputDim[1], time_dim=64)\n",
@@ -138,6 +112,116 @@
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'src.models.diffusion_model.SimpleDiffusionModel'>\n"
]
}
],
"source": [
"new_model = torch.load(\"checkpoint.pt\")\n",
"print(type(new_model))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Determine threshold based on predictions\n",
"from src.models.diffusion_model import DiffusionModel\n",
"\n",
"\n",
"def get_predicted_NRV(date):\n",
" idx = test_loader.dataset.get_idx_for_date(date.date())\n",
" initial, _, samples, target = auto_regressive(test_loader.dataset, [idx]*500, 96)\n",
" samples = samples.cpu().numpy()\n",
" target = target.cpu().numpy()\n",
"\n",
" # inverse using data_processor\n",
" samples = data_processor.inverse_transform(samples)\n",
" target = data_processor.inverse_transform(target)\n",
"\n",
" return initial.cpu().numpy()[0][-1], samples, target\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def sample_diffusion(model: DiffusionModel, n: int, inputs: torch.tensor):\n",
" noise_steps = 1000\n",
" beta_start = 1e-4\n",
" beta_end = 0.02\n",
" ts_length = 96\n",
" \n",
" beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)\n",
" alpha = 1. - beta\n",
" alpha_hat = torch.cumprod(alpha, dim=0)\n",
"\n",
" inputs = inputs.repeat(n, 1).to(device)\n",
" model.eval()\n",
" with torch.no_grad():\n",
" x = torch.randn(inputs.shape[0], ts_length).to(device)\n",
" for i in reversed(range(1, noise_steps)):\n",
" t = (torch.ones(inputs.shape[0]) * i).long().to(device)\n",
" predicted_noise = model(x, t, inputs)\n",
" _alpha = alpha[t][:, None]\n",
" _alpha_hat = alpha_hat[t][:, None]\n",
" _beta = beta[t][:, None]\n",
"\n",
" if i > 1:\n",
" noise = torch.randn_like(x)\n",
" else:\n",
" noise = torch.zeros_like(x)\n",
"\n",
" x = 1/torch.sqrt(_alpha) * (x-((1-_alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise) + torch.sqrt(_beta) * noise\n",
" return x\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-178.8835, -47.2518, -103.9158, -9.8302, 15.9751, 138.9138,\n",
" -56.8392, -128.0629, -128.3637, -83.1066, 56.6656, -200.4618,\n",
" 10.8563, -146.4262, 120.4816, -60.1130, -18.7972, -214.0427,\n",
" 148.1229, 136.0194, 33.7580, 85.7884, -164.5678, 53.8879,\n",
" 187.6217, -77.5978, 153.7462, -129.1419, -149.8551, 118.4640,\n",
" -29.4688, -37.3348, -104.4318, -16.1735, -29.9716, -1.4205,\n",
" -130.6785, 23.8387, 75.6755, 113.8617, -61.4832, -81.3838,\n",
" -15.3194, -63.5703, 215.4112, 8.0719, 26.4597, 72.4347,\n",
" -23.1216, 44.8453, -12.2994, 94.7612, -162.2193, 18.0694,\n",
" 31.2402, 78.6964, 35.1892, -105.0744, 38.7805, -27.5867,\n",
" 39.5985, 136.5500, -179.8039, 231.9039, 116.1411, -226.0043,\n",
" -149.2595, -14.5097, 123.5570, 162.4510, -62.9467, -82.3552,\n",
" 187.5180, 12.3145, -189.3492, -159.3642, -144.8646, 130.9768,\n",
" -79.4541, 53.5424, 35.7119, 134.5416, -87.5582, 70.4020,\n",
" -44.0516, 111.3181, 17.0087, -14.9322, -187.4202, -41.7765,\n",
" 11.2264, 221.0164, -106.3083, -123.9814, -12.2132, -121.7845]],\n",
" device='cuda:0')"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = torch.randn(1, 672).to(device)\n",
"sample_diffusion(new_model, 1, inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,