From 0edcc91e658f094cf7fcf34ecaa217ad387cced1 Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Tue, 16 Apr 2024 22:07:53 +0200 Subject: [PATCH] Made more changes --- Reports/Thesis/sections/nrv_prediction.tex | 2 +- Reports/Thesis/verslag.pdf | Bin 802281 -> 802284 bytes Reports/Thesis/verslag.synctex.gz | Bin 110374 -> 110394 bytes src/notebooks/thesis-visualizations.ipynb | 2 +- src/policies/PolicyEvaluator.py | 98 ++++++++++++++-- src/policies/baselines/PerfectBaseline.py | 14 ++- .../YesterdayBaselinePolicyExecutor.py | 105 ++++++++++++++++-- .../baselines/global_threshold_baseline.py | 11 +- src/policies/baselines/perfect_baseline.py | 2 +- src/trainers/diffusion_trainer.py | 9 +- src/trainers/quantile_trainer.py | 5 +- .../autoregressive_quantiles.py | 2 +- 12 files changed, 214 insertions(+), 36 deletions(-) diff --git a/Reports/Thesis/sections/nrv_prediction.tex b/Reports/Thesis/sections/nrv_prediction.tex index cb6316c..a8464d6 100644 --- a/Reports/Thesis/sections/nrv_prediction.tex +++ b/Reports/Thesis/sections/nrv_prediction.tex @@ -162,7 +162,7 @@ where: \item \(\beta_{0,\tau}, \beta_{1,\tau}, \beta_{2,\tau}, \ldots, \beta_{n,\tau} \) are the coefficients \end{itemize} -The linear model outputs the values for the chosen quantiles. The total amount of parameters depends on the input features and the number of chosen quantiles. Assuming the input features are the 96 previous NRV values and 13 quantiles are chosen, the total amount of parameters is 96 * 13 = 1248. +The linear model outputs the values for the chosen quantiles. The total amount of parameters depends on the input features and the number of chosen quantiles. Assuming the input features are the 96 previous NRV values and 13 quantiles are chosen, the total amount of parameters is 96 * 13 + 13 = 1261. TODO: add results for this model diff --git a/Reports/Thesis/verslag.pdf b/Reports/Thesis/verslag.pdf index 70546f90cfe484ee28feb302a298d7c5527fd840..968d43c072c5238e07d50766e604d9e4ec231e4e 100644 GIT binary patch delta 1743 zcmV;=1~B>Q`!MYLFo1*sgaU*Ev;C=s7cCsqV;fm6;-ZiLzK}S%NpSMp1@`05>xHc=emLMM`IgyGM6ll+seGoHTr<;{yMryQprvYgig>GHmWR_>3>7n z*KOl}cDAkmwXd71^m9NoeZ8;T>X@mty0PD`n+_gYLu*s*vwzpWuPzx|ibg0!k|jA+ zhzsdN`)dM9q$MT>Cy|sS$qAC0In)&{m5dta6 zN5oYjum_1Hh$04t^`Lh#aTH>LjffRkM5^?ERn!s+2{s|GFqyTR%^*-w$iNIZVyps5 zJSq3`K0^V&A~4kRs8|>={1~xBAQpLwSOG(hv?5EfK1UdGB&3cqq$fTUNR%OYmGyN<%$r7m@1MW**O+} z@UdP_a>g)H97q)=OFib(uCk@)g~3eaaCIp|XKO?WIFd3DSbOS};FCjlA_TK3&wqG* z`Vy~CoT?|)r_b^Dp);9P%pDR@HWk9Ub49)1RWY%t28N4N8ch&V*JcJSBMa*b+ zQa86l@7ros_JeDgZo`CnMkI7#=*Z{>!LFwBQ7znfYtLjd{aCj{H>+zWAItqvbz{EV zuRbM+WBOIe@Nl^-&NnLggpS}zy@cC}koU0$`83x7Z}JUF_2YS~fH(b?dPnPjiJ()Y z*{UCJahlB>@94nJ+izL+PXe7HrS($C(go&e@|rtfWvE=&RnvE_juR>nFb(2wZO#-} zO!rJPpYM=%*$zDF)_O?>fPMlR?TGkd%bJBs&VtN)O3(tOyH3x&}#eA^C%DP z(Da_)){VF2u{>0{6&{250tw=O)rOn1$wI|I%3G^!v+*0YJi+L!q4<>kuotF3Y?RF= zK(n+N4*q(6^z^oF95pf|ml_qCrLQwW)2K%EMqETWiQDaPZlX$ugH zcM1@ITmgcgqbn>Gs}cMqh1-~9b=xzS!mIR?S;(2C_lhOvybmGVRU66C`39kFfueSG%+?ZF)}wWFt?>K44Mg1ML9V!GdDOwGB7bSH#0CYI5asi zF-9;#MK?t`L^wk*J|H|rIXN*iH#kBvFflVXGcYnZG&wObMleD}H$^!_I72W#T?#Ku zWo~D5XdpE*GnbG%3@Lvl(#>y7VHC&lbMAAiGj*$@?P#@SsOp=lc3Q9r z*V>4Lgru<|u@OlVB+~8k2Uv>4&R?Nvh-hPNXCV>c`5v23e&;^tzRct&BJwv9iA2+d zvXE3Is)$NYpN%OMsk5{Jssp|j5@_OD1WqM0YkM(NsSLK+xDkKSDsxY4tl_pzq|K11 zRF>?3hemc#$4Z7|G(x?~oi*EQfUS^KxnHttTezd1WFOQF+o45e@u#h{vci+2HrNKc zAgA*5o~`VFozTuQFY>(H^QGS2(X8@n(Ux~ZH*~1HcxTr-p-WxP1&bbSRo9Rgtce(xeK> zdU}myO}tKWq#JNkH%RY3*L->(XS^Q!Nwli#RTh06aP z%D!%Y8^5z_!>>c#RArb0Mbp=Z%CAl^Wli9l24H7Y?A~)cK zvlvo%GVb+bjskfsa@5yJwJ_rNIcmvBE%FkzB7p)KO_pkXN(dB4Dw7mQ&tfRae`O;5 zEEDN-Ow>t%PnlRyO7k1am7gVTW{|l=Cv4^yv#oj5CWA~eYW!l<8c!9&6(JmP&9bO} z*aepGaZ%0+&N0#gSTS3rf#&1Ba?tb2(oAr;zE*Lt4Wa}BNjWHN0(VOA*`Ws+g85YE z|2sbYlg1}O^^5T7Q#ya>%x0YmheDK3Kv;Jk)cbvvQkxnOxG1I31u;i~gqb)N{XW%8 z+TAS3O>Wy}TibS29{k3(?)p>xewbT-)Ppm>{Pz3BWGFNgtYS}QL+sQz8{VZ;7;(W; zPIHsGxgYw_R;#ife9QD3Hq}=|K?jD8jP4NZ8af{}(vSD<%u3}S>UQX6bKT@ac^Im0 zte3~lrvwR1KdTs?hD&z2QK=Vnguv=G+*YJ|Of=|cgo$L6F96lgm#rh&^h?@*ox~GK zr%ZFGpKfu6&k`Tlz%To63HxV>PLVcdsbuLB^SpTN9k8-oZtJS)JCEapN(4fq{M*FT#hXaaGJfi^i6ns-aV?xw~w-yl*)eoDie*uRXVkwtVS_Bh^ zci#kuci#nvci#qwci#txci#wyci#zzci#%Pci#&IQYthuF)%eSG%GMNConK4DGD!5 zZ)8MabY&nYL^?7sGBhwWF*GtUFf}kVx27=+nh8-hL^wn?I6^W)G%`0vGe$QxL^v=s zH#0LdF*PtbF+?~%AUrliI7Bu$LNY=$GB-vuMmIG?I50FfGcz$Bo67PpTWM#7Cr2Hy6?Q^i z<>_5p*$z9PlVx7ydAa9Hy}hGV<<+t+?}A?FQhD*tu609?#)X$Q?u9;$M{_pz!)_RW zJunCb7}9v`uYc_gYh3BIbO82gT>fO^2<(SZjlVn-jep#?gTru8<8QBRJOo8`r8_nr zhcR{K$2J~;qcF}d%8`zd>_0sNC3TZ)wtNCkLK#lMX_!&Buw*M0m{2!2V&hpj2b1bP zzqe}_;Jmu8<2Fvgw7Tz)Y`h4wFvn+DU*_QwV~TW{M1S?7D^^bP?E}(0k ia`_%jh6{d&=M`)69|0y^cbCIG3?T?K3MC~)Peuy*7bNQd diff --git a/Reports/Thesis/verslag.synctex.gz b/Reports/Thesis/verslag.synctex.gz index cac3cd398854b3af6712ced1499f1d0b7c375a71..401bedfa8fd2f31bd5c641998883e7eda339cb58 100644 GIT binary patch delta 624 zcmV-$0+0Qs-Uhnf27t5y+N^)V$(4^3H6I0`{1wAsN5H(1*4Mj8Wz(@jJ=(%4lD zA{dJqZ#!dUT374I->hsy^GsJeAMzw!{Nj)&b<}Ssd6P8ix3lDalr-wMYfQ1hf;38F ziiaPG43j)^Dv87(%>xvbT3uyr0!1Z3ffEUIeFOzxh(uTLWnV-2*NU`icfz^nr34 zi&Pwzp5;xFP;)ah$^9hp*@M7FV+!akgC!WO(qzmK3=&O{ZbU`6J*q^aPaHwOef13G z@af4jE5s+J%Gb|OV2*!9htnc&@d2PPg}$-|B!+?`u-w9c(;AE~tuTUEJ5VzsPWVEm zDBPM8X&cJ2BH&bkh-7gmyL)OzMk1kX0VdJIsROpYId&*93-r+TlvzPaoD*68BD3L~ zm(nMCd|%;n7RF)<%ejP3V5x%dOJ9*CXC7|;XX4Gz&tG1Sa6^B$t6vYB{bBWWdw1OJ zpT0dlZdSi+Kke81Z>#r@yO+-&ZD@3IjnyDY4dzsz27~_12_BMKi+LO&v%=Z-Msp9z22|>{m168 zJ*@t<`Lf#|594IdcdM^r)Z_a8-TLJ}ld<{lYV2Wiciirt53B9-aliZgQh!~TUjFu` ze_G!?-0ycU&!4N|uS1A8|J)sp>!(kTo7@qjn4ZK{%2hr9ju%kkyu_U8YO KxDh?qWCsBBZA8%k delta 604 zcmV-i0;B!9-Ug=L27t5y+N^(a2vGpT>nI{GcY1xYa2k*_d_P;Jz5bq%&C1p=&%FNg z!AsJ`FAiQ(N5^)OH%X&oJ4^0INuy)C#uOVYh@3R0c#Q4HFv%m&kw^?uFF+fq)m0`X z&_)szc#1%wM^NxjNG>KQ@b&si{H9RyHu6x3-xMm~HiVRzC?el0cJ+UH7++z>_jZ0M zrpW$8cH84aNtxOQE@G9jT$d)$J>c^2t!ThSA1Jr6NX234S>7ZGH8(?(+)ompJxFCV zrhsxXSc1XIDaV0eko;3BAe{2rh!|Gp~FT4HmF#deE`YJ{}uJ7Ni zU;Z;0oByuH9yU;S$L;R#{eJiI{J9$bI)r%h&)wm; qe){ydxxINWlH+YIKHj{xsY+fQ?)KX+$CszuoBso%KN`wwWCsB8OGC^6 diff --git a/src/notebooks/thesis-visualizations.ipynb b/src/notebooks/thesis-visualizations.ipynb index c812f82..76725a4 100644 --- a/src/notebooks/thesis-visualizations.ipynb +++ b/src/notebooks/thesis-visualizations.ipynb @@ -171,7 +171,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.undefined.undefined" } }, "nbformat": 4, diff --git a/src/policies/PolicyEvaluator.py b/src/policies/PolicyEvaluator.py index 2fef398..2a97854 100644 --- a/src/policies/PolicyEvaluator.py +++ b/src/policies/PolicyEvaluator.py @@ -44,8 +44,8 @@ class PolicyEvaluator: date, idx_samples, test_loader, - charge_thresholds=np.arange(-1500, 1500, 50), - discharge_thresholds=np.arange(-1500, 1500, 50), + charge_thresholds=np.arange(-1000, 1000, 5), + discharge_thresholds=np.arange(-1000, 1000, 5), penalty: int = 0, state_of_charge: float = 0.0, ): @@ -96,6 +96,7 @@ class PolicyEvaluator: max_iterations=10, tolerance=10, learning_rate_decay=0.9, # Factor to reduce the learning rate after each iteration + iteration=0, ): self.cache = {} penalty = initial_penalty @@ -139,7 +140,7 @@ class PolicyEvaluator: # Re-calculate profit and charge cycles for the final penalty to return accurate results profit, charge_cycles = self.evaluate_test_set_for_penalty( - idx_samples, test_loader, penalty + idx_samples, test_loader, penalty, log_metrics=True, iteration=iteration ) return penalty, profit, charge_cycles @@ -232,25 +233,45 @@ class PolicyEvaluator: ], ) - def evaluate_test_set_for_penalty(self, idx_samples, test_loader, penalty): + def evaluate_test_set_for_penalty( + self, idx_samples, test_loader, penalty, log_metrics=False, iteration: int = 0 + ): total_profit = 0 total_charge_cycles = 0 state_of_charge = 0.0 + loggings = [] + for date in tqdm(self.dates): try: - profit, charge_cycles, _, _, new_state_of_charge = ( - self.evaluate_for_date( - date, - idx_samples, - test_loader, - penalty=penalty, - state_of_charge=state_of_charge, - ) + ( + profit, + charge_cycles, + charge_thresholds, + discharge_thresholds, + new_state_of_charge, + ) = self.evaluate_for_date( + date, + idx_samples, + test_loader, + penalty=penalty, + state_of_charge=state_of_charge, ) state_of_charge = new_state_of_charge total_profit += profit total_charge_cycles += charge_cycles + + new_info = { + "Date": date, + "Profit": profit, + "Charge Cycles": charge_cycles, + "State of Charge": state_of_charge, + "Charge Threshold": charge_thresholds, + "Discharge Threshold": discharge_thresholds, + } + + loggings.append(new_info) + except KeyboardInterrupt: print("Interrupted") raise KeyboardInterrupt @@ -259,6 +280,59 @@ class PolicyEvaluator: print(e) pass + if log_metrics: + log_df = pd.DataFrame(loggings) + + fig = px.line( + log_df, + x="Date", + y="Profit", + title="Profit over time", + labels={"Profit": "Profit (€)", "Date": "Date"}, + ) + + self.task.get_logger().report_plotly( + "Profit", "Profit", iteration=iteration, figure=fig + ) + + fig = px.line( + log_df, + x="Date", + y="Charge Cycles", + title="Charge Cycles over time", + labels={"Charge Cycles": "Charge Cycles", "Date": "Date"}, + ) + + self.task.get_logger().report_plotly( + "Charge Cycles", "Charge Cycles", iteration=iteration, figure=fig + ) + + fig = px.line( + log_df, + x="Date", + y="State of Charge", + title="State of Charge over time", + labels={"State of Charge": "State of Charge", "Date": "Date"}, + ) + + self.task.get_logger().report_plotly( + "State of Charge", "State of Charge", iteration=iteration, figure=fig + ) + + fig = px.line( + log_df, + x="Date", + y=["Charge Threshold", "Discharge Threshold"], + title="Charge and Discharge Thresholds per Day", + ) + + self.task.get_logger().report_plotly( + "Thresholds per Day", + "Thresholds per Day", + iteration=iteration, + figure=fig, + ) + return total_profit, total_charge_cycles def plot_profits_table(self): diff --git a/src/policies/baselines/PerfectBaseline.py b/src/policies/baselines/PerfectBaseline.py index b2b928b..549f1aa 100644 --- a/src/policies/baselines/PerfectBaseline.py +++ b/src/policies/baselines/PerfectBaseline.py @@ -1,5 +1,5 @@ from clearml import Task -from policies.simple_baseline import BaselinePolicy +from src.policies.simple_baseline import BaselinePolicy from src.policies.baselines.YesterdayBaselinePolicyExecutor import ( YesterdayBaselinePolicyEvaluator, ) @@ -14,17 +14,21 @@ class PerfectBaseline(YesterdayBaselinePolicyEvaluator): def evaluate_for_date( self, date, - charge_thresholds=np.arange(-100, 250, 25), - discharge_thresholds=np.arange(-100, 250, 25), + charge_thresholds=np.arange(-300, 300, 5), + discharge_thresholds=np.arange(-300, 300, 5), penalty: int = 0, current_state_of_charge=0.0, ): real_imbalance_prices = self.get_imbanlance_prices_for_date(date.date()) + real_imbalance_prices_tensor = torch.tensor( + np.array([real_imbalance_prices]), device="cpu" + ) + best_charge_thresholds, best_discharge_thresholds = ( self.baseline_policy.get_optimal_thresholds( - real_imbalance_prices, + real_imbalance_prices_tensor, charge_thresholds, discharge_thresholds, penalty, @@ -45,4 +49,6 @@ class PerfectBaseline(YesterdayBaselinePolicyEvaluator): best_profit[0][0].item(), best_charge_cycles[0][0].item(), new_state_of_charge.squeeze(0).item(), + best_charge_thresholds.mean(axis=0), + best_discharge_thresholds.mean(axis=0), ) diff --git a/src/policies/baselines/YesterdayBaselinePolicyExecutor.py b/src/policies/baselines/YesterdayBaselinePolicyExecutor.py index f849a68..a38d7c0 100644 --- a/src/policies/baselines/YesterdayBaselinePolicyExecutor.py +++ b/src/policies/baselines/YesterdayBaselinePolicyExecutor.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd from tqdm import tqdm import torch +import plotly.express as px class YesterdayBaselinePolicyEvaluator(PolicyEvaluator): @@ -15,8 +16,8 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator): def evaluate_for_date( self, date, - charge_thresholds=np.arange(-100, 250, 25), - discharge_thresholds=np.arange(-100, 250, 25), + charge_thresholds=np.arange(-500, 500, 5), + discharge_thresholds=np.arange(-500, 500, 5), penalty: int = 0, current_state_of_charge=0.0, ): @@ -52,9 +53,13 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator): yesterday_profit[0][0].item(), yesterday_charge_cycles[0][0].item(), new_state_of_charge.squeeze(0).item(), + yesterday_charge_thresholds.mean(axis=0), + yesterday_discharge_thresholds.mean(axis=0), ) - def evaluate_test_set_for_penalty(self, data_processor, penalty: int = 0): + def evaluate_test_set_for_penalty( + self, data_processor, penalty: int = 0, log_metrics=False + ): if data_processor: filtered_dates = [] @@ -71,20 +76,89 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator): charge_cycles = 0 state_of_charge = 0.0 + loggings = [] + for date in tqdm(self.dates): try: - new_profit, new_charge_cycles, new_state_of_charge = ( - self.evaluate_for_date( - date, penalty=penalty, current_state_of_charge=state_of_charge - ) + ( + new_profit, + new_charge_cycles, + new_state_of_charge, + charge_threshold, + discharge_threshold, + ) = self.evaluate_for_date( + date, penalty=penalty, current_state_of_charge=state_of_charge ) + profit += new_profit charge_cycles += new_charge_cycles state_of_charge = new_state_of_charge + + new_info = { + "Date": date, + "Profit": profit, + "Charge Cycles": charge_cycles, + "State of Charge": state_of_charge, + "Charge Threshold": charge_threshold.item(), + "Discharge Threshold": discharge_threshold.item(), + } + + loggings.append(new_info) + except Exception as e: print(e) pass + if log_metrics: + log_df = pd.DataFrame(loggings) + + fig = px.line( + log_df, + x="Date", + y="Profit", + title="Profit over time", + labels={"Profit": "Profit (€)", "Date": "Date"}, + ) + + self.task.get_logger().report_plotly( + "Profit", "Profit", iteration=0, figure=fig + ) + + fig = px.line( + log_df, + x="Date", + y="Charge Cycles", + title="Charge Cycles over time", + labels={"Charge Cycles": "Charge Cycles", "Date": "Date"}, + ) + + self.task.get_logger().report_plotly( + "Charge Cycles", "Charge Cycles", iteration=0, figure=fig + ) + + fig = px.line( + log_df, + x="Date", + y="State of Charge", + title="State of Charge over time", + labels={"State of Charge": "State of Charge", "Date": "Date"}, + ) + + self.task.get_logger().report_plotly( + "State of Charge", "State of Charge", iteration=0, figure=fig + ) + + fig = px.line( + log_df, + x="Date", + y=["Charge Threshold", "Discharge Threshold"], + title="Charge and Discharge Thresholds per Day", + ) + + self.task.get_logger().report_plotly( + "Thresholds per Day", "Thresholds per Day", iteration=0, figure=fig + ) + return profit, charge_cycles def optimize_penalty_for_target_charge_cycles( @@ -108,6 +182,21 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator): f"Penalty: {penalty}, Charge Cycles: {simulated_charge_cycles}, Profit: {simulated_profit}" ) + self.task.get_logger().report_scalar( + "Penalty", "Penalty", penalty, iteration=iteration + ) + + self.task.get_logger().report_scalar( + "Charge Cycles", + "Charge Cycles", + simulated_charge_cycles, + iteration=iteration, + ) + + self.task.get_logger().report_scalar( + "Profit", "Profit", simulated_profit, iteration=iteration + ) + # Calculate the gradient (difference) between the simulated and target charge cycles gradient = simulated_charge_cycles - target_charge_cycles @@ -125,7 +214,7 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator): # Re-calculate profit and charge cycles for the final penalty to return accurate results profit, charge_cycles = self.evaluate_test_set_for_penalty( - data_processor, penalty + data_processor, penalty, log_metrics=True ) return penalty, profit, charge_cycles diff --git a/src/policies/baselines/global_threshold_baseline.py b/src/policies/baselines/global_threshold_baseline.py index fb7ea79..998c743 100644 --- a/src/policies/baselines/global_threshold_baseline.py +++ b/src/policies/baselines/global_threshold_baseline.py @@ -55,8 +55,15 @@ charge_discharge_threshold, total_profit, total_charge_cycles = ( policy_evaluator.determine_best_thresholds_test_set(data_processor) ) -task.get_logger().report_single_value(name="Optimal Profit", value=profit) -task.get_logger().report_single_value(name="Optimal Charge Cycles", value=charge_cycles) +print("Thresholds determined on test set") +print(f"Best Charge Discharge Threshold: {charge_discharge_threshold}") +print(f"Total Profit: {total_profit}") +print(f"Total Charge Cycles: {total_charge_cycles}") + +task.get_logger().report_single_value(name="Optimal Profit", value=total_profit) +task.get_logger().report_single_value( + name="Optimal Charge Cycles", value=total_charge_cycles +) task.get_logger().report_single_value( name="Optimal Charge Threshold", value=charge_discharge_threshold[0] ) diff --git a/src/policies/baselines/perfect_baseline.py b/src/policies/baselines/perfect_baseline.py index fee6ccb..10e68a2 100644 --- a/src/policies/baselines/perfect_baseline.py +++ b/src/policies/baselines/perfect_baseline.py @@ -7,7 +7,7 @@ task.execute_remotely(queue_name="default", exit_process=True) from src.policies.simple_baseline import BaselinePolicy, Battery from src.data import DataProcessor, DataConfig -from policies.baselines.PerfectBaseline import PerfectBaseline +from src.policies.baselines.PerfectBaseline import PerfectBaseline ### Data Processor ### data_config = DataConfig() diff --git a/src/trainers/diffusion_trainer.py b/src/trainers/diffusion_trainer.py index 7d69e13..b51671b 100644 --- a/src/trainers/diffusion_trainer.py +++ b/src/trainers/diffusion_trainer.py @@ -208,7 +208,7 @@ class DiffusionTrainer: running_loss /= len(train_loader.dataset) - if epoch % 150 == 0 and epoch != 0: + if epoch % 75 == 0 and epoch != 0: crps, _ = self.test(test_loader, epoch, task) if best_crps is None or crps < best_crps: @@ -217,7 +217,7 @@ class DiffusionTrainer: else: early_stopping += 1 - if early_stopping > 15: + if early_stopping > 5: break if task: @@ -249,7 +249,7 @@ class DiffusionTrainer: test_loader=test_loader, initial_penalty=900, target_charge_cycles=283, - learning_rate=1, + initial_learning_rate=1, max_iterations=50, tolerance=1, ) @@ -438,9 +438,10 @@ class DiffusionTrainer: test_loader=test_loader, initial_penalty=self.prev_optimal_penalty, target_charge_cycles=283, - learning_rate=1, + initial_learning_rate=1, max_iterations=50, tolerance=1, + iteration=epoch, ) ) diff --git a/src/trainers/quantile_trainer.py b/src/trainers/quantile_trainer.py index 378dbbd..2f4c748 100644 --- a/src/trainers/quantile_trainer.py +++ b/src/trainers/quantile_trainer.py @@ -192,9 +192,10 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer): test_loader=dataloader, initial_penalty=900, target_charge_cycles=283, - learning_rate=2, + initial_learning_rate=5, max_iterations=100, tolerance=1, + iteration=epoch, ) ) @@ -823,7 +824,7 @@ class NonAutoRegressiveQuantileRegression(Trainer): test_loader=dataloader, initial_penalty=500, target_charge_cycles=283, - learning_rate=2, + initial_learning_rate=2, max_iterations=100, tolerance=1, ) diff --git a/src/training_scripts/autoregressive_quantiles.py b/src/training_scripts/autoregressive_quantiles.py index 949898c..425e267 100644 --- a/src/training_scripts/autoregressive_quantiles.py +++ b/src/training_scripts/autoregressive_quantiles.py @@ -113,7 +113,7 @@ trainer = AutoRegressiveQuantileTrainer( data_processor, quantiles, "cuda", - policy_evaluator=policy_evaluator, + policy_evaluator=None, debug=False, )