Changed imports

This commit is contained in:
Victor Mylle
2023-11-26 00:29:35 +00:00
parent 360e9f4e8e
commit f0c6369dd3
7 changed files with 35 additions and 41 deletions

View File

@@ -1,13 +1,13 @@
from clearml import OutputModel
import torch
from data.preprocessing import DataProcessor
from utils.clearml import ClearMLHelper
from utils.autoregressive import predict_auto_regressive
from src.data.preprocessing import DataProcessor
from src.utils.clearml import ClearMLHelper
from src.utils.autoregressive import predict_auto_regressive
import plotly.graph_objects as go
import numpy as np
import plotly.subplots as sp
from plotly.subplots import make_subplots
from trainers.trainer import Trainer
from src.trainers.trainer import Trainer
from tqdm import tqdm

View File

@@ -1,9 +1,9 @@
from losses import CRPSLoss
from utils.clearml import ClearMLHelper
from data.preprocessing import DataProcessor, DataConfig
from src.losses import CRPSLoss
from src.utils.clearml import ClearMLHelper
from src.data.preprocessing import DataProcessor, DataConfig
import numpy as np
import plotly.graph_objects as go
from trainers.trainer import Trainer
from src.trainers.trainer import Trainer
import torch

View File

@@ -1,15 +1,11 @@
import torch
from utils.autoregressive import predict_auto_regressive_quantile
from scipy.interpolate import interp1d
from trainers.trainer import Trainer
from trainers.autoregressive_trainer import AutoRegressiveTrainer
from data.preprocessing import DataProcessor
from utils.clearml import ClearMLHelper
from losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss
from plotly.subplots import make_subplots
from src.trainers.trainer import Trainer
from src.trainers.autoregressive_trainer import AutoRegressiveTrainer
from src.data.preprocessing import DataProcessor
from src.utils.clearml import ClearMLHelper
from src.losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss
import plotly.graph_objects as go
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

View File

@@ -1,10 +1,8 @@
from clearml import OutputModel
import torch
from data.preprocessing import DataProcessor
from utils.clearml import ClearMLHelper
from src.data.preprocessing import DataProcessor
from src.utils.clearml import ClearMLHelper
import plotly.graph_objects as go
import numpy as np
import plotly.subplots as sp
from plotly.subplots import make_subplots
@@ -95,7 +93,7 @@ class Trainer:
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
return indices
def train(self, epochs: int):
def train(self, epochs: int, remotely: bool = False):
try:
train_loader, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size
@@ -106,6 +104,9 @@ class Trainer:
task = self.init_clearml_task()
if remotely:
task.execute_remotely(queue_name="default", exit_process=True)
self.best_score = None
counter = 0