Changed imports

This commit is contained in:
Victor Mylle
2023-11-26 00:29:35 +00:00
parent 360e9f4e8e
commit f0c6369dd3
7 changed files with 35 additions and 41 deletions

View File

@@ -1,10 +1,8 @@
from clearml import OutputModel
import torch
from data.preprocessing import DataProcessor
from utils.clearml import ClearMLHelper
from src.data.preprocessing import DataProcessor
from src.utils.clearml import ClearMLHelper
import plotly.graph_objects as go
import numpy as np
import plotly.subplots as sp
from plotly.subplots import make_subplots
@@ -95,7 +93,7 @@ class Trainer:
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
return indices
def train(self, epochs: int):
def train(self, epochs: int, remotely: bool = False):
try:
train_loader, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size
@@ -106,6 +104,9 @@ class Trainer:
task = self.init_clearml_task()
if remotely:
task.execute_remotely(queue_name="default", exit_process=True)
self.best_score = None
counter = 0