67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
# to run this script, you need torch 1.13.1 and torchvision 0.14.1
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import torch
|
|
import torchvision
|
|
import os
|
|
|
|
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
|
|
|
# set parameters of the model
|
|
model_name = 'fingerspelling_embedding_model'
|
|
|
|
# load PyTorch model from .pth file
|
|
|
|
device = torch.device("cpu")
|
|
# if torch.cuda.is_available():
|
|
# device = torch.device("cuda")
|
|
|
|
CHECKPOINT_PATH = "checkpoints/fingerspelling_checkpoint.pth"
|
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
|
|
|
model = SPOTER_EMBEDDINGS(
|
|
features=checkpoint["config_args"].vector_length,
|
|
hidden_dim=checkpoint["config_args"].hidden_dim,
|
|
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
|
).to(device)
|
|
model.load_state_dict(checkpoint["state_dict"])
|
|
# set model to evaluation mode
|
|
model.eval()
|
|
|
|
dummy_input = torch.randn(1, 10, 54, 2)
|
|
|
|
# check if models folder exists
|
|
if not os.path.exists('out-models'):
|
|
os.makedirs('out-models')
|
|
|
|
for model_export in ["onnx", "coreml"]:
|
|
if model_export == "coreml":
|
|
# set device for dummy input
|
|
dummy_input = dummy_input.to(device)
|
|
traced_model = torch.jit.trace(model, dummy_input)
|
|
|
|
out = traced_model(dummy_input)
|
|
import coremltools as ct
|
|
|
|
# Convert to Core ML
|
|
coreml_model = ct.convert(
|
|
traced_model,
|
|
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)],
|
|
)
|
|
|
|
# Save Core ML model
|
|
coreml_model.save("out-models/" + model_name + ".mlmodel")
|
|
else:
|
|
# set device for dummy input
|
|
dummy_input = dummy_input.to(device)
|
|
|
|
torch.onnx.export(model, # model being run
|
|
dummy_input, # model input (or a tuple for multiple inputs)
|
|
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
|
export_params=True, # store the trained parameter weights inside the model file
|
|
opset_version=9, # the ONNX version to export the model to
|
|
do_constant_folding=True, # whether to execute constant folding for optimization
|
|
input_names = ['X'], # the model's input names
|
|
output_names = ['Y'] # the model's output names
|
|
) |