Here I then show how to convert PyTorch models to the ONNX format using the conversion tool which is already a part of PyTorch itself. I also show the best practice of adding metadata to the exported model.
In this article in our series about using portable neural networks in 2020, you’ll learn how to convert a PyTorch model to the portable ONNX format.
Since ONNX is not a framework for building and training models, I will start with a brief introduction to PyTorch. This will be useful for engineers that are starting from scratch and are considering PyTorch as a framework to build and train their models.
A Brief Introduction to PyTorch
PyTorch was released in 2016 and was developed by Facebook’s AI Research lab (FAIR). It has become the preferred framework for researchers experimenting with natural language processing and computer vision. This is interesting because TensorFlow is more widely used in production environments. This dichotomy between what is preferred in the research laboratory versus production really emphasizes the value of a standard like ONNX, which provides a common format for models and a runtime that can be used from all the popular programming languages. As an example, let’s suppose an organization does not want to have every possible framework in its production environment and instead wants to standardize on one. Without ONNX, the model would need to be reimplemented in the framework chosen for production and deployed. This is a non-trivial engineering task. Using ONNX, the PyTorch model can be exported with just a few lines of code and consumed from any language. Only the ONNX Runtime is needed in production.
Importing the Converter
The maintainers of PyTorch have integrated the ONNX converter into PyTorch itself. You do not need to install any additional packages. Once PyTorch is installed, you can access the PyTorch to ONNX converter by including the following import in your modules:
Once the torch module is imported, you can access the conversion function as follows:
Hopefully, this is a practice that other frameworks will adopt. Packaging and versioning the converter with the framework itself makes for one less package to install and also prevents version mismatches between the framework and converter.
A Quick Look at a Model
Before converting a PyTorch model, we need to look at the code that creates the model in order to determine the shape of the input. The code below creates a PyTorch model that predicts the numbers found in the MNIST dataset. A detailed description of the model layers is beyond the scope of this article, but we do need to note the shape of the input. Here it is 784. More specifically, this code is creating a model where the input will be a flattened tensor that is an array of 784 floats. What is the significance of 784? Well, each of the images in the MNIST dataset is a 28 × 28 pixel image. 28 × 28 = 784. So, once flattened, our input is 784 floats where each float represents a shade of gray. The bottom line: This model is expecting 784 floats from a single image. It is not expecting a multidimensional array and it is not expecting a batch of images. Only one prediction at a time. This is an important fact when converting the model to ONNX.
input_size = 784
hidden_sizes = [128, 64]
output_size = 10
model = nn.Sequential(nn.Linear(input_size, hidden_sizes),
Converting PyTorch Models to ONNX
The function below shows how to use the
torch.onnx.export function. There are a few tricks to using this function correctly. The first and most important trick is to set up your sample input correctly. The
sample_input parameter is used to determine the input to the ONNX model. The
export_to_onnx function will accept whatever you give it — as long as it is a tensor — and the conversion will work without error. However, if the sample input is of the wrong shape then you will get an error when you try to run the ONNX model from ONNX Runtime.
sample_input = torch.randn(1, 784)
input_names = ['input'],
output_names = ['output']
onnx_model = onnx.load(ONNX_MODEL_FILE)
meta = onnx_model.metadata_props.add()
meta.key = "creation_date"
meta.value = datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
meta = onnx_model.metadata_props.add()
meta.key = "author"
meta.value = 'keithpij'
onnx_model.doc_string = 'MNIST model converted from Pytorch'
onnx_model.model_version = 3
If the original PyTorch model were designed to accept a batch of 100 images then this sample input would be fine. However, as previously stated, our model was designed to accept only one image at a time when making predictions. If you export the model with this sample input, then you’ll get an error when you run the model.
The code that adds metadata to the model is a best practice. As the data you use to train your model evolves, so will your model. Therefore it is a good idea to add metadata to your model so that you can distinguish it from previous models. The example above adds a brief description of the model to the
doc_string property and sets the version.
author are custom properties added to the
metadata_props property bag. You are free to create as many custom properties using this property bag. Unfortunately, the
model_version property requires an integer or long so you will not be able to version it like your services using major.minor.revision syntax. Additionally, the export function saves the model to a file automatically, so to add this metadata you need to reopen the file and resave it.
Summary and Next Steps
In this article, I provided a brief overview of PyTorch for those looking for a deep learning framework for building and training neural networks. I then showed how to convert PyTorch models to the ONNX format using the conversion tool which is already a part of PyTorch itself. I also showed the best practice of adding metadata to the exported model.
Since the purpose of this article was to demonstrate converting Keras models to the ONNX format, I did not go into detail building and training Keras models. The code sample for this post contains code that explores Keras itself. The keras_mnist.py module is a full end-to-end demo that shows how to load the data, explore the images, and train the model.
Next, we’ll look at converting a TensorFlow model to ONNX.