Training large models with your GPU

Algorithm Engineer

As a Z by HP Global Data Science Ambassador, Yuanhao Wu's content is sponsored and he was provided with HP products.

In the last post, I shared my story of the Kaggle Jigsaw Multilingual Toxic Comment Classification competition. At that time, I only had a 1080Ti with 11G VRAM, and this made it impossible for me to train the SOTA Roberta-XLM large model which requires larger VRAM than what I had. In this post, I want to share some tips about how to reduce the VRAM usage so that you can train larger deep neural networks with your GPU.

The first tip is to use mixed-precision training. When training a model, generally, all parameters will be stored in the VRAM. It is straightforward that the total VRAM usage equals the stored parameter number times  the VRAM usage of a single parameter. A bigger model not only means better performance, but also more VRAM usage. Since performance is quite important, like in Kaggle competitions, we do not want to reduce the model capacity. The only way to reduce memory usage is to reduce memory usage of every single variable. By default, 32-bit floating point is used, so that one variable consumes 4 bytes. Fortunately, people found that we can use 16-bit floating point without much accuracy lost. This means we can  reduce memory consumption by half! In addition, using lower precision also improves training speed, especially on GPUs with Tensor Core support. 

After version 1.5, pytorch started to support automatic mixed precision training. The framework can identify modules that require full precision and use 32-bit floating for them and use 16-bit floating for others. Below is a sample code from the official Pytorch documentation.

The second tip is to use gradient accumulation. The idea behind gradient accumulation is simple: perform several backward passes with the same model parameters before optimizer updating the parameters. The gradients calculated in each backward pass are accumulated(sumed). If your actual batch size is N and you accumulate gradients of M steps, your equivalent batch size is N*M. However, the training result will not be strictly equal, for some parameters, such as batch normalization statistics, can not be perfectly accumulated. 

There are some things to note about gradient accumulation:

    1.When you use gradient accumulation with mixed-precision training, The scale should be calibrated for the effective batch and scale updates should occur at effective-batch granularity.

    2.When you use gradient accumulation with distributed data parallel(DDP) training, use no_sync() context manager to disable gradient all-reduce for first M-1 steps.

The last not least tip is to use gradient checkpointing. The basic idea of gradient checkpointing is only to save intermediate results of some nodes as checkpoints and re-compute other parts in between those nodes during backpropagation. According to the authors of gradient-checkpointing, with the help of gradient checkpointing, they can fit 10X large models onto their GPUs at only a 20% increase at computation time. This is officially supported by Pytorch since version 0.4.0, and some very commonly used libraries such as Huggingface Transformers also support this feature.

Below is an example of how to turn on gradient checkpointing with a transformer model:

At the end of this post, I want to share a simple benchmark test I made on the HP Z4 workstation. As I mentioned before, the workstation is equipped with 2 24G VRAM RTX6000 GPUs, while in the experiments I only used one GPU. I trained XLM-Roberta Base/Large with different configurations, and the training set of the Kaggle Jigsaw Multilingual Toxic Comment Classification competition was used. 

As we can see, mixed-precision training not only reduces the memory consumption, but also provides a significant speedup. The gradient checkpointing is also very powerful. It reduces the VRAM usage from 23.5G to 11.8G! I wish I could know this technology earlier so that I can even train the XLM-Roberta Large model with my previous 11G 2080Ti GPU :)