a man wearing glasses and looking at the camera

Mohit Ayani, Solutions Architect, NVIDIA

Shang Zhang

Shang Zhang, Senior AI Developer Technology Engineer, NVIDIA

Jay Rodge

Jay Rodge, Product Marketing Manager-AI, NVIDIA

Transformer-based models have revolutionized the natural language processing (NLP) domain. Ever since its inception, transformer architecture has been integrated into models like Bidirectional Encoder Representations from Transformers (BERT) and Generative Pre-trained Transformer (GPT) for performing tasks such as text generation or summarization and question and answering to name a few. The newer models are getting bigger in size by stacking more transformer layers and larger input sequence lengths, which in turn, has led to improvements in model accuracy but comes at a cost of higher inference times.

NVIDIA TensorRT is an SDK for high-performance deep learning inference on NVIDIA GPUs. It includes a deep learning inference optimizer and runtime that delivers low latency and high throughput for inference. One of the key features of TensorRT is that it allows the models to be deployed in reduced precisions like FP16 and INT8 without compromising on accuracy. Recently, Bing announced the support of running their transformer models on Azure T4 GPUs leveraging TensorRT INT8 optimization. Starting with TensorRT 8.0, users can now see down to 1.2ms inference latency using INT8 optimization on BERT Large.

Many of these transformer models from different frameworks (such as PyTorch and TensorFlow) can be converted to the Open Neural Network Exchange (ONNX) format, which is the open standard format representing AI and deep learning models for further optimizations. ONNX Runtime is a high-performance inference engine to run machine learning models, with multi-platform support and a flexible execution provider interface to integrate hardware-specific libraries. As shown in Figure 1, ONNX Runtime integrates TensorRT as one execution provider for model inference acceleration on NVIDIA GPUs by harnessing the TensorRT optimizations. Based on the TensorRT capability, ONNX Runtime partitions the model graph and offloads the parts that TensorRT supports to TensorRT execution provider for efficient model execution on NVIDIA hardware. 

Different execution providers supported by ONNX Runtime
Figure 1: Different execution providers supported by ONNX Runtime.

In this blog, we will be using the HuggingFace BERT model, apply TensorRT INT8 optimizations, and accelerate the inference with ONNX Runtime with TensorRT execution provider.

Setup

To get started, you can clone the transformer repository from the HuggingFace Github page.

$ git clone https://github.com/huggingface/transformers.git
$ cd transformers

Then, you can build and launch the docker container using the following steps which uses the NGC PyTorch container image.

$ docker build . -f examples/research_projects/quantization-qdqbert/Dockerfile -t bert_quantization:latest

$ docker run --gpus all --privileged --rm -it --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 bert_quantization:latest

Once inside the container, navigate to the quantization directory.

$ cd transformers/examples/research_projects/quantization-qdqbert/

INT8 optimization

Model quantization is becoming popular in the deep learning optimization methods to use the 8-bit integers calculations for using the faster and cheaper 8-bit Tensor Cores. This, in turn, can be used to compute convolution and matrix-multiplication operations yielding more throughput, which is particularly effective on compute-limited layers.

Quantization Toolkit

TensorRT Quantization Toolkit for PyTorch provides a convenient tool to train and evaluate PyTorch models with simulated quantization. This library can automatically or manually add quantization to PyTorch models and the quantized model can be exported to ONNX and imported by TensorRT 8.0 and later.

If you already have an ONNX model, you can directly apply ONNX Runtime quantization tool with Post Training Quantization (PTQ)  for running with ONNX Runtime-TensorRT quantization. Please refer to this example for more details. This blog focuses on starting with a PyTorch model.

HuggingFace QDQBERT model

The HuggingFace QDQBERT model starts from the HuggingFace BERT model, and uses TensorRT Quantization Toolkit for PyTorch to insert Q/DQ nodes into the network. Fake quantization operations (pairs of QuantizeLinear/DequantizeLinear ops) are added to (1) linear layer inputs and weights, (2) matmul inputs, (3) residual add inputs, in the BERT model. After that, the QDQBERT model is exported to ONNX format, which can be imported into TensorRT. The QDQBERT model can be loaded from any checkpoint of HuggingFace BERT model (for example bert-large-uncased), and perform Quantization Aware Training (QAT) or Post Training Quantization (PTQ) afterwards.

Launch the following command to first perform calibration:

python3 run_quant_qa.py \
--model_name_or_path bert-large-uncased \
--dataset_name squad \
--max_seq_length 128 \
--doc_stride 32 \
--output_dir calib/bert-large-uncased \
--do_calib \
--calibrator percentile \
--percentile 99.99

And then the QAT can be launched by executing the script:

python3 run_quant_qa.py \
--model_name_or_path calib/bert-large-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 4e-5 \
--num_train_epochs 2 \
--max_seq_length 128 \
--doc_stride 32 \
--output_dir finetuned_int8/bert-large-uncased \
--tokenizer_name bert-large-uncased \
--save_steps 0

At a high level, TensorRT processes ONNX models with Q/DQ operators similarly to how TensorRT processes any other ONNX model: TensorRT imports an ONNX model containing Q/DQ operations. It performs a set of optimizations that are dedicated to Q/DQ processing. It continues to perform the general optimization passes. It builds a platform-specific, execution-plan file for inference execution. This plan file contains quantized operations and weights.

Thus, you can now export the fine-tuned model with Q/DQ operations to the ONNX format using the following:

python3 run_quant_qa.py \
--model_name_or_path finetuned_int8/bert-large-uncased \
--output_dir ./ \
--save_onnx \
--per_device_eval_batch_size 1 \
--max_seq_length 128 \
--doc_stride 32 \
--dataset_name squad \
--tokenizer_name bert-large-uncased

Starting from TensorRT 8.0, TensorRT processes Q/DQ networks with new optimizations, which increases Q/DQ model performance and provides predictable and user-controlled arithmetic precision transitions.

Results

Experiments of inferencing performance are performed on NVIDIA A100, using ONNX Runtime 1.11 and TensorRT 8.2 with HuggingFace BERT-large model. The inference task is SQuAD, with INT8 quantization by the HuggingFace QDQBERT-large model.

The benchmarking can be done using either trtexec:

trtexec --onnx=model.onnx --explicitBatch --workspace=16384 --int8 --shapes=input_ids:64x128,attention_mask:64x128,token_type_ids:64x128 --verbose

We also have the python script which uses the ONNX Runtime with TensorRT execution provider and can also be used instead:

python3 ort-infer-benchmark.py

With the optimizations of ONNX Runtime with TensorRT EP, we are seeing up to seven times speedup over PyTorch inference for BERT Large and BERT Base, with latency under 2 ms and 1 ms respectively for BS=1. The figures below show the inference latency comparison when running the BERT Large with sequence length 128 on NVIDIA A100.

Compute latency comparison between ONNX Runtime-TensorRT and PyTorch for running BERT-Large on NVIDIA A100 GPU for sequence length 128.
Figure 2: Compute latency comparison between ONNX Runtime-TensorRT and PyTorch for running BERT-Large on NVIDIA A100 GPU for sequence length 128.

You can also check the accuracy of the INT8 model using the following script:

python3 evaluate-hf-trt-qa.py \
--onnx_model_path=./model.onnx \
--output_dir ./ \
--per_device_eval_batch_size 64 \
--max_seq_length 128 \
--doc_stride 32 \
--dataset_name squad \
--tokenizer_name bert-large-uncased \
--int8 \
--seed 42

Accuracy metrics with ONNX Runtime-TensorRT 8.2 EP for the SQuAD task are:

 INT8FP16FP32
F1 score87.5226387587.6907230487.96610141

At the end

ONNX Runtime-TensorRT INT8 quantization shows very promising results on NVIDIA GPUs. We’d love to hear any feedback or suggestions as you try it in your production scenarios. You can submit feedback by participating in our GitHub repos (TensorRT and ONNX Runtime).