Optimal Gradient Descent

-- Not too fast, not too slow! Make it just so!

Let's start by describing gradient descent and then understanding what optimal gradient descent is, and why it is so. Gradient descent is a very simple and elegant way of finding the minimum/maximum of convex/concave functions respectively. If you are like me, you probably tend to forget the difference between convex and concave. I present to you below a simple trick for distinguishing between the two:

convex function
convex function
The trick is rememebering a concave function looks like a cave. This will help you remember that more generally concave functions have maxima while convex functions have minima.

When talking about gradient descent, it is common to analyze convex functions. The descent in gradient descent often refers to descending towards the minimum of the function. This is however not too important as any concave function can be optimized the same way ... all you have to do is multiply the function by -1 to flip it along its dependent axes ... and voila you have a convex function.

OK so getting back on track to gradient descent. The idea behind gradient descent is to take a step against the direction of the convex function's gradient -∇f. By repeating this process until convergence we are guaranteed to find the minimum. Note: this is only true when talking about strictly convex functions. For an arbitrary function, having many local minima, gradient descent can converge to any of the possible minima. The entire algorithm can be described by the following:

    X(0) = [x1(0), ..., xi(0), ..., xn(0) ]
    for t = 1 to T do:
        for i = 1 to n do:
            xi(t+1) = xi(t) - λ(∂f(X(t))/∂xi)

If you've been paying close attention you will notice that we introduced a new symbol λ into the algorithm. This parameter represents the step size or the learning rate with which we want to take a step in the direction of the gradient. If you take large steps, then the algorithm is likely to oscillate a lot before converging. Conversely, taking small steps will increase the time taken to converge. The trick is identifying the optimal learning rate. But before we go into what the optimal value for λ is, let's look at a simple example.

Consider using gradient descent to find the minimum of the convex function f(x) = x2. Starting at any random point x(0) we can keep computing x(t+1) = x(t) - λf'(x) = x(t) - λ * 2x until we reach the minimum. We can see the effects of variying step size very easily by examining the following plots:

small step size
large step size
good step size
perfect step size

Getting that perfect step size allows us to get to the minimum in the most efficient manner possible. So how do we actually compute this perfect step size? The perfect step size for gradient descent is given by λ = 1/f''(x) in 1-D and more generally by λ = H-1(X) where H is the Hessian (matrix of second partial derivatives). Going back to our example, we see right away that the optimal step size for f(x) = x2 is given by λ = 1/f''(x) = 1/2 = 0.5.

To understand why the second derivative plays such a critical role in determining the optimal learning rate, consider the taylor series expansion for any function in 1-D centered around the point x0: y = f(x0) + f'(x0)(x-x0) + (1/2)f''(x0)(x-x0)2 + higher order terms. The higher order terms tend to vanish very rapidly in the vicinity of x0 due to the increasing powers on the delta terms, as such they can be dropped with no consequence. Considering the taylor expansion up to second order terms only, the gradient of y is given by y' = f'(x0) + f''(x0)(x-x0) and the minimum value that y takes in the vicinity of x0 is given by the point where y' = 0. Hence we get x = x0 - f'(x0)/f''(x0). Compare this to the original update equation x(t+1) = x(t) - λf'(x(t)) and we see right away that λ = 1/f''(x)!