Update number of classes

This commit is contained in:
2023-03-08 14:44:16 +00:00
parent 7653b9b35c
commit b0335044af

View File

@@ -34,7 +34,7 @@ def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
spoter_model = SPOTER(num_classes=12, hidden_dim=len(LANDMARKS) *2)
spoter_model.train(True)
spoter_model.to(device)