blog image

Optimizing the T5 Model for Fast Inference

Model deployment is the most important aspect of developing a machine learning product. The development of every product starts with an idea and then we check if it’s feasible or not. Once we are sure that it is feasible we gather the data and start developing a model for it. During the model development, we also keep in mind the accuracy of the model, as well as its size and inference time as all of these factors, are very important for deploying the model but despite doing all these things sometimes we are settled with a model that has good accuracy but its inference time is not so good which might affect both the cost and user experience of the product hence such models needs to be optimized for inference as much as possible so that we can minimize the inference time to as low as possible.

In this blog, we will explore the way by which we can optimize a T5 model for fast inference on a GPU. 

Motivation Behind T5 Optimization

We were using the T5 model in our product PrepAI which is a question generation platform where users can upload different kinds of documents, videos, or copy and paste texts and the platform automatically generates different kinds of questions with the provided content. Hence we decided to work on optimizing the T5 model. In this tutorial, we will consider the T5 model is trained for translation tasks and will do the comparison on translation tasks only.

Optimization Using TensorRT

As a developer whenever we think of optimizing a model for the Nvidia GPU, TensorRT is the obvious choice that comes to our mind because it transforms the model’s graph into a form that can benefit from the architectural structure of the GPU. It achieves this by replacing certain operations in a graph with other operations which are equivalent to this operation but have less computational overhead. It also fuses certain operations together into one and reduces the computational overhead this process is known as graph fusion.

We thought of trying the TensorRT for optimizing our T5 model. For this, we used the same codes in the official Nvidia repo to implement the optimization as we just wanted to see how much performance can be achieved by using the implementation. We found out that the model was around 3-4x faster for smaller sequences like <100 tokens but after that, the model starts to slow down, and at 512 tokens long sequences it becomes slower than the original torch model. At that time according to certain developers, there was some bug with TensorRT which later got fixed later on when we switched to the next phase in optimization involving ONNX.

Later on, we also found that the implementation also excludes the past key values caching which is essential for speedup on longer sequences. However, this can change in the future and TensorRT might beat other implementations but at the time of writing this blog, it hasn’t been implemented yet. 

Optimization using ONNX runtime

After trying out the TensorRT we decided to optimize the model with ONNX runtime. Converting any model to ONNX and applying little optimization automatically speed up the model by a small bit. During optimizing the model ONNX does basic operations like removing unused nodes, conversion of variables to constant, etc.

For models like BERT, BART, GPT-2, Roberta, etc the ONNX also implements graph fusion which fuses the graphs of these models similar to the way TensorRT does. But unfortunately, the T5 model is not available yet. Since graph optimization was out of the plate the obvious method seemed to us was the conversion of the model into float16.

Converting Encoder Into float16

The T5 model is an encoder-decoder model hence we tried to optimize the encoder first and then the decoder next. For doing this we utilized the ONNX runtime transformer optimization package. We first all the nodes of the ONNX encoder graph to float 16 and tried to evaluate the speed and accuracy of the model. We observed that converting all the nodes in the encoder destabilizes the encoder and hence the encoder only produces NAN values.

The reason for this is that the encoder has a lot of operations which doesn’t work well in float16. Also, the T5 was never designed to be fully compatible with float16 but with bfloat16 and float32 data types. Hence we decided to identify those unstable nodes and keep them in float32 only doing that kept our model stable as well as improved the speed of the model.

Conversion of Decoder Into float16

The next part is to optimize the decoder part of the model. Optimizing the decoder is more important than optimizing the encoder as in each generation the number of times an encoder run is for 1 time while the decoder runs for n times where n is the length of the target sentence. Similar to the encoder we started by converting all the nodes in the decoder part of the model into float16 and then evaluating to see how this has affected the accuracy and speed of the model we noticed that the model was not so fast for smaller sequences and it became slower for longer sequences.

Digging deep into the ONNX runtime and some open source libraries we came to know about the cause behind this slowing. It was related to the data movement between the GPU memory and the RAM. We know that the decoder accepts the decoder_input_ids and the encoder_ouputs for generating the next token. Each time a prediction is made the input is transferred from RAM to GPU memory and after the calculation when the output which is in ORT format needs to be converted into NumPy which can only be done after moving the data back to CPU. The sizes of these tensors would be large. Moving these tensors consumes a lot of time and hence our model becomes slow as the size of the target output increases as data movement operations increase. We got the solution to this problem by using IO bindings with ONNX.

ONNX With IO Bindings

ONNX runtime also provides options to bind inputs and outputs using IO bindings. In this methodology when the input is created it is created as a CUDA tensor which is stored in the GPU memory. For output, we create an empty tensor of the same shape as what would be the output of the calculation. For example, the output of the encoder is determined by the batch size and seq length hence using this parameter we can predict the output size of the decoder as (batch_size,seq_length,768) using this shape value we initialize an empty tensor and we bind the output of the ONNX runtime to this tensor using data pointer of this tensor.

After the ONNX runtime completes its execution it writes the result into this empty tensor and now since this is a tensor there is no need to convert it to NumPy to access its values or to process it further hence by doing so we have avoided the expensive data movement operations during the inference. Doing this improved our model performance and now the model was around 2x faster for smaller sequences and for 512 tokens long sequences also it was around 1.2-1.5 times faster for larger batches it would sometimes reach just 1.05x faster which was not so good hence we decided to explore more into the optimization.

The Past Key Values

Upon researching further into the architectural details of T5 we came to know from various sources that past_key_values can be used to improve the performance of the model on long sequences. The past_key_values are the key, query, and values pair of the transformer. Many of these values remain constant during the generation period. Hence if we cache these values and use the same for the generation we might be able to save a lot of computational power. 

Exporting the models with Past Key Values

To avoid the performance drop for the longer sequences we decided to convert the model to ONNX with past key values. For achieving the same we have used the inspiration from the fastT5 library and we use the same code base for conversion. The fastT5 module converts the model into three individual models: an encoder and two decoders. When doing the conversion we disabled the quantization option of the FastT5 module as it quantizes the weights to int8 which is not beneficial for the optimization of GPU.

We then quantize these models into float16. This time we don’t quantize all the weights in any of the models. But rather than we keep the nodes like ‘Pow’, ‘ReduceMean’, ‘Sqrt’, ‘Softmax’, and ‘Relu’ in float32 as these have exponential components and hence the slight error in any stage can have a large impact on the results so we avoid these nodes.

Further, we experimented with converting different nodes and checking their accuracies as well as their performance. Having been satisfied with the result we decided to implement the bindings for these ONNX models.

Implementing the IO Bindings for the New Models

The way in which we implemented the IO bindings for past_key_values was slightly different from what we did in our previous implementation. In IO bindings one must know the shape of the output tensor to which you want to bind the value. It was slightly difficult to guess and implement the output shape of the past key values hence we choose to directly convert the ORT values to cuda tensor through the ort_val_to_cuda_tensor function of the ONNX training module. This function directly converts the ORT values to cuda tensor with the help of dlpack without transferring the data to CPU. Doing so we saved a lot of time in the implementation.


Our new optimized model was more than 2x faster for a batch size up to 4 and was around 1.3x faster for a batch size of 15+. The sequence length used for doing this test was 512 tokens and the model was t5-base. This variation comes from the fact that as the batch size grows the size of the past key values to be cached also grows and it depends on the available GPU memory on how much of it can be cached hence having a GPU with more GPU memory will in turn help in maintaining the speed for larger batches as well.

Some Outputs From Our Model

We checked the model’s accuracy using an exact match approach, some of the examples of which are shown in the below table. For the translation task for 100 examples, the dataset exact match was 100.

Output comparison

Further Improvements

  • More experimentation can be done to convert the model into a mixed precision format.
  • The model can be converted to float16 and tested as T5 was designed to work with float16 rather than float16.
  • Graph fusion can be implemented to further improve the performance of the model.

Leave a Reply Protection Status