Body Pose Estimation

Introduction

In this tutorial, we will guide you through the steps to take the BodyPoseNet Model from the NVIDIA NGC hub, convert the ONNX model to TFLite format, quantize the converted models, and execute them on the TFLite benchmark model. We will demonstrate two conversion methods: using the open-source onnx2tf tool and MediaTek’s proprietary mtk-converter. Finally, we will showcase a demo video of an application created using the same model.

Prerequisites

Before you begin, ensure you have the following installed:

  • Python 3.11

  • ONNX v1.13

  • onnxruntime v1.17.1

  • TensorFlow v2.14

  • onnx2tf

  • Neuropilot Converter 8.10

  • NVIDIA TAO Toolkit (v5.5)

Download the Model

First, download the BodyPoseNet model from the NVIDIA NGC hub.

wget --content-disposition 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/tao/bodyposenet/deployable_onnx_v1.0.1/files?redirect=true&path=model.onnx' -O bodyposenet.onnx

Model Analysis

This downloaded model has three dynamic inputs, which need to be made static for conversion. Below is a visualization of the model using Netron.

BodyPoseNet ONNX Model

BodyPoseNet ONNX model

This model has dynamic inputs (batch_size, width, height) and dynamic outputs. To make the inputs static, you can use the following onnxruntime command:

python3 -m onnxruntime.tools.make_dynamic_shape_fixed --input_name input_1:0 --input_shape 1,224,320,3 bodyposenet.onnx bodyposenet_fixed.onnx

This will fix the input shape of the ONNX model to [1,224,320,3].

You can now see that the dynamic inputs have become static.

BodyPoseNet Fixed ONNX Model

BodyPoseNet ONNX model fixed

Convert ONNX Model to TFLite

We will demonstrate two methods to convert the ONNX model to TFLite format.

Method 1: Using onnx2tf

onnx2tf is an open-source tool that converts ONNX models to TensorFlow Lite models. It is designed to be easy to use and supports a wide range of ONNX operators. This method is suitable for users who prefer open-source solutions and want to leverage the flexibility of TensorFlow Lite.

  1. Install onnx2tf:

    pip3 install onnx2tf
    pip3 install sor4onnx
    
  2. Convert the Model to TF:

    Note that the ONNX model already has the input shape in NHWC format. onnx2tf by default expects the input shape to be NCHW and will introduce a transpose operation to the input to convert it to NHWC. The steps below ensure that the generated TFLite input shape format remains NHWC. We start by removing the :0 suffix from the input since TensorFlow introduces its own suffixes.

    sor4onnx --input_onnx_file_path bodyposenet_fixed.onnx --old_new ":0" "" --mode full --search_mode suffix_match --output_onnx_file_path bodyposenet_renamed.onnx
    

    You can now see that the input name has changed from input_1:0 to input_1.

    BodyPoseNet Fixed Renamed ONNX Model

    BodyPoseNet ONNX model with renamed input

    Now we convert the ONNX model to TensorFlow saved_model using onnx2tf.

    onnx2tf -i bodyposenet_renamed.onnx -ois input_1:1,224,320,3 -kat input_1
    

    This generates the TensorFlow saved_model.

  3. Convert and Quantize the Model to TFLite:

    We can use the generated TF saved_model and convert it to TFLite using the TFLite converter.

    import tensorflow as tf
    import numpy as np
    
    tf_model_path = '/path/to/saved_model'
    tflite_model_path = 'bodyposenet.tflite'
    
    # Generate representative dataset
    def representative_dataset():
        data = tf.random.uniform((1,224,320,3))
        yield [data]
    
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8  # Can be tf.uint8, tf.float32, or tf.float16
    converter.inference_output_type = tf.float32  # Can be tf.uint8, tf.int8, or tf.float16. We keep it float32 for ease of post-processing output data
    
    tflite_model = converter.convert()
    
    with open(tflite_model_path, "wb") as f:
        f.write(tflite_model)
    

    The Netron visualization of the generated TFLite model is below:

    BodyPoseNet TFLite Model

    BodyPoseNet TFLite Model

Method 2: Using mtk-converter

mtk-converter is MediaTek’s proprietary tool for converting ONNX models to TFLite format. It is optimized for MediaTek’s hardware and provides additional features for model optimization and deployment. This method is recommended for users who are deploying models on MediaTek platforms and want to take advantage of MediaTek’s AI-ML solutions.

  1. Install mtk-converter

    Follow the instructions to download and install the mtk-converter tool from the official NeuroPilot documentation. Install the mtk-converter tool inside a Python v3.11 virtual environment.

    # Steps to start a Python virtual environment and install the mtk_converter wheel file in your Linux environment
    sudo apt install python3.11-venv
    mkdir NP8
    python3.11 -m venv NP8/
    source NP8/bin/activate  # This starts your virtual environment
    pip install mtk_converter*.whl  # Install the compatible mtk-converter wheel file. It would be part of the downloaded NP package.
    pip install <required_dependencies>  # Install the dependencies as mentioned in the Neuropilot documentation
    
  2. Convert ONNX Model to TFLite

    The ONNX model already has NHWC shapes. We can use the below Python script to convert the ONNX model to TFLite.

    The provided Python script performs two functions:

    1. Convert ONNX IR version to v8 supported by NP8

    2. Convert the generated ONNX model to TFLite

    import onnx
    from onnx import version_converter
    import mtk_converter
    import numpy as np
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Convert ONNX model to TFLite with quantization')
    parser.add_argument('input_model', type=str, help='Path to the input ONNX model')
    parser.add_argument('output_model', type=str, help='Path to the output TFLite model')
    parser.add_argument('--input_names', type=str, default=None, help='Comma-separated list of input tensor names')
    parser.add_argument('--input_shapes', type=str, default=None, help='Semicolon-separated list of comma-separated input tensor shapes, e.g., "1,224,224,3;1,10"')
    parser.add_argument('--quantize', type=bool, default=False, help='Set to true if you want a quantized model. Default is False')
    
    args = parser.parse_args()
    input_model_path = args.input_model
    output_model_path = args.output_model
    quantize = args.quantize
    input_names = args.input_names.split(',') if args.input_names else None
    input_shapes = [list(map(int, shape.split(','))) for shape in args.input_shapes.split(';')] if args.input_shapes else None
    
    # Step 1: Convert the IR version of the ONNX model
    def convert_ir_version(input_model_path, target_ir_version=8):
        original_model = onnx.load(input_model_path)
        print(f"Original model IR version: {original_model.ir_version}")
    
        # Convert the IR version
        original_model.ir_version = target_ir_version
        output_model_path = 'converted_model.onnx'
        onnx.save(original_model, output_model_path)
    
        converted_model = onnx.load(output_model_path)
        print(f"Converted model IR version: {converted_model.ir_version}")
        return output_model_path
    
    # Step 2: Convert the ONNX model to TFLite
    def convert_to_tflite(input_model_path, output_model_path, input_shapes, input_names):
    
        # Load the ONNX model to get the input shape
        onnx_model = onnx.load(input_model_path)
    
        if input_shapes is None:
            input_shapes = []
            for input_tensor in onnx_model.graph.input:
                input_shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
                input_shapes.append(input_shape)
    
        # Define representative dataset
        def data_gen(input_shapes):
            for i in range(10):
                yield [np.random.randn(*shape).astype(np.float32) for shape in input_shapes]
    
        # Create the converter
        if args.input_names or args.input_shapes:
            converter = mtk_converter.OnnxConverter.from_model_proto_file(input_model_path, input_names, input_shapes)
        else:
            converter = mtk_converter.OnnxConverter.from_model_proto_file(input_model_path)
    
        if quantize:
            converter.quantize = True
            converter.calibration_data_gen = lambda: data_gen(input_shapes)
    
            # Uncomment to set precision proportion to INT8
            # converter.precision_proportion = {'8W8A': 1.0}
    
            # Uncomment with precision proportion to set input data type as FP32
            # converter.prepend_input_quantize_ops = True
    
            # Uncomment with precision proportion to set output data type as FP32
            # converter.append_output_dequantize_ops = True
    
        converter.use_per_output_channel_quantization = False
        # Convert the model to TFLite
        _ = converter.convert_to_tflite(output_file=output_model_path)
    
    # Convert the IR version of the ONNX model
    converted_model_path = convert_ir_version(input_model_path)
    
    # Convert the ONNX model to TFLite
    convert_to_tflite(converted_model_path, output_model_path, input_shapes, input_names)
    

    The provided conversion_script.py will convert your ONNX model to TFLite. Use the below command line arguments to execute the script:

    python3 conversion_script.py bodyposenet.onnx bodyposenet.tflite --input_names input_1:0 --input_shapes 1,224,320,3 --quantize=True
    

    Set the --quantize flag to TRUE to quantize the TFLite model.

Model Execution

Use the TFLite benchmark model tool to execute the quantized model and measure its performance.

benchmark_model is provided in Tensorflow Performance Measurement for performance evaluation.

Commands for executing the benchmark tool with CPU and different delegates are as follows.

  • Execute on CPU (8 threads):

benchmark_model --graph=bodyposenet.tflite --num_threads=8 --num_runs=10
  • Execute on GPU, with GPU delegate:

benchmark_model --graph=bodyposenet.tflite --use_gpu=1 --allow_fp16=0 --gpu_precision_loss_allowed=0 --num_runs=10
  • Execute on GPU, with Arm NN delegate:

benchmark_model --graph=bodyposenet.tflite --external_delegate_path=/usr/lib/libarmnnDelegate.so.29 --external_delegate_options="backends:GpuAcc,CpuAcc" --num_runs=10
  • Execute on NPU, with NPU LiteRT delegate:

benchmark_model --graph=bodyposenet.tflite --stable_delegate_settings_file=/usr/share/label_image/--stable_delegate_settings.json

The following table shows the inference times (in milliseconds) for the converted bodyposenet TFLite model running on the Genio 700 platform.

Inference Times on Genio 700

Model

CPU (1 thread)

GPU

NPU LiteRT delegate

BodyPoseNet

576.2

294.3

21

Demo Example

Application Control Flow Stack Diagram

Below we show the Application Control Flow of the demo example. The example makes use of our advanced MDP, GPU, and NPU capabilities.

Application Control Flow

BodyPoseNet Application Control Flow

Application Demo Video

Below are a few video examples of body pose estimation using BodyPoseNet

The example estimates body poses of a person performing an exercise with skipping ropes

The example estimates body poses of a person performing squat exercises

Conclusion

In this tutorial, we demonstrated how to take the BodyPoseNet from the NVIDIA NGC hub, convert it to TFLite format using two different methods, quantize the model, and execute it on the TFLite benchmark model. We also showcased a demo video of an application created using the same model. This process highlights the flexibility and efficiency of using the NVIDIA TAO Toolkit and MediaTek’s Neuropilot solutions for deploying AI-ML models at the edge.