import torch import torchvision import onnx import numpy as np from src.model import SPOTER from src.identifiers import LANDMARKS model_name = 'Fingerspelling_AE' # load PyTorch model from .pth file model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2) state_dict = torch.load('models/' + model_name + '.pth') model.load_state_dict(state_dict) # set model to evaluation mode model.eval() # create dummy input tensor batch_size = 1 num_of_frames = 1 input_shape = (108, num_of_frames) dummy_input = torch.randn(batch_size, *input_shape) # export model to ONNX format output_file = 'models/' + model_name + '.onnx' torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output']) # load exported ONNX model for verification onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model)