This post just tries to explicate the claim in Deriving the Adjoint Equation for Neural ODEs Using Lagrange Multipliers that the vector-Jacobian product λfz\lambda^\intercal \frac{\partial f}{\partial z} can be calculated efficiently without explicitly constructing the Jacobian fz\frac{\partial f}{\partial z}. The claim is made in the Solving PL, PG, PM with Good Lagrange Multiplier section.

This post is inspired by a question asked about this topic in the comments post there.

In what follows, the variable yy will take the place of zz from the earlier post.

torchdiffeq uses torch.autograd.grad’s vJp magic

To begin let’s see how torchdiffeq, the Neural ODEs implementation from the original authors, calls pytorch AutoDiff’s torch.autograd.grad function to calculate the vector-Jacobian product. The relevant chunk is in lines 37-46. Specifically, the call to torch.autograd.grad is made between lines 41-44.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
with torch.set_grad_enabled(True):
    t = t.to(y[0].device).detach().requires_grad_(True)
    y = tuple(y_.detach().requires_grad_(True) for y_ in y)
    func_eval = func(t, y)
    vjp_t, *vjp_y_and_params = torch.autograd.grad(
        func_eval, (t,) + y + f_params,
        tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
    )
vjp_y = vjp_y_and_params[:n_tensors]
vjp_params = vjp_y_and_params[n_tensors:]

Here’s a table to clarify the arguments and return values of the torch.autograd.grad function call. (For simplicity, I assume the loss depends on just the terminal timepoint t1t_1 and that the the batch size is 11. Further, I ignore ff’s other inputs t,θt, \theta.)

CodeMathDescription
func_evalf(t1,y(t1),θ)f(t_1, y(t_1), \theta)(arg) The output whose derivative you want.
yy(t1)y(t_1)(arg) The state w.r.t to which you want output’s derivative.
adj_yλ(t1)\lambda(t_1)^\intercal(arg) The free choice adjoint vector/lagrange multiplier.
vjp_yλ(t1)fy(t1)\lambda(t_1)^\intercal \frac{\partial f}{\partial y(t_1)}(ret) The vector-Jacobian product.

In this case, torch.autograd.grad(func_eval, y, adj_y) does all the magic under the hood to compute λ(t1)fy(t1)\lambda(t_1)^{\intercal} \frac{\partial f}{\partial y(t_1)} without constructing the Jacobian fy(t1)\frac{\partial f}{\partial y(t_1)}. The same principle applies to calculating λ(t1)fθ\lambda(t_1)^\intercal \frac{\partial f}{\partial \theta}.

How does torch.autograd.grad compute the vector-Jacobian product without constructing the Jacobian? Could it really be.. magic😮?

torch.autograd.grad’s vJp is not really magic

No, torch.autograd.grad’s vector-Jacobian product calculation is not really magic.😶

To be sure, the underlying idea called reverse mode (as opposed to forward mode) automatic differentiation is ingenious and elegant. But it helps (me) to sometimes see it as just another case of optimization-by-upfront-work. Here is how we would implement reverse mode autodiff based vector-Jacobian products.

  1. Precompute the simplified expressions of vjp(f,y)(v)=vf(y)y\text{vjp}_{(f, y)}(v) = v^\intercal \frac{\partial f(y)}{\partial y} for a whole bunch of f(y)f(y)‘s​ 1​.
  2. Simply invoke vjp(f,y)(λ)\text{vjp}_{(f, y)}(\lambda) to calculate λf(y)y\lambda^\intercal \frac{\partial f(y)}{\partial y}.

That’s really all there is to the reverse mode autodiff way of calculating vector-Jacobian products without Jacobians. Specifically, in step 1, the simplified expression vjp(f,y)(v)\text{vjp}_{(f, y)}(v) will help exploit

  • the sparsity of the Jacobian and
  • the commonality in the expression for a function and its derivative

to avoid/reuse calculations so that we never have to hold the entire Jacobian in memory. We can just calculate the product very easily from f,yf, y and the transposed vector vv.

In fact we only need atmost three times the cost in FLOPS needed for calculating the original function ff, independent of the dimensionality of the input 2​.😲

Let’s see the autodiff idea for vector-Jacobian products with a few examples.

Example 1 of vJp without Jacobian: sin(y)sin(y)

In this example we’ll explore how autodiff vector-Jacobian product implementation exploits Jacobian sparsity.

It is a bit contrived but imagine the state is y(t)=[y1(t),y2(t),y3(t)]y(t)=[y_1(t), y_2(t), y_3(t)] and the NN is f(t,y(t),θ)=sin(y(t))=[sin(y1(t)),sin(y2(t)),sin(y3(t))]f(t, y(t), \theta) = \sin(y(t)) = [\sin(y_1(t)), \sin(y_2(t)), \sin(y_3(t))]. Since both ff and yy are 3d, the full Jacobian fy\frac{\partial f}{\partial y} is a 3×33×3 matrix as follows.

fy=[f1y1f1y2f1y3f2y1f2y2f2y3f3y1f3y2f3y3]=[cos(y1)000cos(y2)000cos(y3)]=cos(y)\begin{aligned} \frac{\partial f}{\partial y} &= \begin{bmatrix} \frac{\partial f_1}{\partial y_1} & \frac{\partial f_1}{\partial y_2} & \frac{\partial f_1}{\partial y_3} \\ \frac{\partial f_2}{\partial y_1} & \frac{\partial f_2}{\partial y_2} & \frac{\partial f_2}{\partial y_3} \\ \frac{\partial f_3}{\partial y_1} & \frac{\partial f_3}{\partial y_2} & \frac{\partial f_3}{\partial y_3} \end{bmatrix} = \begin{bmatrix} \cos(y_1) & 0 & 0 \\ 0 & \cos(y_2) & 0 \\ 0 & 0 & \cos(y_3) \end{bmatrix} \\ & = \cos(y) \end{aligned}

Thus the vector-Jacobian product in this case would be λfy=[λ1cos(y1)λ2cos(y2)λ3cos(y2)]=(λcos(y))(1)\begin{aligned} \lambda^\intercal \frac{\partial f}{\partial y} & = \begin{bmatrix} \lambda_1 \cdot \cos(y_1) & \lambda_2 \cdot \cos(y_2) & \lambda_3 \cdot \cos(y_2) \end{bmatrix} \\ &= (\lambda \odot \cos(y))^\intercal \end{aligned} \tag{1}

where \odot is the elementwise multiplication operator.

So we really don’t need to store the 3×3 Jacobian matrix fy\frac{\partial f}{\partial y} to calculate the vector-Jacobian product λfy\lambda^\intercal \frac{\partial f}{\partial y}. Instead we can make a vector-Jacobian product routine that just exploits (1)\text{(1)} to calculate the product.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def vjp_sin(f, y):
  """For f = sin(y), returns v.T * df/dy without making df/dy."""
  return lambda v: np.multiply(v, np.cos(y))

# Suppose
# 1. y is the input vector
# 2. adj_y is the adjoint vector

# forward pass
func_eval = np.sin(y) # (store func_eval for later)

# vjp from reverse mode autodiff
vjp_result = vjp_sin(func_eval, y)(adj_y)

Notice that vjp_sin never keeps around the full jacobian fy\frac{\partial f}{\partial y}. Also notice that the lambda function on line 3 only takes about twice as many flops as the original sin\sin func – about the same for the cos and the multiply.

Example 2 of vJp without Jacobian: sin(Wy)sin(Wy)

In this example we’ll explore how autodiff vector-Jacobian product implementation exploits commonality in the expressions for a function and its derivative.

Suppose f(z)=sin(z)f(z) = \sin(z), where z=Wyz=Wy for some weight matrix WW. We want vfWv^\intercal \frac{\partial f}{\partial W}. We’ll assume that yy is 3d, WW is 2×32 \times 3 and z,f(z),vz, f(z), v are 2d.

From the chain rule we have vfW=vfzzW=(vfz)zWv^\intercal \frac{\partial f}{\partial W} = v^\intercal \frac{\partial f}{\partial z} \frac{\partial z}{\partial W} = \left( v^\intercal \frac{\partial f}{\partial z} \right) \frac{\partial z}{\partial W}

First, notice that the vector-Jacobian product in the brackets can be easily evaluated using the vjp_sin implementation of (1)\text{(1)} from the preceding section. Thus vfW=(vcos(z))zWv^\intercal \frac{\partial f}{\partial W} = (v \odot \cos(z))^\intercal \frac{\partial z}{\partial W}

Now the Jacobian zW\frac{\partial z}{\partial W} is not exactly sparse. But it is easy (though laborious) to show that for z=Wyz = Wy, the vector-Jacobian product vzW{v^\prime}^\intercal \frac{\partial z}{\partial W} is just the outer product vy{v^\prime} \otimes y. See 3​. Thus vfW=(vcos(z))yv^\intercal \frac{\partial f}{\partial W} = (v \odot \cos(z)) \otimes y

Here’s how we could code this up.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def vjp_sin(f, z):
  """Just copied from the preceding section. Remember f = sin(z)."""
  return lambda v: np.multiply(v, np.cos(z))

def vjp_matmul_W(z, W, y):
  """This is the vjp for v.T dz/dW = v.T d(Wy)/W"""
  return lambda v: np.outer(v, y) # No need to transpose v or y. Numpy smart.

# Suppose 
# 1. y is the input vector
# 2. W is a weight matrix
# 3. adj_y is the adjoint vector

# forward pass
z = np.dot(W, y) # (store z for later)
func_eval = np.sin(z) # (store func_eval for later)

# vjp from reverse mode autodiff
vjp_intermediate = vjp_sin(func_eval, z)(adj_y)
vjp_result = vjp_matmul_W(z, W, y)(vjp_intermediate)

Again hopefully it is clear that no where are we instantiating full Jacobians.

Notes for making an actual reverse mode autodiff library

In an actual library, we would obviously want to make our calls a lot more natural and consistent. So we would do a few things like

  1. Pair functions to their vector-Jacobian product implementations in a dictionary like data structure so as to execute our reverse pass in an automated fashion with a simple call to a high level function like some_loss.reverse().
  2. Extend the above logic to many other types of functions like binary functions etc.

How reverse mode vJp helps for backprop in DL

Hopefully example 2 above is already starting to reveal why reverse mode autodiff is chosen for backprop implementations in deep learning libraries.

Reverse mode autodiff helps efficiently calculate fy\frac{\partial f}{\partial y} when ff is a deep composition like f(y)=L(fl(fl1(..(f1(y)))))f(y) = L(f_l(f_{l-1}(..(f_1(y))))) and the outermost loss function L outputs a scalar.

By the chain rule we’d have fy=Lflflfl1fl1fl2f2f1=(((Lflflfl1)fl1fl2))f2f1(2)\begin{aligned}\frac{\partial f}{\partial y} &= \frac{\partial L}{\partial f_l} \frac{\partial f_l}{\partial f_{l-1}} \frac{\partial f_{l-1}}{\partial f_{l-2}} … \frac{\partial f_{2}}{\partial f_{1}} \\ &= \left(\left(\left( \frac{\partial L}{\partial f_l} \frac{\partial f_l}{\partial f_{l-1}} \right) \frac{\partial f_{l-1}}{\partial f_{l-2}}\right) …\right) \frac{\partial f_{2}}{\partial f_{1}} \tag{2} \end{aligned}

where we get the second expression by invoking the associativity of matrix multiplication.

Now notice that because LL is a scalar, the Jacobian Lfl\frac{\partial L}{\partial f_l} is just a row vector! Let’s call the transpose of this row vector gg. Then (2)\text{(2)} becomes fy=(((gflfl1)fl1fl2))f2f1(3)\begin{aligned}\frac{\partial f}{\partial y} &= \left(\left(\left( g^\intercal \frac{\partial f_l}{\partial f_{l-1}} \right) \frac{\partial f_{l-1}}{\partial f_{l-2}}\right) …\right) \frac{\partial f_{2}}{\partial f_{1}} \tag{3} \end{aligned}

Starting from the innermost parenthesis, we can just invoke the appropriate vector-Jacobian product function and move out to the next enclosing parenthesis. At that iteration we’ll find ourselves with yet another task of left-multiplying a Jacobian by a row vector, which we can once again compute efficiently using the appropriate vector-Jacobian product function. So on and so forth till we’ve successfully finished all products in the chain.

Hopefully the previous section and (3)\text{(3)} make the advantage and drawback of backprop abundantly clear.

  • The advantage is that we can efficiently calculate the gradient of a scalar loss with respect to a very high-dimensional input (or parameter tensor).
  • The drawback is that we need to store in memory all the function evaluations from our forward pass.

Back to Neural ODEs

Circling back to Neural ODEs, remember that we invoke reverse mode autodiff to calculate the vector-Jacobian product between the adjoint vector λ(t1)\lambda(t_1) and the Jacobian fy(t1)\frac{\partial f}{\partial y(t_1)} of the evolution function neural network f(y)f(y).

But from (3)\text{(3)} it looks like this would set off a chain of vector-Jacobian products, essentially giving us backprop with the only difference being that the vector gg is replaced by the adjoint vector λ(t1)\lambda(t_1).

What’s going on here? Don’t Neural ODEs not use backprop? 🤔

Well, actually it would be correct that this would look very much like backprop. The crucial difference is that we would not be backpropagating gradients from the loss all the way through the potentially huge computation graph created by the ODE solver steps.

Instead we would only be “backpropagating” the adjoint vector​ 4​ from f(y(t1))f(y(t_1)) to y(t1)y(t_1) which is a very tiny graph in comparison the entire ODE solver’s computation graph.😀

Hence the claim in Deriving the Adjoint Equation for Neural ODEs Using Lagrange Multipliers that the vector-Jacobian product λfz\lambda^\intercal \frac{\partial f}{\partial z} can be calculated efficiently without explicitly constructing the Jacobian fz\frac{\partial f}{\partial z}.

References


  1. ​For simplicity, I’m considering only unary functions. ↩︎

  2. Griewank A. A mathematical view of automatic differentiation. Acta Numerica. 2003;12:321–398. ↩︎

  3. Clark K. Computing Neural Network Gradients. Stanford CS224n: Natural Language Processing with Deep Learning. https://web.stanford.edu/class/cs224n/readings/gradient-notes.pdf. Published December 20, 2019. Accessed February 25, 2020. ↩︎

  4. ​I believe the actual term here is pulling back. ↩︎