Files
sign-predictor/export.py
2023-03-12 19:34:04 +00:00

31 lines
841 B
Python

import torch
import torchvision
import onnx
import numpy as np
from src.model import SPOTER
from src.identifiers import LANDMARKS
model_name = 'Fingerspelling_AE'
# load PyTorch model from .pth file
model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
state_dict = torch.load('models/' + model_name + '.pth')
model.load_state_dict(state_dict)
# set model to evaluation mode
model.eval()
# create dummy input tensor
batch_size = 1
num_of_frames = 1
input_shape = (108, num_of_frames)
dummy_input = torch.randn(batch_size, *input_shape)
# export model to ONNX format
output_file = 'models/' + model_name + '.onnx'
torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output'])
# load exported ONNX model for verification
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)