In which we reproduce a baseline to train CIFAR10 in 6 minutes and then speed it up a little – we observe that there’s plenty of room for improvement before the GPU runs out of FLOPs
Over the past few months, I’ve been investigating how to train deep neural networks quickly. Or at least not as slowly.
My interest in the problem began earlier this year, during a project with Sam Davis at Myrtle. We were compressing large recurrent networks for automatic speech recognition to deploy on FPGAs and needed to retrain the models. The baseline implementation from Mozilla took a week to train on 16 GPUs. After some great work by Sam to remove bottlenecks and move to mixed-precision computation on Nvidia Volta GPUs, we were able to reduce training times more than 100-fold and bring iterations down below a day on a single GPU.
This was crucial to our ability to progress and got me wondering what else we could speed up and what applications this might enable.
At about the same time, folks at Stanford were thinking along similar lines and launched the DAWNBench competition to compare training speeds on a range of deep learning benchmarks. Of most interest were the benchmarks for training image classification models to 94% test accuracy on CIFAR10 and 93% top-5 accuracy on ImageNet. Image classification is a particularly popular field of deep learning research, but the initial entries didn’t reflect state-of-the-art practices on modern hardware and took multiple hours to train.
By the time the competition closed in April, the situation had changed and on CIFAR10, the fastest single GPU entry, from fast.ai student Ben Johnson, reached 94% accuracy in under 6 minutes (341s). The main innovations were mixed-precision training, choosing a smaller network with sufficient capacity for the task and employing higher learning rates to speed up stochastic gradient descent (SGD).
So an obvious question is: how good is 341s to train to 94% test accuracy on CIFAR10? The network used in the fastest submission was an 18-layer Residual network, shown below. (Click-and-scroll to navigate the network, and hover over nodes to see more information.) In this case the number of layers refers to the serial depth of (purple) convolutional and (blue) fully connected layers although the terminology is by no means universal:
The network was trained for 35 epochs using SGD with momentum and the slightly odd learning rate schedule below:
Let’s see how long training should take in this setup assuming 100% compute efficiency on a single NVIDIA Volta V100 GPU – the top-of-the-line data centre GPU used by the winning DAWNBench entries. A forward and backward pass through the network on a 32×323 CIFAR10 image, requires approximately 2.8×109 FLOPs. Assuming that parameter update computations are essentially free, 35 epochs of training on the 50,000 image dataset should complete in approximately 5×1015 FLOPs.
Equipped with 640 Tensor Cores, Tesla V100 delivers 125 TeraFLOPS of deep learning performance
Assuming that we could realise 100% compute efficiency, training should complete in… 40 seconds. Even under realistic assumptions, it seems there’s room to improve the 341s state-of-the-art.
So with a target in mind, it’s time to start training. The first order of business is to reproduce the baseline CIFAR10 results with the network above. Since we’re planning to change things later, I built a version of the network in PyTorch and replicated the learning rate schedule and hyperparameters from the DAWNBench submission. Training on an AWS p3.2×large instance with a single V100 GPU, 3/5 runs reach a final test accuracy of 94% in 356s.
With baseline duly reproduced, the next step is to look for simple improvements that can be implemented right away. A first observation: the network starts with two consecutive (yellow-red) batch norm-ReLU groups after the first (purple) convolution. This was presumably not an intentional design and so let’s remove the duplication. Likewise the strange kink in the learning rate at epoch 15 has to go although this shouldn’t impact training time. With those changes in place the network and learning rate look slightly simpler and more importantly, 4/5 runs reach 94% final test accuracy in a time of 323s! New record!
A second observation: some of the image preprocessing (padding, normalisation and transposition) is needed on every pass through the training set and yet this work is being repeated each time. Other preprocessing steps (random cropping and flipping) differ between epochs and it makes sense to delay applying these. Although the preprocessing overhead is being mitigated by using multiple CPU processes to do the work, it turns out that PyTorch dataloaders (as of version 0.4) launch fresh processes for each iteration through the dataset. The setup time for this is non-trivial, especially on a small dataset like CIFAR10. By doing the common work once before training, removing pressure from the preprocessing jobs, we can reduce the number of processes needed to keep up with the GPU down to one. In heavier tasks, requiring more preprocessing or feeding more than one GPU, an alternative solution could be to keep dataloader processes alive between epochs. In any case, the effect of removing the repeat work and reducing the number of dataloader processes is a further 15s saving in training time (almost half a second per epoch!) and a new training time of 308s.
A bit more digging reveals that most of the remaining preprocessing time is spent calling out to random number generators to select data augmentations rather than in the augmentations themselves. During a full training run we make several million individual calls to random number generators and by combining these into a small number of bulk calls at the start of each epoch we can shave a further 7s of training time. Finally, at this point it turns out that the overhead of launching even a single process to perform the data augmentation outweighs the benefit and we can save a further 4s by doing the work on the main thread, leading to a final training time for today of 297s. Code to reproduce this result can be found in this Python notebook
In Part 2 we increase batch sizes. Training time continues to drop…