License Plate Detection

Introduction

In this tutorial, we will guide you through the steps to take the LPDNet Model from the NVIDIA NGC hub, convert the ONNX model to TFLite format, quantize the converted models, and execute them on the TFLite benchmark model. We will demonstrate two conversion methods: using the open-source onnx2tf tool and MediaTek’s proprietary mtk-converter. Finally, we will showcase a demo video of an application created using the same model.

Prerequisites

Before you begin, ensure you have the following installed:

  • Python 3.11

  • ONNX v1.13

  • onnxruntime v1.17.1

  • TensorFlow v2.14

  • onnx2tf

  • Neuropilot Converter 8.10

  • NVIDIA TAO Toolkit (v5.5)

Download the Model

First, download the LPDNet_USA pruned model from the NVIDIA NGC hub.

wget --content-disposition 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/tao/lpdnet/pruned_v2.2/files?redirect=true&path=LPDNet_usa_pruned_tao5.onnx' -O LPDNet_usa_pruned_tao5.onnx

Model Analysis

The downloaded model has dynamic inputs, which need to be made static for conversion. Below is a visualization of the model using Netron.

LPDNet_USA Pruned ONNX Model

LPDNet_USA ONNX model

To make the inputs static, you can use the following onnxruntime command:

python -m onnxruntime.tools.make_dynamic_shape_fixed --input_name input_1:0 --input_shape 1,3,480,640 LPDNet_usa_pruned_tao5.onnx LPDNet_usa_pruned_tao5_fixed.onnx

You can now see that that the dynamic inputs have become static

LPDNet_USA Pruned Fixed ONNX Model

LPDNet_USA ONNX model fixed

Convert ONNX Model to TFLite

We will demonstrate two methods to convert the ONNX model to TFLite format.

Method 1: Using onnx2tf

onnx2tf is an open-source tool that converts ONNX models to TensorFlow Lite models. It is designed to be easy to use and supports a wide range of ONNX operators. This method is suitable for users who prefer open-source solutions and want to leverage the flexibility of TensorFlow Lite.

  1. Install onnx2tf:

    pip3 install onnx2tf
    
  2. Remove suffixes from ONNX model:

    import onnx
    import sys
    
    def remove_suffix_from_names(model_path, output_model_path, suffix=':0'):
        # Load the ONNX model
        onnx_model = onnx.load(model_path)
    
        # Get input and output names to remove the suffix from
        graph_input_names = [input.name for input in onnx_model.graph.input]
        graph_output_names = [output.name for output in onnx_model.graph.output]
    
        print('graph_input_names =', graph_input_names)
        print('graph_output_names =', graph_output_names)
    
        # Remove suffix from input names
        for input in onnx_model.graph.input:
            input.name = input.name.removesuffix(suffix)
    
        # Remove suffix from output names
        for output in onnx_model.graph.output:
            output.name = output.name.removesuffix(suffix)
    
        # Remove suffix from node input and output names
        for node in onnx_model.graph.node:
            for i in range(len(node.input)):
                if node.input[i] in graph_input_names:
                    node.input[i] = node.input[i].removesuffix(suffix)
    
            for i in range(len(node.output)):
                if node.output[i] in graph_output_names:
                    node.output[i] = node.output[i].removesuffix(suffix)
    
        # Save the modified ONNX model
        onnx.save(onnx_model, output_model_path)
    
    if __name__ == "__main__":
        if len(sys.argv) != 3:
            print("Usage: python3 script.py <input_model.onnx> <output_model.onnx>")
            sys.exit(1)
    
        input_model_path = sys.argv[1]
        output_model_path = sys.argv[2]
    
        remove_suffix_from_names(input_model_path, output_model_path)
    

    use the above script as python3 script.py LPDNet_usa_pruned_tao5_fixed.onnx LPDNet_usa_pruned_tao5_mod.onnx

  3. Convert the Model to TF:

    onnx2tf -i LPDNet_usa_pruned_tao5_mod.onnx -oiqt
    

    This will generate the TF saved_model which can be further used for quantization using MLIR.

  4. Convert and Quantize the Model to TFLite:

    We can use the generated TF saved_model and convert it to TFLite using the TFLite converter.

    import tensorflow as tf
    import numpy as np
    
    tf_model_path = '/path/to/saved_model'
    tflite_model_path = 'LPDNet_usa_pruned_tao5.tflite'
    
    # Generate representative dataset
    def representative_dataset():
        data = tf.random.uniform((1,480,640,3))
        yield [data]
    
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset()
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8  # Can be tf.uint8, or tf.float32 or tf.float16
    converter.inference_output_type = tf.float32  # Can be tf.uint8, tf.int8 or tf.float16. We keep it float32 for ease of post-processing output data
    
    tflite_model = converter.convert()
    
    with open(tflite_model_path, "wb") as f:
        f.write(tflite_model)
    

Method 2: Using mtk-converter

mtk-converter is MediaTek’s proprietary tool for converting ONNX models to TFLite format. It is optimized for MediaTek’s hardware and provides additional features for model optimization and deployment. This method is recommended for users who are deploying models on MediaTek platforms and want to take advantage of MediaTek’s AI-ML solutions.

  1. Install mtk-converter:

    Follow the instructions to download and install the mtk-converter tool from the official NeuroPilot documentation. Install the mtk-converter tool inside a Python v3.11 virtual environment.

    # Steps to start a Python virtual environment and install the mtk_converter wheel file in your Linux environment
    sudo apt install python3.11-venv
    mkdir NP7
    python3.11 -m venv NP8/
    source NP8/bin/activate  # This starts your virtual environment
    pip install mtk_converter*.whl  # Install the compatible mtk-converter wheel file. It would be part of the downloaded NP package.
    pip install <required_dependencies>  # Install the dependencies as mentioned in the Neuropilot documentation
    
  2. Convert Input and Output to NHWC Data Format:

    The mtk-converter by default will introduce a transpose operation to the converted TFLite model to convert ONNX NCHW data format to NHWC format preferred by TensorFlow. This might introduce overhead during execution. To avoid this overhead, we will first insert ‘transpose’ operations to the input and output operations and then continue with our model conversion and quantization.

    We only need to insert transpose op to the input here

    import onnx
    
    # The provided model has one input and two outputs
    def add_transpose(model_proto, nhwc_input_shape, nhwc_output1_shape, nhwc_output2_shape):
    
        input_nhwc = onnx.helper.make_tensor_value_info('input_nhwc', onnx.TensorProto.FLOAT, nhwc_input_shape)
        output1_nhwc = onnx.helper.make_tensor_value_info('output1_nhwc', onnx.TensorProto.FLOAT, nhwc_output1_shape)
        output2_nhwc = onnx.helper.make_tensor_value_info('output2_nhwc', onnx.TensorProto.FLOAT, nhwc_output2_shape)
    
        input_transpose_node = onnx.helper.make_node('Transpose', ['input_nhwc'], ['input_1:0'], perm=[0,3,1,2])
        model_proto.graph.node.insert(0, input_transpose_node)
    
        output1_transpose_node = onnx.helper.make_node('Transpose', ['output_cov/Sigmoid:0'], ['output1_nhwc'], perm=[0,2,3,1])
        model_proto.graph.node.append(output1_transpose_node)
    
        output2_transpose_node = onnx.helper.make_node('Transpose', ['output_bbox/BiasAdd:0'], ['output2_nhwc'], perm=[0,2,3,1])
        model_proto.graph.node.append(output2_transpose_node)
    
        model_proto.graph.input.pop(0)
        model_proto.graph.input.insert(0, input_nhwc)
    
        model_proto.graph.output.pop(0)
        model_proto.graph.output.insert(0, output1_nhwc)
    
        model_proto.graph.output.pop(1)
        model_proto.graph.output.insert(1, output2_nhwc)
    
        onnx.checker.check_model(model_proto)
        return model_proto
    
    model_proto = onnx.load('LPDNet_usa_pruned_tao5.onnx')
    nhwc_input_shape = [1,480,640,3]
    nhwc_output1_shape = [1,30,40,1]
    nhwc_output2_shape = [1,30,40,4]
    
    model_proto_nhwc = add_transpose(model_proto, nhwc_input_shape, nhwc_output1_shape, nhwc_output2_shape)
    
    onnx.save(model_proto_nhwc, 'LPDNet_usa_pruned_tao5_nhwc.onnx')
    
  3. Convert NHWC ONNX Model to TFLite

    This part of the conversion consists of two steps:

    1. Version Conversion: TAO ONNX models have IR version 9 and above, whereas mtk-converter supports ONNX models having IR version 3 to 8. So we convert the ONNX IR version to 8.

    # The provided code snippet by default will convert to IR8. You can choose to convert it to any IR version between versions 3 to 8
    
    import onnx
    from onnx import version_converter
    
    def convert_ir_version(input_model_path, target_ir_version=8):
    
        original_model = onnx.load(input_model_path)
        original_model.ir_version = target_ir_version
        output_model_path = 'converted_onnx.onnx'
        onnx.save(original_model, output_model_path)
    
    convert_ir_version('path_to_onnx_model')
    
    1. Convert to TFLite: The following code snippets convert the generated ONNX model with IR v8 to TFLite.

    import onnx
    import mtk_converter
    import numpy as np
    
    def convert_to_tflite(input_model_path, output_model_path):
    
        onnx_model = onnx.load(input_model_path)
    
        # Load ONNX model to obtain input shape of model to create representative dataset for quantization.
        # We quantize as part of model conversion by default in this code snippet. You can choose to not quantize.
    
        input_tensor = onnx_model.graph.input[0]
        input_shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
    
        def data_gen(input_shape):
            for i in range(10):
                yield [np.random.randn(*input_shape).astype(np.float32)]
    
        # Create the converter
        converter = mtk_converter.OnnxConverter.from_model_proto_file(input_model_path)
        # Set quantizer to True
        converter.quantize = True
        converter.calibration_data_gen = lambda: data_gen(input_shape)
    
        # If you want the model to remain in float, you can choose to uncomment precision_proportion
        # =====================================================
        # converter.precision_proportion = {'FP': 1.0}
        # =====================================================
    
        # If you want your model to be INT8 quantized but outputs dequantized to float (fp32), use below snippet
        # ======================================================
        converter.precision_proportion = {'8W8A': 1.0}
        converter.append_output_dequantize_ops = True
        # ======================================================
    
        converter.use_per_output_channel_quantization = False
        _ = converter.convert_to_tflite(output_file=output_model_path)
    
    input_model_path = 'converted_onnx.onnx'
    output_model_path = 'retail_object_recognition.tflite'
    
    # Convert to TFLite and save to the provided output path
    convert_to_tflite(input_model_path, output_model_path)
    

Model Execution

Use the TFLite benchmark model tool to execute the quantized model and measure its performance.

benchmark_model is provided in Tensorflow Performance Measurement for performance evaluation.

Commands for executing the benchmark tool with CPU and different delegates are as follows.

  • Execute on CPU (8 threads):

benchmark_model --graph=LPDNet_usa_pruned_tao5.tflite --num_threads=8 --num_runs=10
  • Execute on GPU, with GPU delegate:

benchmark_model --graph=LPDNet_usa_pruned_tao5.tflite --use_gpu=1 --allow_fp16=0 --gpu_precision_loss_allowed=0 --num_runs=10
  • Execute on GPU, with Arm NN delegate:

benchmark_model --graph=LPDNet_usa_pruned_tao5.tflite --external_delegate_path=/usr/lib/libarmnnDelegate.so.29 --external_delegate_options="backends:GpuAcc,CpuAcc" --num_runs=10
  • Execute on NPU, with NPU LiteRT delegate:

benchmark_model --graph=LPDNet_usa_pruned_tao5.tflite --stable_delegate_settings_file=/usr/share/label_image/--stable_delegate_settings.json

The following table shows the inference times (in milliseconds) for the converted LPDNet_usa_pruned_tao5 TFLite model running on the Genio 700 platform.

Inference Times on Genio 700

Model

CPU (1 thread)

GPU

NPU LiteRT delegate

LPDNet_usa_pruned

54.2

53.9

2.6

Demo Example

Application Control Flow Stack Diagram

Below we show the Application Control Flow of a demo example. The example makes use of our advanced MDP, GPU and NPU capabilities

Application Control Flow

LPDNet Application Control Flow

Application Demo Video

The below video is from a parking lot. The model was trained on NVIDIA-owned US license plate dataset. The application detects unique license platest in the video

License Plate Detection

Conclusion

In this tutorial, we demonstrated how to take the LPDNet pruned model from the NVIDIA NGC hub, convert it to TFLite format using two different methods, quantize the model, and execute it on the TFLite benchmark model. We also showcased a demo video of an application created using the same model. This process highlights the flexibility and efficiency of using the NVIDIA TAO Toolkit and MediaTek’s Neuropilot solutions for deploying AI-ML models at the edge.