Changed imports
This commit is contained in:
@@ -2,15 +2,15 @@ import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import torch
|
||||
from data.dataset import NrvDataset
|
||||
from src.data.dataset import NrvDataset
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
|
||||
history_data_path = "../../data/history-quarter-hour-data.csv"
|
||||
forecast_data_path = "../../data/load_forecast.csv"
|
||||
pv_forecast_data_path = "../../data/pv_gen_forecast.csv"
|
||||
wind_forecast_data_path = "../../data/wind_gen_forecast.csv"
|
||||
history_data_path = "data/history-quarter-hour-data.csv"
|
||||
forecast_data_path = "data/load_forecast.csv"
|
||||
pv_forecast_data_path = "data/pv_gen_forecast.csv"
|
||||
wind_forecast_data_path = "data/wind_gen_forecast.csv"
|
||||
|
||||
|
||||
class DataConfig:
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from data import DataProcessor, DataConfig
|
||||
from trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression
|
||||
from trainers.probabilistic_baseline import ProbabilisticBaselineTrainer
|
||||
from trainers.autoregressive_trainer import AutoRegressiveTrainer
|
||||
from trainers.trainer import Trainer
|
||||
from utils.clearml import ClearMLHelper
|
||||
from models import *
|
||||
from losses import *
|
||||
from src.data import DataProcessor, DataConfig
|
||||
from src.trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression
|
||||
from src.trainers.probabilistic_baseline import ProbabilisticBaselineTrainer
|
||||
from src.trainers.autoregressive_trainer import AutoRegressiveTrainer
|
||||
from src.trainers.trainer import Trainer
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
from src.models import *
|
||||
from src.losses import *
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import MSELoss, L1Loss
|
||||
@@ -15,10 +13,6 @@ from datetime import datetime
|
||||
import pytz
|
||||
import torch.nn as nn
|
||||
|
||||
# auto reload
|
||||
%load_ext autoreload
|
||||
%autoreload 2
|
||||
|
||||
#### ClearML ####
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from clearml import OutputModel
|
||||
import torch
|
||||
from data.preprocessing import DataProcessor
|
||||
from utils.clearml import ClearMLHelper
|
||||
from utils.autoregressive import predict_auto_regressive
|
||||
from src.data.preprocessing import DataProcessor
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
from src.utils.autoregressive import predict_auto_regressive
|
||||
import plotly.graph_objects as go
|
||||
import numpy as np
|
||||
import plotly.subplots as sp
|
||||
from plotly.subplots import make_subplots
|
||||
from trainers.trainer import Trainer
|
||||
from src.trainers.trainer import Trainer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from losses import CRPSLoss
|
||||
from utils.clearml import ClearMLHelper
|
||||
from data.preprocessing import DataProcessor, DataConfig
|
||||
from src.losses import CRPSLoss
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
from src.data.preprocessing import DataProcessor, DataConfig
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from trainers.trainer import Trainer
|
||||
from src.trainers.trainer import Trainer
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import torch
|
||||
from utils.autoregressive import predict_auto_regressive_quantile
|
||||
from scipy.interpolate import interp1d
|
||||
from trainers.trainer import Trainer
|
||||
from trainers.autoregressive_trainer import AutoRegressiveTrainer
|
||||
from data.preprocessing import DataProcessor
|
||||
from utils.clearml import ClearMLHelper
|
||||
from losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss
|
||||
from plotly.subplots import make_subplots
|
||||
from src.trainers.trainer import Trainer
|
||||
from src.trainers.autoregressive_trainer import AutoRegressiveTrainer
|
||||
from src.data.preprocessing import DataProcessor
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
from src.losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss
|
||||
import plotly.graph_objects as go
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from clearml import OutputModel
|
||||
import torch
|
||||
from data.preprocessing import DataProcessor
|
||||
from utils.clearml import ClearMLHelper
|
||||
from src.data.preprocessing import DataProcessor
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
import plotly.graph_objects as go
|
||||
import numpy as np
|
||||
import plotly.subplots as sp
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
@@ -95,7 +93,7 @@ class Trainer:
|
||||
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
|
||||
return indices
|
||||
|
||||
def train(self, epochs: int):
|
||||
def train(self, epochs: int, remotely: bool = False):
|
||||
try:
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
@@ -106,6 +104,9 @@ class Trainer:
|
||||
|
||||
task = self.init_clearml_task()
|
||||
|
||||
if remotely:
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
self.best_score = None
|
||||
counter = 0
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ class ClearMLHelper:
|
||||
self.project_name = project_name
|
||||
|
||||
def get_task(self, task_name: str = "Model Training"):
|
||||
Task.ignore_requirements("torch")
|
||||
Task.ignore_requirements("torchvision")
|
||||
Task.ignore_requirements("tensorboard")
|
||||
task = Task.init(project_name=self.project_name, task_name=task_name, continue_last_task=False)
|
||||
task.set_base_docker(f"docker.io/clearml/pytorch-cuda-gcc:2.0.0-cuda11.7-cudnn8-runtime --env GIT_SSL_NO_VERIFY=true --env CLEARML_AGENT_GIT_USER=VictorMylle --env CLEARML_AGENT_GIT_PASS=Voetballer1" )
|
||||
return task
|
||||
Reference in New Issue
Block a user