diff --git a/models/model_A-Z_v2.onnx b/models/model_A-Z_v2.onnx new file mode 100644 index 0000000..1a9ab21 Binary files /dev/null and b/models/model_A-Z_v2.onnx differ diff --git a/models/model_A-Z_v2.pth b/models/model_A-Z_v2.pth new file mode 100644 index 0000000..34792a8 Binary files /dev/null and b/models/model_A-Z_v2.pth differ diff --git a/src/export.py b/src/export.py index 217513b..f35722e 100644 --- a/src/export.py +++ b/src/export.py @@ -7,7 +7,7 @@ from src.model import SPOTER from src.identifiers import LANDMARKS # set parameters of the model -model_name = 'model_A-Z' +model_name = 'model_A-Z_v2' num_classes = 26 # load PyTorch model from .pth file diff --git a/src/train.py b/src/train.py index 3d527d1..11e4ade 100644 --- a/src/train.py +++ b/src/train.py @@ -128,9 +128,9 @@ def train(): if val_acc > best_val_acc: best_val_acc = val_acc epochs_without_improvement = 0 - if epoch > 55: + if epoch > 45: top_val_acc = val_acc - top_train_acc = train_acc + top_train_acc = pred_correct / pred_all checkpoint_index = epoch torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") else: diff --git a/visualizations/webcam_view.py b/visualizations/webcam_view.py index 99bf2bb..4f665f3 100644 --- a/visualizations/webcam_view.py +++ b/visualizations/webcam_view.py @@ -27,7 +27,7 @@ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) keypoints = [] spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2) -spoter_model.load_state_dict(torch.load('models/spoter_76.pth', map_location=torch.device('cpu'))) +spoter_model.load_state_dict(torch.load('models/model_A-Z_v2.pth', map_location=torch.device('cpu'))) # get values of the landmarks as a list of integers values = []