From b0335044af577b745dacb30353da3d23de371502 Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Wed, 8 Mar 2023 14:44:16 +0000 Subject: [PATCH] Update number of classes --- src/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index de50b9a..ead0ca7 100644 --- a/src/train.py +++ b/src/train.py @@ -34,7 +34,7 @@ def train(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2) + spoter_model = SPOTER(num_classes=12, hidden_dim=len(LANDMARKS) *2) spoter_model.train(True) spoter_model.to(device)