Merge branch 'WES-184-New-letter-variants' into 'dev'
WES-184 Train the SPOTER model on the new letter variants See merge request wesign/sign-predictor!18
This commit was merged in pull request #18.
This commit is contained in:
BIN
models/model_A-Z_v2.onnx
Normal file
BIN
models/model_A-Z_v2.onnx
Normal file
Binary file not shown.
BIN
models/model_A-Z_v2.pth
Normal file
BIN
models/model_A-Z_v2.pth
Normal file
Binary file not shown.
@@ -7,7 +7,7 @@ from src.model import SPOTER
|
||||
from src.identifiers import LANDMARKS
|
||||
|
||||
# set parameters of the model
|
||||
model_name = 'model_A-Z'
|
||||
model_name = 'model_A-Z_v2'
|
||||
num_classes = 26
|
||||
|
||||
# load PyTorch model from .pth file
|
||||
|
||||
@@ -128,9 +128,9 @@ def train():
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
epochs_without_improvement = 0
|
||||
if epoch > 55:
|
||||
if epoch > 45:
|
||||
top_val_acc = val_acc
|
||||
top_train_acc = train_acc
|
||||
top_train_acc = pred_correct / pred_all
|
||||
checkpoint_index = epoch
|
||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||
else:
|
||||
|
||||
@@ -27,7 +27,7 @@ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
keypoints = []
|
||||
|
||||
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2)
|
||||
spoter_model.load_state_dict(torch.load('models/spoter_76.pth', map_location=torch.device('cpu')))
|
||||
spoter_model.load_state_dict(torch.load('models/model_A-Z_v2.pth', map_location=torch.device('cpu')))
|
||||
|
||||
# get values of the landmarks as a list of integers
|
||||
values = []
|
||||
|
||||
Reference in New Issue
Block a user