import numpy as np import onnx import torch import torchvision from models.spoter_embedding_model import SPOTER_EMBEDDINGS # set parameters of the model model_name = 'embedding_model' output=32 # load PyTorch model from .pth file device = torch.device("cpu") # if torch.cuda.is_available(): # device = torch.device("cuda") CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth" checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) model = SPOTER_EMBEDDINGS( features=checkpoint["config_args"].vector_length, hidden_dim=checkpoint["config_args"].hidden_dim, norm_emb=checkpoint["config_args"].normalize_embeddings, ).to(device) model.load_state_dict(checkpoint["state_dict"]) # set model to evaluation mode model.eval() model_export = "onnx" if model_export == "coreml": dummy_input = torch.randn(1, 10, 54, 2) # set device for dummy input dummy_input = dummy_input.to(device) traced_model = torch.jit.trace(model, dummy_input) out = traced_model(dummy_input) import coremltools as ct # Convert to Core ML coreml_model = ct.convert( traced_model, inputs=[ct.TensorType(name="input", shape=dummy_input.shape)], ) # Save Core ML model coreml_model.save("out-models/" + model_name + ".mlmodel") else: # create dummy input tensor dummy_input = torch.randn(1, 10, 54, 2) # set device for dummy input dummy_input = dummy_input.to(device) # export model to ONNX format output_file = 'models/' + model_name + '.onnx' torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output']) torch.onnx.export(model, # model being run dummy_input, # model input (or a tuple for multiple inputs) 'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=9, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['X'], # the model's input names output_names = ['Y'] # the model's output names ) # load exported ONNX model for verification onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model)