This post just tries to explicate the claim in Deriving the Adjoint Equation for Neural ODEs Using Lagrange Multipliers that the vector-Jacobian product can be calculated efficiently without explicitly constructing the Jacobian . The claim is made in the Solving PL, PG, PM with Good Lagrange Multiplier section.
This post is inspired by a question asked about this topic in the comments post there.
In what follows, the variable will take the place of from the earlier post.
torchdiffeq uses torch.autograd.grad’s vJp magic
To begin let’s see how torchdiffeq, the Neural ODEs implementation from the original authors, calls pytorch AutoDiff’s torch.autograd.grad
function to calculate the vector-Jacobian product. The relevant chunk is in lines 37-46. Specifically, the call to torch.autograd.grad
is made between lines 41-44.
|
|
Here’s a table to clarify the arguments and return values of the torch.autograd.grad
function call. (For simplicity, I assume the loss depends on just the terminal timepoint and that the the batch size is . Further, I ignore ’s other inputs .)
Code | Math | Description |
---|---|---|
func_eval | (arg) The output whose derivative you want. | |
y | (arg) The state w.r.t to which you want output’s derivative. | |
adj_y | (arg) The free choice adjoint vector/lagrange multiplier. | |
vjp_y | (ret) The vector-Jacobian product. |
In this case, torch.autograd.grad(func_eval, y, adj_y)
does all the magic under the hood to compute without constructing the Jacobian . The same principle applies to calculating .
How does torch.autograd.grad
compute the vector-Jacobian product without constructing the Jacobian? Could it really be.. magic😮?
torch.autograd.grad’s vJp is not really magic
No, torch.autograd.grad
’s vector-Jacobian product calculation is not really magic.😶
To be sure, the underlying idea called reverse mode (as opposed to forward mode) automatic differentiation is ingenious and elegant. But it helps (me) to sometimes see it as just another case of optimization-by-upfront-work. Here is how we would implement reverse mode autodiff based vector-Jacobian products.
- Precompute the simplified expressions of for a whole bunch of ‘s 1.
- Simply invoke to calculate .
That’s really all there is to the reverse mode autodiff way of calculating vector-Jacobian products without Jacobians. Specifically, in step 1, the simplified expression will help exploit
- the sparsity of the Jacobian and
- the commonality in the expression for a function and its derivative
to avoid/reuse calculations so that we never have to hold the entire Jacobian in memory. We can just calculate the product very easily from and the transposed vector .
In fact we only need atmost three times the cost in FLOPS needed for calculating the original function , independent of the dimensionality of the input 2.😲
Let’s see the autodiff idea for vector-Jacobian products with a few examples.
Example 1 of vJp without Jacobian:
In this example we’ll explore how autodiff vector-Jacobian product implementation exploits Jacobian sparsity.
It is a bit contrived but imagine the state is and the NN is . Since both and are 3d, the full Jacobian is a matrix as follows.
Thus the vector-Jacobian product in this case would be
where is the elementwise multiplication operator.
So we really don’t need to store the 3×3 Jacobian matrix to calculate the vector-Jacobian product . Instead we can make a vector-Jacobian product routine that just exploits to calculate the product.
|
|
Notice that vjp_sin
never keeps around the full jacobian . Also notice that the lambda
function on line 3 only takes about twice as many flops as the original func – about the same for the cos
and the multiply
.
Example 2 of vJp without Jacobian:
In this example we’ll explore how autodiff vector-Jacobian product implementation exploits commonality in the expressions for a function and its derivative.
Suppose , where for some weight matrix . We want . We’ll assume that is 3d, is and are 2d.
From the chain rule we have
First, notice that the vector-Jacobian product in the brackets can be easily evaluated using the vjp_sin
implementation of from the preceding section. Thus
Now the Jacobian is not exactly sparse. But it is easy (though laborious) to show that for , the vector-Jacobian product is just the outer product . See 3. Thus
Here’s how we could code this up.
|
|
Again hopefully it is clear that no where are we instantiating full Jacobians.
Notes for making an actual reverse mode autodiff library
In an actual library, we would obviously want to make our calls a lot more natural and consistent. So we would do a few things like
- Pair functions to their vector-Jacobian product implementations in a dictionary like data structure so as to execute our reverse pass in an automated fashion with a simple call to a high level function like
some_loss.reverse()
. - Extend the above logic to many other types of functions like binary functions etc.
How reverse mode vJp helps for backprop in DL
Hopefully example 2 above is already starting to reveal why reverse mode autodiff is chosen for backprop implementations in deep learning libraries.
Reverse mode autodiff helps efficiently calculate when is a deep composition like and the outermost loss function L outputs a scalar.
By the chain rule we’d have
where we get the second expression by invoking the associativity of matrix multiplication.
Now notice that because is a scalar, the Jacobian is just a row vector! Let’s call the transpose of this row vector . Then becomes
Starting from the innermost parenthesis, we can just invoke the appropriate vector-Jacobian product function and move out to the next enclosing parenthesis. At that iteration we’ll find ourselves with yet another task of left-multiplying a Jacobian by a row vector, which we can once again compute efficiently using the appropriate vector-Jacobian product function. So on and so forth till we’ve successfully finished all products in the chain.
Hopefully the previous section and make the advantage and drawback of backprop abundantly clear.
- The advantage is that we can efficiently calculate the gradient of a scalar loss with respect to a very high-dimensional input (or parameter tensor).
- The drawback is that we need to store in memory all the function evaluations from our forward pass.
Back to Neural ODEs
Circling back to Neural ODEs, remember that we invoke reverse mode autodiff to calculate the vector-Jacobian product between the adjoint vector and the Jacobian of the evolution function neural network .
But from it looks like this would set off a chain of vector-Jacobian products, essentially giving us backprop with the only difference being that the vector is replaced by the adjoint vector .
What’s going on here? Don’t Neural ODEs not use backprop? 🤔
Well, actually it would be correct that this would look very much like backprop. The crucial difference is that we would not be backpropagating gradients from the loss all the way through the potentially huge computation graph created by the ODE solver steps.
Instead we would only be “backpropagating” the adjoint vector 4 from to which is a very tiny graph in comparison the entire ODE solver’s computation graph.😀
Hence the claim in Deriving the Adjoint Equation for Neural ODEs Using Lagrange Multipliers that the vector-Jacobian product can be calculated efficiently without explicitly constructing the Jacobian .
References
For simplicity, I’m considering only unary functions. ↩︎
Griewank A. A mathematical view of automatic differentiation. Acta Numerica. 2003;12:321–398. ↩︎
Clark K. Computing Neural Network Gradients. Stanford CS224n: Natural Language Processing with Deep Learning. https://web.stanford.edu/class/cs224n/readings/gradient-notes.pdf. Published December 20, 2019. Accessed February 25, 2020. ↩︎
I believe the actual term here is pulling back. ↩︎