Gradient Descent

Gradient Descent Algorithm

Given a loss function (θ) that we want to minimize, the GD algorithm works as follows:

  1. Initialize θ with θ0
  2. For each θk, calculate the gradient at θk: (θk)
  3. Take the gradient step with step size η: θk+1θkη(θk)
  4. Loop until convergence

Convex Quadratics

Let us first consider the case of quadratic problems, with objective of the type

(θ)=12θTQθqTθ,

where Q is positive definite.

With diagonalizing Q=UΛUT with orthogonal U, and letting θUTθ and qUTq, the objective becomes

(θ)=i=1di(θi),i(ϑ)=λi2ϑ2qiϑ,

where λi>0.

Taking derivatives i(ϑ)=λiϑqi, we have the optima

θi=qiλi,i=i(θi)=qi22λi.

Behavior of GD on Convex Quadratics

In order to further simplify the analysis, we shift i by a constant such that mini=0:

i(ϑ)=12λi(λiϑqi)2.

Taking a gradient step:

i(ϑη(λiϑqi))=(1λiη)2i(ϑ).

For η>0, the sufficient condition for (1λiη)2<1 is η<2λi, hence

η<2λmax,λmax=maxi{λi}.

With appropriately chosen step size η<2λmax, gradient descent converges exponentially fast to the minimum of a convex quadratic.

Optimal Convergence Rate

Considering all i, to optimize the rate of convergence, we need to minimize the maximum (1ηλi)2 term, where the rate of convergence is at the weakest:

η=argminηmaxi(1ηλi)2=argminηmax{ηλmax1,1ηλmin},

which is attained at ηλmax1=1ηλmin, when

η=2λmax+λmin<2λmax.

This results in the weakest rate of convergence

ρ=(1λmaxη)2=(λmaxλminλmax+λmin)2(κ1κ+1)2,

where κ=λmaxλmin is the condition number of Q.

With a bad convergence number, the convergence rate will be slow in one direction but fast in another, making the optimization oscillating.

Smoothness

Gradient descent can only work, if gradient does not change too much relative to the step size.

Smooth Functions

:RdR is L-smooth for some L>0, if

(θ)(θ)Lθθ,θ,θ.

Namely, the difference of gradients at two points in the parameter space is bounded by the distance of the two points, or is a Lipschitz continuous function.

lipschitz continuous functions

A function f is Lipschitz continuous, only if for every two points x and x, the slope is always bounded by a constant number:
|f(x)f(x)xx|L.

Implication of Smoothness

If is twice differentiable, taking Taylor expansion for (θ) at θ=θ:

(θ)(θ)=(θθ)T(θ)+12(θθ)TH((θ))(θθ)(θθ)T(θ)+L2θθ2.

In gradient descent, let's say θ=θη(θ):

(θ)(θ)η(1Lη2)(θ)2.

By selecting η=1L:

(θ)(θ)12L(θ)2.

One can see that:

Gradient Norm

With small gradient norm, the convergence becomes prohibitively slow. It is thus reasonable to find θ where the gradient norm is small enough. Let be a differentiable at θ, then θ is an ϵ-critical point, if

(θ)ϵ.
theorem

Gradient descent on an L-smooth, differentiable function finds an ϵ-critical point in at most k=2L((θ0))ϵ2 steps. Namely, smoothness is sufficient to find ϵ-critical points with O(ϵ2) steps of gradient descent.

proof

Let C=(θ0).

C(θk)12Lr=0k1(θr)21kr=0k1(θr)22LCk.

This means in at least one of the iterations we have

(ϑ)22LCk.

If we have reached the critical point at ϑ,

(ϑ)22LCkϵ2.

Therefore

k2LCϵ2.

Strong Convexity and the PL-condition

Polyak-Łojasiewicz Condition

A differentiable function obeys the Polyak-Łojasiewicz condition (PL condition) with parameter μ>0 if and only if

12(θ)2μ((θ)).
theorem

Let be differentiable, L-smooth and μ-PL. Then gradient descent with step size η=1L converges at a geometric rate:
(θk)(1μL)k((θ0)).

The PL condition is a fundamental property that directly implies geometric convergence to the minimum.

Strongly Convex Functions

A differentiable function is μ-strongly convex for some μ>0, if

(θ)(θ)+(θθ)T(θ)+μ2θθ2,θ,θ. 0μIH((θ))LI.

The PL condition is implied by strong convexity:

theorem

Let be μ-strongly convex, then it fulfills the PL condition with the same μ.

proof

Minimizing both sides of the strong convexity condition:

minθ(θ)(θ)=(θ)minθ(θθ)T(θ)+μ2θθ2=12μ(θ)2.

Therefore,

12(θ)2μ((θ)),

which gives the PL condition.

In DNNs, the PL condition will typically not hold globally, but possibly over a domain around a local minimum. It then ensures fast local convergence to this critical point without making claims to its sub-optimality.

Momentum and Acceleration

Saddle Points

The training objective of a DNN is usually non-convex. It thus may contain saddle points which can slow down GD in its neighborhood. Therefore, it is useful to introduce some noise to the GD algorithm, i.e. we can compensate small gradients by smoothness.

Heavy Ball Method

In the heavy ball method, we add a β-weighted term that includes the change made in the previous update:

θk+1θkη(θk)+β(θkθk1),β(0,1).

With constant gradient , one can show that

limk(θkθk1)=ηi=1βi=[η1β].

Therefore, by using large momentum, i.e. β1, one can boost the effective step size by an arbitrary large factor.

Practically, as the gradient is not a constant, a too large β will create oscillations and instabilities. Therefore, β is usually selected in the range [0.9,0.95].

Nesterov Acceleration

Nesterov acceleration pursues the same idea as the heavy ball method, but evaluates the gradient at the extrapolated point:

ϑk+1=θk+β(θkθk1),θk+1=ϑk+1η(ϑk+1).

AdaGrad

AdaGrad uses adaptive learning rate for each single dimension. It uses the history of gradients at previous iterations to influence the effective step size. Defining

γikγik1+[i(θk)]2,iθi,

which is the sum of the squares f the i-th parameter's partial derivatives. We can use these estimates to adapt the step size of each parameter:

θik+1θikηiki(θk),ηikηγik+δ,

where δ is a small positive constant for numeric stability. Parameters with historically smaller magnitudes of their partial derivatives are updated with an effectively larger step size.

Adam and RMSprop

Adaptive momentum estimation, as known as Adam, is the state-of-the-art learning algorithm. It combines the benefits of momentum and AdaGrad, using an exponentially weighted average to estimate the mean and variance of each partial derivative:

gik=βgik1+(1β)i(θk),gi0i(θ0),hik=αhik1+(1α)[i(θ)k]2,hi0[i(θ0)]2.

The update rule becomes

θik+1=θikηikgik,ηikηhik+δ.

Adam without the use of momentum is called RMSprop.