A major topic of this post series is the investigation of methods to reduce the number of required training epochs for ResNets. In particular with respect to image analysis. Our test case is defined by a ResNet56v2 neural network trained on the CIFAR10 dataset. For intermediate results of numerical experiments
- with a piecewise constant and a piecewise linear schedule for the learning rate LR
- and the application of Adam and AdamW in combination with L2-regularization
see the first two posts
- AdamW for a ResNet56v2 – I – a detailed look at results based on the Adam optimizer
- AdamW for a ResNet56v2 – II – linear LR-schedules, Adam, L2-regularization, weight decay and a reduction of training epochs
During the last week I performed a series of experiments with the optimizer AdamW. While I could bring the number of training epochs down below 30, I stumbled across an important topic which came up due to a major difference between Adam and AdamW regarding two methods of weight regularization: “L2-regularization vs. (pure) weight decay“.
Removing L2-regularization completely and raising the strength of weight decay for AdamW had major consequences: Huge oscillations of the validation loss were triggered during certain training phases – but the approach nevertheless improved validation accuracy! Interesting, isn’t it?
This post will provide information to better understand the differences between the two regulatory methods and their relation to optimizers. As I will use formulas it is a theoretical excursion. But we will need at least a basic understanding of pure weight decay for the discussion of the results of further numerical experiments.
L2 regularization …
The models trained so far applied explicit L2-regularization to the weights of the layers. Technically, this was done by providing a respective “l2”-parameter to the Keras functions for layer creation. (In the case of AdamW this L2- regularization came in addition to a very tiny amount of explicit pure weight decay; see the discussion below). What is my understanding of L2-regularization?
Text books on ML introduce L2-regularization as one of the methods to avoid overfitting. An overfitting model adapts too strongly to specific properties of sub-sets or even individual samples of the training data (high variance in its predictions). As a result, overfitting leads to a big gap between the accuracy achieved on the training data and the accuracy achieved for yet unseen and varying sets of test samples.
It has been observed that overfitting comes with rather specific weight values. The reason is that a network uses all of its degrees of freedom to optimally fit the given training data via adjusting individual weights. This might lead to exceptionally big individual weight values for the mapping and distinction of “detected”, but highly specific “correlations” properties of a few special samples.
We we want instead is a good generalization capability of our model with respect to yet unknown data sets. Somehow we have to reduce the degrees of freedom of a model and thus its complexity. Possible measures are e.g.: Limiting the value range of trainable model parameters (regularization) or limiting the number of result defining parameters (e.g. by a statistical dropout of neurons).
A typical approach to overcome overfitting is to add a constraining regulatory term R to the loss function L(yT,i, yi(xi), wi). Just for a brief repetition of some basics:
The results yi calculated by the model define the prediction error for training samples with given target label values yT,i . The loss function L() itself punishes the error for a given state of the model’s weights wi achieved during optimization iterations. Optimization is typically guided by a gradient descent method to a (hopefully) global minimum of the loss function which is defined over the multidimensional space spanned by coordinate axes for the weight (and other) parameter values.
Let us index an iteration step by “t“. Let us further define the calculated y-values as results of a function fNet that maps input to output vectors and represent all weights wj,t by a vector wt we can write (see e.g. [2])
To restrict the degrees of freedom and to filter the correlations for the general and really important ones, one approach is to restrict a measure of the total magnitude of all weight values, i.e. to restrict a defined norm of wt. This is e.g. achieved by adding an additional term proportional to the Euclidean l2-norm of all the network’s weights (expressed as a vector) to the loss function
The parameter λL2 controls the strength of the regularization. During error back propagation we try to a find a (global) minimum of the convexly shaped loss LL2(). We assume that such a minimum of LL2() would eventually also reflect small weight values. I.e., while approaching a (hopefully) global minimum of LL2() we hope for an effective reduction of the absolute values of the weights.
Another interpretation of L2-regularization is that we add some “noise” to the loss. The idea is that we thus force the network to detect the really important patterns and correlations in the data above the induced basic noise level.
In the L2-case the noise term εnoise can be associated with the regulatory term based on the l2-norm given above. Note that any grain of noise would also help to overcome gradient problems, e.g. at saddle points or local minima of the multidimensional loss hyper-surface (over the weights’ parameter space).
How small the L2-based “noise” term becomes in the end depends of course on the efficiency of the weight value reduction during training. This depends on the impact of the quadratic term on the form and the variation of the loss hyper-surface – in particular around its (global) minimum. Furthermore: “Noise” is relative. The smaller the standard loss value in its undisturbed minimum, the more small contributions of the final L2 regularization term will modify the surroundings of the this minimum.
Anyway: The reduction of the weight values should occur such that the eventual L2-contribution does not smear out an original sharp functional minimum of the unmodified loss too much. It should not turn a sharp minimum into a relatively flat trough area with only a small dip in the middle. However, my own experience (with optimizers other than SGD) and various image data sets can be described as follows:
- L2-regularization raises the basic level of the minimum. I.e. the additional loss term contributes substantially to the overall LL2() loss value at the minimum reached during the last phase of a training run.
- L2-regularization smears out and flattens the environment of the minimum on the LL2() loss hyper-surface in comparison the respective minimum on the undisturbed L() hyper-surface.
In combination these two effects of L2-regulaization may lead to a noticeable reduction in test accuracy – in comparison to runs employing other methods. We have to fight with a delicate balance of generalization and side effects as the smear out of minima:
On the one hand we need regularization to improve the generalization capability of a model and thus also accuracy on unseen new test data. On the other hand side effects of the regularization (noise) may become counterproductive and even reduce the achievable eventual accuracy. Thus we end up with experiments and systematic parameter variations to find a model and dataset dependent optimal value of λL2.
Pure Weight Decay and SGD
Let us approach the whole topic of weight regularization from the perspective of a weight update step, including a schedule of the learning rate LR. Without L2-regularization and without momentum based optimization the weights are typically updated during an iteration (indexed by “t“) of gradient descent as follows :
“α” is the initial learning rate LR. “ηt” (ηt < 1) is a factor coming from a schedule for LR reduction. “t” is an index for the iteration steps during training.
In the extreme case of SGD (Stochastic Gradient Descent) the gradient is evaluated per sample, in other cases it is evaluated as an average over a mini-batch of samples (indexed by i). Below we ignore the respective subtleties of averaging in the mathematical formulas. If you want to look at them rigorously assume SGD and regard x and yT as vectors representing just one training sample (batch size BS=1).
Now, if we wanted to reduce the weights systematically whilst iterating, what would we do? An efficient solution is to add a reduction term proportional to the weight value itself:
λwd is a parameter controlling the strength of this type of regularization. Regrouping of terms leads to
The factor applied to wt explains why this approach is called “weight decay“. We systematically diminish the length of the weight vector. This straightforward approach has two advantages:
- The proportionality to the weight itself ensures an efficient and strong reduction per iteration step. The reduction automatically adapts to the present size of the weight vector.
- The weight reduction is directly coupled to the LR-schedule. Thereby it is guaranteed that the weight decay impact is systematically reduced according to the schedule. In an eventual training phase it may therefore become smaller than the changes prescribed by the gradient into an undisturbed functional minimum of the loss.
By this approach we basically do not smear out an eventual deep minimum – provided that the weight decay has operated efficiently enough before we approach the bottom of the minimum’s environment.
To distinguish this approach from L2-regularization I call it “pure weight decay“. One of the problems in some discussions is that the expression “weight decay” was also used for L2-regularization in some literature. The reason becomes clear in the next section.
Equivalence of L2 regularization and weight decay for SGD
The reason why the term “weight decay” sometimes is also used to describe L2-regularization is that for SGD (Stochastic Gradient Descent) both methods are equivalent. We understand this by differentiating the LL2–loss to get the gradient for the weight update in SGD with a constant learning rate α (< 1):
I.e. by adding the L2-norm-based regulatory term to the loss we punish large weights in the same way as pure “weight decay” does. Adding a schedule gives us:
By comparing this formula with (7) we find
We see that for SGD, the weight decay parameter λwd is a scaled version of the L2-regularization parameter λL2:
Otherwise the methods are fully equivalent. We can even extend the formalism to a simple type of momentum driven optimization. Let us define a very simple momentum approach by defining a momentum evolution as
Then with we get something like
The L2-induced weight decay term remains separated.
Note: The scaling of a weight decay parameter as expressed in (10) should make us careful when we provide such parameters to numerical functions of Scikit Learn, Keras or PyTorch. We should find out whether the parameter is scaled with the learning rate by the function itself or not.
Weight optimization with Adam: L2 regularization and weight decay are not equivalent
Modern optimizers track and normalize the evolution of a gradient. A very important, but complex momentum driven optimizer is, of course, Adam. In contrast to the simple form of momentum driven descent described above it uses normalized contributions of the so called 1st momentum and the 2nd momentum of the gradient evolution. The momentum degree refers to a Taylor like description of the exponentially weighted moving averages (ema) of the gradient itself and of its square (see e.g. [3], [2] and [1]).
Therefore, the inclusion of regulatory terms in the loss is much more difficult to describe than for SGD. I follow a description outlined in [1]. Let us rewrite the sum of gradient and a L2 induced regulatory term as
The momentum handling of Adam (with factors β1 for the 1st momentum and β2 for the second momentum) uses the following terms for the iterative momentum updates
Without any further corrections Adam updates the weights like this:
εtiny is a parameter which avoids numerical errors. In the case of large t-values the denominators of the modified momentum terms go to 1 – and we get
This obviously means that the L2 part
gets divided by the denominator, too. This is obviously not consistent with the pure weight decay formalism described above, which should lead to a plain term
in the expression for the momentum update. We cannot separate such a term out of fraction expression in (15).
The named division may have the consequence that the weights are not as effectively reduced as it happens in the case of the SGD optimizer. In particular not in phases where the gradient is still big. This was explicitly criticized in [1].
The step forward which Loshchilov, F. Hutter achieved and tested experimentally in [1] was that they just corrected the momentum based optimizer by adding an explicit pure “weight decay” term:
Thereby they strongly indicated a logical difference in the “meaning” of L2-regularization and weight decay. The term at the end on the right side of (16) enforces a strong reduction of the weights – proportional to their size. The weights could under perfect conditions can get close to 0 in the end. A suitable LR schedule will in addition guarantee that the shape of the surroundings of loss minimum is not modified by side effects. The minimum will be approached guided by the loss gradient and momentum, only, in the end. Note the scaling of the factor λwd by α !
But is this approach consistent?
I would say, it depends on the final coding of AdamW.
The big question is what the first term in the update formula (16) really includes. If it only includes the gradient of the plain loss function L() without any regularizing term, i.e.
then the adaptive momentum ignores the initially large contribution corresponding to
This term in turn would reflect some kind of gradient evolution due to some (functionally unclear) loss contribution if we included it somehow in the loss
So, this kind of approach would be a bit confusing, to say the least. The momentum would ignore the impact of any such decaying loss contribution and its evolving gradient onto the momentum. Which might become problematic in special situations …
But if it is indeed done this way, the big “advantage” would be that we, at last after a strong decay of the learning rate, would reach a minimum of the original undisturbed loss function L(). This might also improve the eventual accuracy.
An indication of this point in numerical experiments would be that the eventual loss achieved with Adam and some L2-regularization should be significantly bigger that the one we get with AdamW. We will clarify this question by the experiments in the next post.
The description in [1] of how the authors actually realized the AdamW optimizer for their experiments is a bit unclear in my opinion. And analyzing the code of the Keras implementation would be time consuming. For the time being we have to wait for results of our own numerical experiments. Which will come in the next post.
But it appears to be fair to assume that by pure weight decay and by applying a reasonable LR schedule, which in the end reduces the learning rate substantially, we would achieve three things:
- We would at the end of training approach the undisturbed real minimum of the loss without regularization.
- All side effects of ignoring a weight decay based loss contribution during momentum evolution would probably be damped out by a LR schedule with a strong reduction of the step size ahead of a final phase of the training.
- All side effects of the enforced weight decay during gradient descent – like a continuous drift towards a region of the loss hyper-surface at low weight values – would be damped out by a LR schedule with a strong reduction of the step size ahead of an eventual training phase.
Noise in the validation loss due to a continuous “side drift” to lower weight values?
Let us look at weight decay from the perspective of the models movement across the loss hyper-surface during weight optimization. This requires some imagination.
A well known source of noise during training is the overshooting effect of adaptive, momentum driven optimizers – in particular in bumpy regions of the loss hyper-surface with steep walls on the sides of narrow valleys. Such noise will in particular become visible in the validation loss. Why is this?
We should always be aware of the fact that the validation loss hyper-surface is not identical to the loss hyper-surface – although major structures will be very similar. If the structure of the validation loss surface deviates a bit from that of the training data then small differences in the length of the momentum vector and in the direction to which the optimized “momentum” vector points may make a huge difference for the calculated values of the validation or test loss. Such effects may e.g. occur when the direction of a steep valley in the loss hyper-surface changes and the momentum still follows the original direction.
With pure weight decay we also may run into a similar situation, but for a different reason. Weight decay obviously tries to drive us across the loss hyper-surface (over the multidimensional weight space) towards a region with lower weight values. The direction of this region may deviate from the direction to where the loss gradient points to. Let us call this tendency to move sideways from the direction given by the loss gradient (or its momentum) the “weight decay drift“.
Weight decay drift: A continuous overall tendency of a model – caused by pure weight decay – to move towards lower weight value regions in all w-dimensions – as far as the shape of loss hyper-surface allows for it.
The basic idea behind “weight value regularization improves generalization” is that this drift at saddle or bifurcation points will in general lead us to a “better” overall loss minimum for validation and (yet unknown) test data.
As long as the side drift happens in a region with a dominant overall gradient of the loss, this gradient will keep us on an overall downward track. For big weight decay values λwd we may, however, sooner or later reach areas where the incremental shift towards lower values becomes significant in comparison to the movement into the direction of the loss gradient. This occurs in particular if and when the gradient flattens out in the main movement direction. In such a situation local differences between the validation loss hyper-surface and the loss hyper-surface may amplify secondary effects of side-way movements.
Imagine we move along a river bed in relatively flat valley in flow direction, but the valley has steep walls at some limited and relatively small distance in perpendicular direction to the river bed. Now imagine a situation where the the main direction of the river bed changes and the walls “move” a bit closer to the river bed. Any sideway movement may then lead us up the walls.
In terms of the loss: While we still may have relatively small changes in the training loss, a slightly narrower valley or a closer wall in the validation surface – in drift direction – may drive the validation loss to much bigger values than the loss itself. Such a situation would of course dependent on specific properties of the data set. But in principal it could lead to significant fluctuations in the validation loss.
We expect such situations to occur in the vicinity of systematic and abrupt direction changes of the gradient inside deep valleys, or at saddle or at bifurcation points in a rough environment of the loss hyper-surface.
Such fluctuations in the validation loss (if they occur) would, however, be strongly damped as soon as the sideway drift induced by weight decay is significantly reduced by the LR schedule.
Therefore, the chance of hitting a local hinder on the validation loss hyper-surface will be largest when the LR-schedule has not diminished the weight decay drift too much in relation to the gradient. I..e. after an initial phase where we typically follow large gradients down (on relatively broad walls) until we get into regions of major and pronounced variations of the loss surface and before an eventual phase where ηt * α gets really small (i.e. orders of magnitude smaller than the initial α).
Summary: For AdamW we expect a relatively strong drift of the model towards lower weights. This drift may deviate from the direction of the loss gradient and in some regions of the loss hyper-surface the drift based movement may become comparable to the gradient induced movement per update step. If this happens in narrow or narrowing valleys disturbances in the evolution of the validation loss are possible, even if the effect of the drift (induced by weight decay) on the training loss remains relatively small. Note that such fluctuations of the validation loss would not prevent a convergence towards a (global) minimum of the loss hyper-surface if the LR schedule shrinks the learning rate and the weight decay significantly in a later phase.
Conclusion and outlook
In this post we have seen that the momentum handling of the Adam optimizer (in contrast to SGD) violates the equivalence of “pure weight decay” and standard “L2-regularization”. We have understood that adding an explicit “weight decay” term to the weight update could recover the advantages of pure weight decay. This would include an effective reduction of the weight values – as long as this is compatible with the general path of gradient descent over the loss hyper-surface during training. Due to its regularizing effect a drift induced by the weight decay towards regions of lower weight values may find paths to a loss minimum which marks a well generalizing weight distribution after training.
A possible side effect of the drift towards low weight values could be significant fluctuations in the validation loss – even when the fluctuations in the training loss remain small. We should look out for such effects in test runs with AdamW performed without any explicit L2-regularization and with relatively large values of the weight decay parameter λwd.
A LR schedule reducing the weight decay drift significantly in a final phase will allow the optimizer to follow the eventual gradient into the full depth of the found loss minimum. We may thus reach lower minimum values than in the case of L2 regularization where side effects of the additional noisy L2 loss typically would both broaden the area around a minimum and lift its bottom to higher values. Therefore, pure weight decay may result in an improved validation accuracy in comparison to L2-regularization.
While we are not yet completely sure that the momentum in AdamW is based on the evolution of the standard loss gradient, only, we are convinced that we would recognize such an implementation from very low loss values and a lift of the minimum value to higher values as soon as we add a standard L2-regularization term to the loss.
Respective numerical experiments will be the topic of the next posts. See e.g.:
We will also see that with the help of AdamW we can reduce the number of required training epochs for CIFAR10 to about 30 – and still get good validation accuracy values.
Links and literature
[1] I. Loshchilov, F. Hutter, 2018, “Fixing Weight Decay Regularization in Adam“, arXiv
[2] I. Drori, 2023, “The Science Of Deep Learning”, Cambridge University Press
[3] A. Geron, 2023, “Hands-On Machine Learning with Scikit-Learn, Keras & TensorFlow, O’Reilly, Sebastopol, CA