Header
Improve Inference Efficiency with Batch Inference

Algorithm Engineer

As an algorithm engineer, it is inevitable that you will encounter the problem of bringing models online in your daily work. For some less demanding scenarios, you can handle this by utilizing a web framework: for each user request, call the model to infer and return the result. However, this straightforward implementation often fails to maximize the use of the GPU, and is slightly overwhelming for scenarios with high performance requirements.

There are many ways to optimize, and one useful tip is to change from inference for each request to inference for multiple requests at once. Last year, about this time I wrote a small tool to achieve this function and gave it a rather overbearing name InferLight. Honestly,  that tool was not very well implemented. Recently, I refactor the tool with reference to Shannon Technology's Service-Streamer. 

This feature seems simple, but in the process of implementation, we can understand a lot of Python asynchronous programming knowledge and feel the parallel computing power of modern GPU.

 

Architecture

First, to improve the model's online inference throughput, you should make the inference service asynchronous. For web services, asynchronous means that the program can handle other requests while the model is computing. For Python, asynchronous services can be implemented with good Asyncio-based frameworks, such as Sanic, which I commonly use. Whereas inference is computationally intensive, our goal is to be able to aggregate multiple inference requests, make efficient use of the parallel computing power of the GPU, and be able to return the results of bulk inference to the corresponding requestor correctly.

To achieve the above goal, the following modules are needed
1. Front-end service: used to receive requests and return results. It can be various protocols such as Http, PRC, etc. It is an independent process.
2. Inference Worker: responsible for model initialization, bulk inference data construction, and inference calculation. It is an independent process.
3. Task queue: the front-end service receives the request and sends the calculation task to the task queue; the inference worker listens to the queue and takes out a small batch each time by the model inference
4. Result queue: After the inference done, inference worker sends the result to the result queue; the front-end service listens to the queue and gets the inference result
5. Result distribution: before sending the task to the task queue, a unique identifier of the task needs to be generated, and the result corresponding to the task is obtained according to the identifier after retrieving the result from the result queue

There are many ways to implement the task queue and result queue, and you can use some mature middleware such as Kafka and Redis. To avoid external dependencies, I chose to use Python's native multi-process queue this time. The result queue is listened to and distributed through a sub-thread of the front-end service process.

Implementation

The inference worker is relatively simple. Since there are a variety of models to load and data processing steps, I designed the inference worker as a base class that is inherited and implements specific methods when used.

Along with this is a Wrapper class used in the front-end service to do the request receiving, result collection and distribution of inference requests.

Some of the data structures used are defined as follows

Use Case and Test Result

Here we show how the above components can be used with a sentiment analysis BERT model.

First define the model

Then inherit BaseInferLightWorker and implement three functions to get a complete Worker class

Finally, building services

I did some tests with the famous Apache’s ab tool. I started the above app on my HP Z4 Workstation and made sure the worker process was running on a RTX 6000 GPU.

With ab -n 1000 -c 32 http://localhost:8888/batched_predict, I got the following result.

Test result of another straightford implement without batch inference is as follow:

As you can see, we got about 2.5 times throughput with batch inference! When doing the benchmark, I also observed that the GPU utilization is much higher with batch inference.

I have opened source the InferLight, and it can be found at https://github.com/thuwyh/InferLight. Hope you love it :)