Changed imports
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
from clearml import OutputModel
|
||||
import torch
|
||||
from data.preprocessing import DataProcessor
|
||||
from utils.clearml import ClearMLHelper
|
||||
from src.data.preprocessing import DataProcessor
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
import plotly.graph_objects as go
|
||||
import numpy as np
|
||||
import plotly.subplots as sp
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
@@ -95,7 +93,7 @@ class Trainer:
|
||||
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
|
||||
return indices
|
||||
|
||||
def train(self, epochs: int):
|
||||
def train(self, epochs: int, remotely: bool = False):
|
||||
try:
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
@@ -106,6 +104,9 @@ class Trainer:
|
||||
|
||||
task = self.init_clearml_task()
|
||||
|
||||
if remotely:
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
self.best_score = None
|
||||
counter = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user