In which we delve deeper into the learning rate dynamics.

The reader may be feeling a little uneasy at this point. Last Time we presented experimental results and theoretical explanations for three almost flat directions in the space of neural network training hyperparameters, but the three explanations don’t quite hang together…

The first explained the observed weak dependence of training trajectories on the choice of learning rate $\lambda$ and batch size $N$ when the ratio $\frac{\lambda}{N}$ was held fixed. A similar argument applied to momentum $\rho$ when $\frac{\lambda}{1-\rho}$ was fixed. Both relied on a simple matching of first order terms in the weight update equations of SGD + momentum.

A third argument, regarding the weight decay $\alpha$, was rather different and distinctly not first order in nature. We argued that for weights with a scaling symmetry, the gradient with respect to the weight norm vanishes and a second order effect (Pythagorean growth from orthogonal gradient updates) appears at leading order in the dynamics. What is worse, we argued that, although weight norms are irrelevant to the forward computation of the network, they determine the effective learning rate for the other parameters.

What then of our first order arguments? Do they hold water when we take proper account of the weight norm dynamics? The experimental results suggest that the conclusions at least are correct.

Our main task for today will be to shore up the two first order arguments with a more careful study of the weight norm dynamics. In the process we will find a close relation with the technique of Layer-wise Adaptive Rate Scaling which has been introduced recently in the context of large batch training on ImageNet. We study the implications of this relation and propose that it may be behind a remarkable stability in the optimal learning rate across different architectures.

Today’s post is going to be more theoretical than the ones before. It contains equations and not much in the way of practical advice. I hope that it will be of some interest nonetheless. The maths is all of the high school variety and shouldn’t slow things down too much. On the plus side, it allows us to be more precise in our arguments than we have been – and, for that matter, than various discussions in the literature.

The connection to LARS is presented mostly as a curiosity, but it should lead to practical training improvements given more effort. Rather than wait for those results, I’ve decided to put the post out as is. It’s long overdue, the work having been largely completed during Thomas Read’s productive internship at Myrtle in the summer.

Let’s commence by looking in more detail at the dynamics of weight norms in the presence of weight decay. Before launching into equations, here is a recap of the main points from last time.

For weights with a scaling symmetry – which includes all the convolutional layers of our network because of subsequent batch normalisation – gradients are orthogonal to weights. As a result, gradient updates lead to an increase in weight norm whilst weight decay leads to a decrease.

For small weights, the growth term dominates and vice versa for large weights. This leads to a stable control mechanism whereby the weight norms approach a fixed point (for fixed learning rate and other hyperparameters) such that the first order shrinking effect of weight decay balances the second order growth from orthogonal gradient updates.

Perhaps some pictures will make this more concrete. First we plot the amount by which the squared weight norm $|w|^2$ of each convolutional layer shrinks under weight decay at each step of training:

Next we plot the amount by which $|w|^2$ grows because of gradient updates. This looks rather like a noisy version of the first plot – as it is should if they are in equilibrium:

Here are the two plots superimposed:

Let’s also plot the evolution of $|w|^2$ itself. Weight norms for most layers grow initially, whilst the gradient updates dominate, and then level out or decay slowly once equilibrium is reached:

Plots are normalised so that $|w|^2 = 1$ for each layer initially.

Let’s put this into equations, starting with the simple case of SGD without momentum. In the notation of last time the SGD update splits into two pieces, a weight decay term:

$w \leftarrow w – \lambda \alpha w \tag{1}$

$w \leftarrow w – \lambda g \tag{2}$

In terms of weight norms, we have:

$|w|^2\leftarrow |w|^2 – 2 \lambda \alpha |w|^2 + O(\lambda^2 \alpha^2)\tag{1′}$

and:

$|w|^2 \leftarrow |w|^2 + \lambda^2 |g|^2 \tag{2′}$

where we have used the fact that for weights with a scaling symmetry, $w \cdot g = 0.$

We expect that training will reach an equilibrium where the shrinking effect of the first update balances the growth of the second so that:

$2 \lambda \alpha |w|^2 \approx \lambda ^2 |g|^2 \tag{3}$

From last time, we recall that although the forward computation is invariant under rescaling $w$, the SGD update is not. Indeed, if we rescale $w$ by a factor $r$, then gradients g scale by $r^{-1}$ rather than the r that would be required for invariance. As a result, weight norms matter and determine the effective step size for the remaining weights.

One way to fix this would be to use a different optimiser such as LARS. In the absence of momentum, the LARS update is:

$w \leftarrow w – \lambda_{LARS} \frac{|w|}{|g|} g \tag{4}$

Scale invariance is enforced by rescaling gradients by a batch dependent factor of $\frac{|w|}{|g|}$. Now equation $(3)$ tells us that (for scale invariant weights) the dynamics of ordinary SGD with weight decay drives this batch dependent factor towards a fixed value anyway!

$\frac{|w|}{|g|} \approx \sqrt{\frac{\lambda}{2 \alpha}} \tag{5}$

If this ratio were really given by its equilibrium value at each batch then SGD and LARS dynamics would be identical with:

$\lambda_{LARS} := \lambda \frac{|g|}{|w|} = \sqrt{2 \lambda \alpha}\tag{6}$

In the plot below we compare the two quantities on the right hand side of this equation for a training run using SGD without momentum. The noisy batch dependent values of $\lambda \frac{|g|}{|w|}$ are plotted as circles whilst the black line shows $\sqrt{2\lambda\alpha}$:

We see that, up to noise, and a minor deviation towards the end of training when equilibrium breaks down, SGD makes very similar updates to LARS with an effective learning rate of $\sqrt { 2 \lambda \alpha}$. This is somewhat unexpected and begs the question what LARS is doing in that case – perhaps just providing a scale invariant way to clip noisy gradients?

Where does this leave the linear scaling rule which states that we should keep $\frac{\lambda}{N}$ approximately fixed when varying learning rates and batch sizes? The effective learning rate parameter in the scale invariant (LARS) form of the dynamics, is $\sqrt{2\lambda \alpha}$.

If we rescale batch size by a factor $n$ and $\lambda$ by the same amount then the effective step size grows by a factor of $\sqrt{n}$. This no longer looks very linear!

In fact this is what we need in the case when batch gradients are dominated by noise – so that gradients for different timesteps are orthogonal to one another on average. In that case, $|g|$ scales like $\frac{1}{\sqrt{n}}$ (from averaging $n$ orthogonal contributions). The two factors of $\sqrt{n}$ combine so that one large batch update is once again the same as $n$ small batch updates at leading order.

Note that the two $\sqrt{n}$ factors in the argument above came from very different sources. The first traces back to orthogonality between gradients and weights because of an exact symmetry of the model. The second is a consequence of mutual orthogonality between the gradients of subsequent batches because of SGD noise and needs to be empirically verified.

To that end, here is a plot of the correlation between gradients $g_t$ and an exponential moving average of previous gradients $v_{t-1}$ (as appears in the update step of SGD + momentum with $\rho=0.9$.) Correlations are close to zero for most of training, confirming that orthogonality on average between nearby gradients is a reasonable approximation.

So far we have only treated the case of SGD without momentum. Rather than delaying the reader with another lengthy exposition let’s summarise as briefly as possible. Once again, we will need to assume that gradients of different batches are orthogonal on average.

SGD+momentum applies the update corresponding to a single gradient $g_t$ over multiple timesteps (weighted by powers of $\rho$). If we could resume these updates and apply them at once, then the effect of summing the geometric series would be to simply replace the learning rate $\lambda$ with a value of $\frac{\lambda}{1-\rho}$. We could then repeat the equilibrium argument from above with this substitution (remembering that weight decay terms are summed in the same way) and find that:

$\frac{|w|}{|g|} \approx \sqrt{\frac{\lambda}{2 \alpha (1-\rho)}}.\tag{7}$

The difficulty with this argument is as follows. Consider ordinary SGD+momentum dynamics (without this resummation). During the delay between computing a gradient $g$ and applying it, subsequent gradients from other batches added to $w$ might break the orthogonality between $g$ and $w$ so that equation $(2′)$ no longer holds (the lack of orthogonality because of adding $g$ itself is taken care of in the resummation argument). Mutual orthogonality of gradients saves the day and allows the resummation argument to go through. The LARS update in the presence of momentum is:

$v \leftarrow \rho v + \frac{|w|}{|g|} g$

$w \leftarrow w – \lambda_{LARS} v$

We have just seen that, subject to the assumptions on scale invariance and orthogonality of gradients between batches, SGD+momentum dynamics drive $\frac{|w|}{|g|}$ to an equilibrium value given by $(7)$. We can then relate ordinary SGD+momentum dynamics to LARS with:

$\lambda_{LARS} := \lambda \frac{|g|}{|w|} \approx \sqrt{2 \lambda \alpha (1-\rho)}.$

Here is a plot of the two quantities appearing on the right hand side of this equation during a training run with $\rho=0.9:$

There’s a slightly larger discrepancy than before because we needed to make the approximation of orthogonal gradients between batches. This assumption was not used at this stage for pure SGD. Nonetheless, there is, once again, a close similarity between the steps taken by SGD+momentum and LARS with learning rate $\sqrt{ 2 \lambda \alpha (1-\rho) }$, with the main difference being some gradient clipping in the latter case.

We applied the ‘resummation argument’ above to combine together all the weight updates corresponding to a single gradient, relying on orthogonality of gradients between batches to argue that this didn’t affect the weight norm dynamics. Last time we applied a similar argument, relying instead on low curvature of the loss, to say that training dynamics depend primarily on $\frac{ \lambda }{ 1-\rho}$ rather than either parameter individually. If we combine these two arguments, we learn that the effective contribution of a single gradient $g$ in scale invariant form is approximately:

$w \leftarrow w – \sqrt { \frac {2 \lambda \alpha }{ 1-\rho } } \frac {|w|}{|g|} g$

Confirming our result from last time, with a proper treatment of the weight norm dynamics.

So what have we learnt after all of this hard work? One lesson, is that one should be wary of plausible sounding theory! If we have a model with a symmetry and an optimiser which is not invariant, then any treatment which ignores the dynamics along orbits of the group action (of which one can find several examples in the literature!) should be treated with great caution.

Perhaps more interestingly, the connection to LARS opens up a number of research questions and the possibility of learning something useful! LARS is employed in most of the recent successful attempts to scale ImageNet training to very large batch sizes. Given the close relation between LARS and ordinary SGD with weight decay, it would be very interesting to understand what leads to the superior performance of LARS.

Is it just that large gradients are being clipped? Or that at the crucial stage early in training, SGD weight norms have not yet reached equilibrium? Either way, approaches suggest themselves to bring the two techniques closer together and isolate the important differences. If the early part of training is problematic in SGD, one could adjust weight initialisation scales to remove the initial out-of-equilibrium stage. If gradient noise is the issue, then employing lower learning rates and/or higher momenta for selected layers might improve things, or simply smoothing gradient norms between batches.

One unresolved issue from initial investigations along these lines has been to find a setting where LARS training leads to a clear improvement for the current network and CIFAR10 dataset. The results from ImageNet suggest that there should be something to find here.

Another intriguing possibility coming out of the connection to LARS, is that it could explain a remarkable stability of the optimal learning rate between architectures. Although we didn’t comment on this at the time, we relied on this property during our architecture search in an earlier post. Here we compared the training accuracy of different architectures under a fixed learning rate schedule. One might question whether such an approach would select the best architecture, or rather the architecture whose optimal learning rate is closest to the value used in the experiment.

In fact, if we plot test accuracies over a range of different learning rates for architectures from the fourth post, we see that selecting according to performance at $\lambda=0.4$ is perfectly reasonable and that all the different architectures share a similar optimal learning rate.

Note, the results in the plot above are an average of three runs and have been smoothed to highlight trends. The raw results look like this and illustrate why we are so keen to avoid hyperparameter tuning!

The stability in optimal learning rates across architectures is somewhat surprising given that we are using ordinary SGD and the same learning rate for all layers. One might think that layers with different sizes, initialisations and positions in the network might require different learning rates and that the optimal learning rate might then vary significantly between architectures because of their different layer compositions.

A possible explanation is that the LARS-like dynamics of SGD with weight decay, provides a useful type of adaptive scaling for the different layers so that each receive the same step size in scale invariant units and that this renders manual tuning of learning rates per layer unnecessary. One might probe this experimentally by manually optimising learning rates separately for each layer and seeing to what extent optimum values coincide.

In part 7 we investigate another role of batch normalisation in protecting against covariate shift during training.