Files
sign-predictor/src/export.py
Victor Mylle e13f365d81 Dev
2023-03-26 19:40:47 +00:00

44 lines
1.7 KiB
Python

import torch
import torchvision
import onnx
import numpy as np
from src.model import SPOTER
from src.identifiers import LANDMARKS
# set parameters of the model
model_name = 'model_A-Z'
num_classes = 26
# load PyTorch model from .pth file
model = SPOTER(num_classes=num_classes, hidden_dim=len(LANDMARKS) *2)
if torch.cuda.is_available():
state_dict = torch.load('models/' + model_name + '.pth')
else:
state_dict = torch.load('models/' + model_name + '.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
# set model to evaluation mode
model.eval()
# create dummy input tensor
dummy_input = torch.randn(10, 108)
# export model to ONNX format
output_file = 'models/' + model_name + '.onnx'
torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output'])
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
'models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['X'], # the model's input names
output_names = ['Y'] # the model's output names
)
# load exported ONNX model for verification
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)