Merge branch 'WES-184-New-letter-variants' into 'dev'

WES-184 Train the SPOTER model on the new letter variants

See merge request wesign/sign-predictor!18
This commit was merged in pull request #18.
This commit is contained in:
Tibe Habils
2023-05-06 19:20:57 +00:00
5 changed files with 4 additions and 4 deletions

BIN
models/model_A-Z_v2.onnx Normal file

Binary file not shown.

BIN
models/model_A-Z_v2.pth Normal file

Binary file not shown.

View File

@@ -7,7 +7,7 @@ from src.model import SPOTER
from src.identifiers import LANDMARKS
# set parameters of the model
model_name = 'model_A-Z'
model_name = 'model_A-Z_v2'
num_classes = 26
# load PyTorch model from .pth file

View File

@@ -128,9 +128,9 @@ def train():
if val_acc > best_val_acc:
best_val_acc = val_acc
epochs_without_improvement = 0
if epoch > 55:
if epoch > 45:
top_val_acc = val_acc
top_train_acc = train_acc
top_train_acc = pred_correct / pred_all
checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
else:

View File

@@ -27,7 +27,7 @@ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
keypoints = []
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2)
spoter_model.load_state_dict(torch.load('models/spoter_76.pth', map_location=torch.device('cpu')))
spoter_model.load_state_dict(torch.load('models/model_A-Z_v2.pth', map_location=torch.device('cpu')))
# get values of the landmarks as a list of integers
values = []