In this tutorial, we will guide you through the steps to take the BodyPoseNet Model from the NVIDIA NGC hub, convert the ONNX model to TFLite format, quantize the converted models, and execute them on the TFLite benchmark model. We will demonstrate two conversion methods: using the open-source onnx2tf tool and MediaTek’s proprietary mtk-converter. Finally, we will showcase a demo video of an application created using the same model.
This model has dynamic inputs (batch_size, width, height) and dynamic outputs. To make the inputs static, you can use the following onnxruntime command:
onnx2tf is an open-source tool that converts ONNX models to TensorFlow Lite models. It is designed to be easy to use and supports a wide range of ONNX operators. This method is suitable for users who prefer open-source solutions and want to leverage the flexibility of TensorFlow Lite.
Install onnx2tf:
pip3installonnx2tf
pip3installsor4onnx
Convert the Model to TF:
Note that the ONNX model already has the input shape in NHWC format. onnx2tf by default expects the input shape to be NCHW and will introduce a transpose operation to the input to convert it to NHWC. The steps below ensure that the generated TFLite input shape format remains NHWC. We start by removing the :0 suffix from the input since TensorFlow introduces its own suffixes.
We can use the generated TF saved_model and convert it to TFLite using the TFLite converter.
importtensorflowastfimportnumpyasnptf_model_path='/path/to/saved_model'tflite_model_path='bodyposenet.tflite'# Generate representative datasetdefrepresentative_dataset():data=tf.random.uniform((1,224,320,3))yield[data]converter=tf.lite.TFLiteConverter.from_saved_model(tf_model_path)converter.optimizations=[tf.lite.Optimize.DEFAULT]converter.representative_dataset=representative_datasetconverter.target_spec.supported_ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]converter.inference_input_type=tf.int8# Can be tf.uint8, tf.float32, or tf.float16converter.inference_output_type=tf.float32# Can be tf.uint8, tf.int8, or tf.float16. We keep it float32 for ease of post-processing output datatflite_model=converter.convert()withopen(tflite_model_path,"wb")asf:f.write(tflite_model)
The Netron visualization of the generated TFLite model is below:
mtk-converter is MediaTek’s proprietary tool for converting ONNX models to TFLite format. It is optimized for MediaTek’s hardware and provides additional features for model optimization and deployment. This method is recommended for users who are deploying models on MediaTek platforms and want to take advantage of MediaTek’s AI-ML solutions.
Install mtk-converter
Follow the instructions to download and install the mtk-converter tool from the official NeuroPilot documentation. Install the mtk-converter tool inside a Python v3.11 virtual environment.
# Steps to start a Python virtual environment and install the mtk_converter wheel file in your Linux environment
sudoaptinstallpython3.11-venv
mkdirNP8
python3.11-mvenvNP8/
sourceNP8/bin/activate# This starts your virtual environment
pipinstallmtk_converter*.whl# Install the compatible mtk-converter wheel file. It would be part of the downloaded NP package.
pipinstall<required_dependencies># Install the dependencies as mentioned in the Neuropilot documentation
Convert ONNX Model to TFLite
The ONNX model already has NHWC shapes. We can use the below Python script to convert the ONNX model to TFLite.
The provided Python script performs two functions:
Convert ONNX IR version to v8 supported by NP8
Convert the generated ONNX model to TFLite
importonnxfromonnximportversion_converterimportmtk_converterimportnumpyasnpimportargparse# Set up argument parserparser=argparse.ArgumentParser(description='Convert ONNX model to TFLite with quantization')parser.add_argument('input_model',type=str,help='Path to the input ONNX model')parser.add_argument('output_model',type=str,help='Path to the output TFLite model')parser.add_argument('--input_names',type=str,default=None,help='Comma-separated list of input tensor names')parser.add_argument('--input_shapes',type=str,default=None,help='Semicolon-separated list of comma-separated input tensor shapes, e.g., "1,224,224,3;1,10"')parser.add_argument('--quantize',type=bool,default=False,help='Set to true if you want a quantized model. Default is False')args=parser.parse_args()input_model_path=args.input_modeloutput_model_path=args.output_modelquantize=args.quantizeinput_names=args.input_names.split(',')ifargs.input_nameselseNoneinput_shapes=[list(map(int,shape.split(',')))forshapeinargs.input_shapes.split(';')]ifargs.input_shapeselseNone# Step 1: Convert the IR version of the ONNX modeldefconvert_ir_version(input_model_path,target_ir_version=8):original_model=onnx.load(input_model_path)print(f"Original model IR version: {original_model.ir_version}")# Convert the IR versionoriginal_model.ir_version=target_ir_versionoutput_model_path='converted_model.onnx'onnx.save(original_model,output_model_path)converted_model=onnx.load(output_model_path)print(f"Converted model IR version: {converted_model.ir_version}")returnoutput_model_path# Step 2: Convert the ONNX model to TFLitedefconvert_to_tflite(input_model_path,output_model_path,input_shapes,input_names):# Load the ONNX model to get the input shapeonnx_model=onnx.load(input_model_path)ifinput_shapesisNone:input_shapes=[]forinput_tensorinonnx_model.graph.input:input_shape=[dim.dim_valuefordimininput_tensor.type.tensor_type.shape.dim]input_shapes.append(input_shape)# Define representative datasetdefdata_gen(input_shapes):foriinrange(10):yield[np.random.randn(*shape).astype(np.float32)forshapeininput_shapes]# Create the converterifargs.input_namesorargs.input_shapes:converter=mtk_converter.OnnxConverter.from_model_proto_file(input_model_path,input_names,input_shapes)else:converter=mtk_converter.OnnxConverter.from_model_proto_file(input_model_path)ifquantize:converter.quantize=Trueconverter.calibration_data_gen=lambda:data_gen(input_shapes)# Uncomment to set precision proportion to INT8# converter.precision_proportion = {'8W8A': 1.0}# Uncomment with precision proportion to set input data type as FP32# converter.prepend_input_quantize_ops = True# Uncomment with precision proportion to set output data type as FP32# converter.append_output_dequantize_ops = Trueconverter.use_per_output_channel_quantization=False# Convert the model to TFLite_=converter.convert_to_tflite(output_file=output_model_path)# Convert the IR version of the ONNX modelconverted_model_path=convert_ir_version(input_model_path)# Convert the ONNX model to TFLiteconvert_to_tflite(converted_model_path,output_model_path,input_shapes,input_names)
The provided conversion_script.py will convert your ONNX model to TFLite. Use the below command line arguments to execute the script:
In this tutorial, we demonstrated how to take the BodyPoseNet from the NVIDIA NGC hub, convert it to TFLite format using two different methods, quantize the model, and execute it on the TFLite benchmark model. We also showcased a demo video of an application created using the same model. This process highlights the flexibility and efficiency of using the NVIDIA TAO Toolkit and MediaTek’s Neuropilot solutions for deploying AI-ML models at the edge.