Added new training scripts

This commit is contained in:
Victor Mylle
2023-11-27 14:55:22 +00:00
parent 5e87165dbb
commit c1152ff96c
7 changed files with 37 additions and 36 deletions

View File

@@ -19,7 +19,6 @@ class AutoRegressiveTrainer(Trainer):
criterion: torch.nn.Module,
data_processor: DataProcessor,
device: torch.device,
clearml_helper: ClearMLHelper = None,
debug: bool = True,
):
super().__init__(
@@ -28,7 +27,6 @@ class AutoRegressiveTrainer(Trainer):
criterion=criterion,
data_processor=data_processor,
device=device,
clearml_helper=clearml_helper,
debug=debug,
)
self.model.output_size = 1