Update number of classes
This commit is contained in:
@@ -34,7 +34,7 @@ def train():
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
|
||||
spoter_model = SPOTER(num_classes=12, hidden_dim=len(LANDMARKS) *2)
|
||||
spoter_model.train(True)
|
||||
spoter_model.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user