On hyperparameter tuning and how to avoid it.

There is a common belief that hyperparameter selection for neural networks is hard. If one subscribes to this view, two divergent courses of action present themselves. The first – roughly our approach so far – is to ignore the problem and hope for the best, perhaps tweaking the occasional learning rate schedule. An alternative, popular at some of the larger cloud providers, is to launch full parameter sweeps for each new experimental setting.

It’s time to address this issue and ask what our hyperparameter neglect has cost. Have we missed an opportunity for optimisation or worse, reached invalid conclusions? Or alternatively, has the importance and difficulty of hyperparameter tuning been overstated?

In today’s post we will succeed in identifying – experimentally and theoretically – a number of nearly flat directions in hyperparameter space. This greatly simplifies the search problem. Optimising along these directions helps a little, but the penalty for doing so badly, is small. In the following post, we will argue that, in favourable circumstances, whole families of architectures share similar values of optimal parameters.

We conclude that hyperparameter search isn’t always difficult and can sometimes be skipped altogether. The work on weight decay, was largely completed during Thomas Read’s internship at Myrtle in the summer. Apologies to Thomas for taking so long to write this up.

Let’s get started. To motivate things, we’ll begin with some experimental results supporting our claim that there are almost flat directions in hyperparameter space. We plot final CIFAR10 test accuracy of the network from the previous post over various two-dimensional slices in hyperparameter space.

In each case, the learning rate on the x-axis refers to the maximal learning rate in the schedule from the previous post. A second hyperparameter (batch size, 1-momentum or weight decay) is varied on the y-axis of each plot. Other hyperparameters are fixed at: batch size=512, momentum=0.9, weight decay=5e-4. Plots have log-log axes and test accuracies have been floored at 91% to preserve dynamic range. Each datapoint is a mean of three runs.

Each plot has a ridge of maximal test accuracy oriented at 45° to the axes and spanning a wide range in log-parameter space. Let’s denote the maximal learning rate by λ, batch size by N, momentum by ρ and weight decay by α. The plots provide striking evidence of almost-flat directions in which $\frac{\lambda}{N}$, $\frac{\lambda} {1-\rho}$ or $\lambda$ $\alpha$ are held constant.

We discussed holding $\frac{\lambda}{N}$ fixed, whilst varying $N$, in the second post of the series. This is known as the linear scaling rule and has been discussed many times elsewhere. We didn’t plot test accuracy over hyperparameter space previously and do so now for completeness and to underline how well this heuristic applies here. We discussed previously the reason why such a rule is expected in the current regime in which curvature effects are not dominant.

We will provide similar explanations for the other two flat directions shortly. The arguments are straightforward and give one confidence in applying these heuristics in practice. Before we do that, let’s see how knowledge of these flat directions simplifies hyperparameter search.

Suppose that we forgot the hyperparameter settings that we used to reach 94% test accuracy in 24 epochs with our current network. We fix the choice of network, set batch size to 512 and assume a learning rate schedule that increases linearly from zero for the first 5 epochs and decays linearly for the remainder. The hyperparameters that we aim to recover are the maximal learning rate $\lambda$, Nesterov momentum $\rho$, and weight decay $\alpha$. We assume that we know nothing about reasonable values for these hyperparameters and start with arbitrary choices $\lambda = 0.001$, $\rho=0.5$, $\alpha=0.01$ which achieve a test accuracy of 30.6% after 24 epochs.

To optimise these hyperparameters we will follow a type of cyclical coordinate descent in which we tune one parameter at a time with a very crude line search (doubling or halving and retraining until things get worse). Now coordinate descent in $(\lambda, \rho, \alpha)$ space would be a bad idea because of the nearly flat directions at $\pm 45$° to the coordinate axes. A much better idea is to perform the descent in $(\frac{\lambda \alpha}{1-\rho}, \rho, \alpha)$ space, aligning the almost flat directions with the axes. A picture may help:

We optimise the learning rate $\lambda$ first, then momentum $\rho$ (we half/double $1-\rho$ rather than $\rho$), then weight decay $\alpha$ before cycling through again. Importantly, when we optimise $\rho$ or $\alpha$ we keep $\frac{\lambda \alpha}{1-\rho}$ fixed by adjusting $\lambda$ appropriately. We halt the whole process when things stop improving.

Here are results based on a search that converged after a total of 22 training runs. The process could be straightforwardly automated and made more efficient. With more reasonable starting values we’d have been done in fewer than 10 runs. The first line in the table shows the initial parameter settings and subsequent lines are the result of optimising $\lambda$, $\rho$ or $\alpha$ in turn with $\frac{\lambda \alpha}{1-\rho}$ held fixed in the latter two cases:

train run$\lambda$$\rho$$\alpha$$\frac{\lambda \alpha}{1-\rho}$test acc

Note that after the first step, $ \frac{\lambda \alpha}{1-\rho}$ has already stabilised to within a factor of two of its final value. The optimized hyperparameters are rather close to our hand-chosen ones from before $(\lambda=0.4, \rho=0.9, \alpha=0.0005)$ and the combination $\frac{\lambda \alpha}{1-\rho}$ is as close to the previous value of $\frac{0.4 \times 0.0005}{(1-0.9)} = 0.002$ as could be, given the resolution of our doubling/halving scheme. The final test accuracy 94.2% is based on a single run and the improvement over our previous 94.08% is not statistically significant.

One might believe that optimising further at a higher parameter resolution – and using multiple training runs to reduce noise – would lead to improvements over our baseline training. We have not succeeded in doing so. Two of the directions in hyperparameter space are extremely flat at this point and the other is close enough to optimal to be almost flat also. In conclusion, there is a large and readily identifiable region of hyperparameter space whose performance is hard to distinguish at the level of statistical noise.

So far, we have presented experimental evidence for the existence of nearly flat directions in hyperparameter space and demonstrated the utility of knowing these when tuning hyperparameters. It remains to explain where the flat directions come from.

In the case of momentum $\rho$, the explanation is simple and mirrors the one that we gave before regarding batch size. We present the argument in the case of SGD with ordinary momentum rather than Nesterov momentum since the equations are slightly simpler. Similar logic applies in either case.

Ignoring weight decay for now (which can be incorporated into a change of loss function), vanilla SGD + momentum updates the parameters $w$ in two steps. First we update v which is an exponential moving average of the gradient g of the loss with respect to $w$:

$v \leftarrow \rho v + g$

Next we use $v$ to update the weights $w$:

$ w \leftarrow w -\lambda v$

$\lambda$ and $\rho$ are learning rate and momentum parameters as before.

Let’s focus on the gradient $g_t$ computed at a single timestep $t$. It is first added to $w_t$ with weight $-\lambda_t$. At the next timestep it is added to $w_{t+1}$, via its contribution to $v_{t+1}$ with the weight $-\rho \lambda_{t+1}$. This continues over time so that the total contribution of the gradient $g_t$ to the updated weights approaches

$-\lambda_t – \rho \lambda_{t+1} – \rho ^2 \lambda_{t+2} + \cdots \approx -\frac{\lambda_t}{1- \rho} $

where we have assumed that $\lambda_t$ is approximately constant over the relevant timescale and summed the geometric series.

Now suppose that we make a small change to $\lambda$ and $\rho$, keeping the combination $\frac{\lambda}{1-\rho}$ fixed. This leads to the same total contribution from a given gradient $g_t$ but changes slightly the timescale over which the update is applied. If we can assume that delaying updates has only a small effect – as will be the case for small enough learning rates – the dynamics will be similar in either case.

Curvature effects are sub-dominant in the current training regime since learning rates are limited by other effects as explained previously. As a result changing $\rho$ whilst keeping $\frac{\lambda}{1-\rho}$ fixed has only a weak effect on training as seen experimentally.

So much for momentum. Next we would like to explain why training changes slowly when we vary weight decay $\alpha$ and learning rate $\lambda$, keeping $\lambda \alpha$ fixed.

The argument will apply precisely to parameters $w$ which can be rescaled without changing the loss function. This is true of all the convolutional weights in our network since each convolutional layer is followed directly by a batch normalisation which neutralises the effect of rescaling weights. It is not true of the final classifier or some of the later batch norm layers, but rescaling these has a mild enough effect that the conclusions continue to hold fairly well in practice.

Let $w$ be a set of parameters such that the loss is unchanged under rescaling of $w$. Let’s consider pure SGD without momentum and with weight decay $\alpha$ (momentum is irrelevant to the main discussion and we want to keep things simple). The parameter update splits into a weight decay step:

$w \leftarrow (1 – \lambda \alpha) w$

and a gradient descent step:

$w \leftarrow w -\lambda g.$

(The careful reader will have observed that the weight decay step is just a rescaling of $w$ and thus a no-op from the point of view of the loss function. More on that shortly.)

Suppose that we decide to use our freedom in choosing the scale of $w$ to carry out training with a rescaled $w$. Say we decide to work with $3w$ instead. The gradients $g$ that we compute become 3 times smaller (since a fixed change in parameters delta $w$ has less impact at the new scale). On the other hand, in order to update $w$ as before, we would like the update step to be 3 times bigger than previously. This is fine for the weight decay step which scales with $w$, but we need to adjust $\lambda$ to be 9 times larger to account for both the smaller gradients and the rescaling of the update step and then we need to reduce $\alpha$ by a factor of 9 to compensate the change in $\lambda$ in the weight decay step. In conclusion, for parameters with a scaling symmetry, there is an exact symmetry of the loss and training dynamics if we rescale weights $w$ by a factor $r$, learning rate $\lambda$ by $r^2$ and weight decay $\alpha$ by $r^{-2}$ thereby keeping $\lambda$$\alpha$ fixed.

How does this help us? Actually we’re nearly done. We’ve seen that changing λ and α holding $\lambda \alpha$ fixed is equivalent to a change in initialisation scale of the weights (for parameters with scaling symmetry). Such changes in initialisation don’t change the loss and should have a diminishing effect on training over time if we’re lucky. Let’s understand this in more detail.

Suppose we initialise weights at some large scale. What happens next? The large scale on the weights is equivalent to a very small learning rate as far as the gradient update step is concerned. The weight decay step proceeds as normal and gradually shrinks the weights. Once the weights reach a small enough scale, the gradient updates start to be important and balance the shrinking effect of the weight decay. If the weights start too small, the opposite dynamics take place and the gradient updates dominate the weight decay term in importance until the scale of weights returns to equilibrium.

It’s important for this argument that gradient updates lead (on average) to an increase in the scale of weights since otherwise there would be nothing to stop weight decay shrinking the weights to zero and de-stabilising the dynamics. In fact, since we are assuming that the loss is invariant to rescaling weights, there is no component of the gradient g in the direction parallel to $w$ and g is orthogonal to $w$. By Pythagoras’s theorem, the gradient updates therfore increase the norm of $w$. The argument in the case of SGD + momentum is different, but the result is the same. More on this (and the relation to LARS in the next post).

The conclusion of all this is the following. Weight decay in the presence of batch normalisation acts as a stable control mechanism on the effective step size. If gradient updates get too small, weight decay shrinks the weights and boosts gradient step sizes until equilibrium is restored. The reverse happens when gradient updates grow too large.

The choice of initial weight scale and raw learning rate λ, do not get to control the dynamics for long. After a settling-in period, the size of gradient update steps is determined by the rate of weight decay which is a function of λα only. (Similar arguments in the presence of momentum lead to a dependence on $\frac{\lambda \alpha}{1-\rho}$​ as we shall review next time). If we vary $\lambda$ and $\alpha$ holding the product fixed, then the learning rate dynamics for most of training is unaffected (or weakly affected for the few layers without scaling symmetry) and this gives rise to the corresponding almost flat directions in hyperparameter space.

In Part 6 we continue to investigate weight decay and ask why hyperparameter settings seem to transfer well between architectures.