Myrtle AI

How to Train Your ResNet 4: Architecture

In which we try out some different networks and discover that we’ve been working too hard

So far, we’ve been training a fixed network architecture, taken from the fastest single-GPU DAWNBench entry on CIFAR10. With some simple changes, we’ve reduced the time taken to reach 94% test accuracy from 341s to 154s. Today we’re going to investigate alternative architectures.

Let’s review our current network:


The pink residual blocks contain an identity shortcut and preserve the spatial and channel dimensions of the input:

Light-green downsampling blocks reduce spatial resolution by a factor of two and double the number of output channels:

The motivation for including residual blocks is to ease optimisation by creating shortcuts through the network. The hope is that the shorter paths represent shallow sub-networks which are relatively easy to train, whilst longer paths add capacity and computational depth. It seems reasonable to study how the shortest path through the network trains in isolation and to take steps to improve this before adding back the longer branches.

Eliminating the long branches yields the following backbone network in which all convolutions, except for the initial one, have a stride of two:

In the following experiments we will train for 20 epochs, using an accelerated version of the previous learning rate schedule, since the networks are small and converge more quickly as a result. For anyone who’d like to follow along, code to reproduce the main results is available here.

Training the shortest path network for 20 epochs yields an unimpressive test accuracy of 55.9% in 36 seconds. Removing repeated batch norm-ReLU groups, reduces training time to 32s but leaves test accuracy approximately unchanged.

A serious shortcoming of this backbone network is that the downsampling convolutions have 1×1 kernels and a stride of two, so that rather than enlarging the receptive field they are simply discarding information. If we replace these with 3×3 convolutions, things improve considerably and test accuracy after 20 epochs is 85.6% in a time of 36s.

We can further improve the downsampling stages by applying 3×3 convolutions of stride one followed by a pooling layer instead of using strided convolutions. We choose max pooling with a 2×2 window size leading to a final test accuracy of 89.7% after 43s. Using average pooling gives a similar result but takes slightly longer.

The final pooling layer before the classifier is a concatenation of global average pooling and max pooling layers, inherited from the original network. We replace this with a more standard global max pooling layer and double the output dimension of the final convolution to compensate for the reduction in input dimension to the classifier, leading to a final test accuracy of 90.7% in 47s. Note that average pooling at this stage underperforms max pooling significantly.

By default in PyTorch (0.4), initial batch norm scales are chosen uniformly at random from the interval [0,1]. Channels which are initialised near zero could be wasted so we replace this with a constant initialisation at 1. This leads to a larger signal through the network and to compensate we introduce an overall constant multiplicative rescaling of the final classifier. A rough manual optimisation of this extra hyperparameter suggest that 0.125 is a reasonable value. (The low value makes predictions less certain and appears to ease optimisation.) With these changes in place, 20 epoch training reaches a test accuracy of 91.1% in 47s.

Here is a recap of the steps we’ve taken so far:

NetworkTest accTrain time
Original backbone55.9%36s
No repeat BN-ReLU56.0%32s
3×3 convolutions85.6%36s
Max pool downsample89.7%43s
Global max pool90.7%47s
Better BN scale init91.1%47s

The backbone network seems in reasonable shape now and we’re hitting diminishing returns. It’s time to add back some layers. The network is only 5 layers deep (4 convolutional, one fully connected), so it’s unclear whether we need residual branches, or if extra layers along the backbone would get us to the 94% target.

One approach that doesn’t seem promising is just to add width to the 5 layer network. If we double the channel dimensions and train for 60 epochs we can reach 93.5% test accuracy but training takes all of 321s.

In extending the depth of the network we are faced with a plethora of choices, such as different residual branch types, depths and widths, as well as new hyperparameters, such as initial scales and biases for the residual branches. For the sake of making progress, we restrict to a manageable search space and don’t tune any new hyperparameters.

Specifically we shall consider two classes of networks. The first is constructed by optionally adding a convolutional layer (with batch norm-ReLU) after each max pooling layer. The second class is constructed by optionally adding a residual block consisting of two serial 3×3 convolutions with an identity shortcut, after the same max pooling layers.

We insert an additional 2×2 max pooling layer after the final convolutional block and before the global max pooling so that there are 3 locations to add new layers. The choice to include or not include a new layer is made independently in each case, leading to 7 new networks in each class. We also considered mixtures of the two classes but these did not lead to further improvement so we won’t describe them here.

Here is an example of a network in the first class in which we’ve added an extra convolution after the second max pool layer:

Here is an example network of the second class in which we’ve added residual branches after the first and third layers:

Now it’s time for some brute force architecture search! We train each of the 15 networks (improved backbone + 7 variations in each class) for 20 epochs and also for 22 epochs to understand the benefit of training for longer versus using a deeper architecture. If we ran each experiment once, this would correspond to a full 30 minutes of computation. Unfortunately, the standard deviation of each final test accuracy is around 0.15%, so to have any hope of drawing accurate conclusions, we run each experiment 10 times, leading to ~0.05% standard deviations for each of the data points. Even so, the variation between architectures in rate of improvement going from 20 to 22 epochs is probably mostly noise.

Here are the results. Points indicate 20 epoch times and accuracies, whilst lines extend to the corresponding 22 epoch result:

The rate of improvement from training for longer seems slow compared to the improvements achievable by using deeper architectures. Of the architectures tested, perhaps the most promising is Residual:L1+L3 which we fortuitously chose to illustrate above. This network achieves 93.8% test accuracy in 66s for a 20 epoch run. If we extend training to 24 epochs, 7 out of 10 runs reach 94% with a mean accuracy of 94.08% and training time of 79s!

This would seem a good place to stop for today. We have found a 9 layer deep residual network which trains to 94% accuracy in 79s, cutting training time almost in half. One remaining question is did we really need the residual branches to reach 94% test accuracy? The answer to this is a clear no. For example the single branch network Extra:L1+L2+L3 reaches 95% accuracy in 180s with 60 epoch training and extra regularisation (12×12 cutout) and wider versions go higher still. But at least for now the fastest network to 94% is a residual network (which is fortunate given the title of the series.)

Before closing for today, I’d like to reflect briefly on the motivations for this work. There is a reasonable point of view that says that training a model to 94% test accuracy on CIFAR10 is a meaningless exercise since state-of-the-art is above 98%. (There is a less reasonable viewpoint that says that ImageNet is The One True Dataset and that anything else is a waste of time, but I won’t address this except to say that some subset of lessons learnt on CIFAR10 should also transfer to The One.)

The fact that we could reach 94% accuracy in 24 epochs with a 9 layer network adds some weight to the viewpoint that we are targetting too low a threshold. On the other hand, human performance on CIFAR10 is estimated at around 94% (based on one human classifying 400 images!) so the case is not clear cut.

State-of-the-art accuracy is an ill-conditioned target in the sense that throwing a larger model, more hyperparameter tuning, more data augmentation or longer training time at the problem will typically lead to accuracy gains, making fair comparison between works a delicate task. There is a danger that innovations in training or architectural design introduce additional hyperparmeter dimensions and that tuning these may lead to better implicit optimisation of aspects of training that are otherwise unrelated to the extension under study. Ablation studies, which are generally considered best practice, cannot resolve this issue if the base model has a lower dimensional space of explicit hyperparmeters to optimise. A result of this situation is that state-of-the-art models can seem like a menagerie of isolated results which are hard to compare, fragile to reproduce and difficult to build upon.

Given these problems, we propose that anything that makes it easier to compare within and between experimental works is a good thing. We believe that producing competitive benchmarks for resource constrained training is a way to alleviate some of these difficulties.

The introduction of resource constraints encourages a fairer comparison between works, reducing the need to adjust for the effort that has gone into training. Additional model complexity, which allows for higher dimensional optimisation of implicit parameters, will typically be penalised by resource contrained benchmarks. Methods which explicitly control the relevant parameters will tend to win out.

Resource contrained benchmarks allow for a state-of-the-art frontier rather than a single point, allowing one to target studies to the appropriate regime and allowing a more nuanced understanding of the role of different components and perhaps a phase diagram of successful architectures with a series of transitions to greater complexity.

A simple technique can be proven against the lower thresholds without the need to incorporate all the tricks required to improve an unconstrained state-of-the-art. Shorter training times and reduced model complexity should make the lower resource benchmarks easier to investigate and better optimised than their unconstrained cousins. Better optimised baselines in turn make it easier to accept or reject proposed changes.

Recently, there has been a positive trend in publishing curves of state-of-the-art models according to inference time or model size. These are both important practical targets for optimisation and they tackle some of the issues above, but we believe that additionally regularising by training time brings further benefits. On the flip side, optimising for training time with no concern over inference cost would be suboptimal, which is why our training time results always include the time to evaluate the test set at each epoch and we’ve avoided techniques such as test time augmentation which reduce training time at the expense of inference.

In Part 5 we take a break from speeding up training and develop some heuristics for hyperparameter tuning.

Scroll to Top

Myrtle.ai accelerates Machine Learning inference.

This website uses cookies to ensure you get the best experience on our website. By continuing to browse on this website, you accept the use of cookies for the above purposes.