L-infinity

Computational fluid dynamics, multiphase flows, machine learning, OpenFOAM

24 Feb 2021

Some calculus with torch::autograd

DOI

Most neural networks learn by minimizing a scalar-valued error. There is ample information about the back-propagation calculus in neural networks, like this great video. Interestingly, back-propagation in neural networks is, in fact, reverse-mode Automatic Differentiation (AD). Automatic differentiation is a computational technique that enables the calculation of exact derivatives from arithmetic expressions. The exact derivative calculation is crucial for ensuring the convergence when training neural networks because, as the network learns, the differences in the values of its weights between iterations become smaller. If these barely different values are used to compute network gradients with finite differences, floating-point cancellation errors quickly become catastrophic, destroying convergence. Alternatively to finite differences, exact derivatives of arithmetic expressions can also be symbolically calculated using “sympy” or similar symbolic calculation packages. Unfortunately, symbolic derivatives of non-trivial arithmetic expressions soon become intractably complex and challenging to translate into source code. Reverse-mode AD comes to the rescue, and in this post, reverse-mode AD automagically pulls exact derivatives in PyTorch out of its hat. Besides the documentation, (1) covers the details of AD and it is referenced in (2). There is also a great video describing AD.

Thanks to Alban D from the PyTorch forum for helping me figure out how autograd::grad uses the Jacobian.

In a nutshell, the reverse-mode AD in autograd works with Jacobians: matrices that contain partial derivatives of a function with respect to a vector. Those partial derivatives from the Jacobian matrix can be combined with each other to construct differential operators. Instead of picking elements from the Jacobian “by hand” somehow, matrix and vector inner products can be used to “select” and combine the elements of the Jacobian. This is what ‘autograd’ does, it doesn’t compute the gradient $\nabla f$ of a function $f$ “only”, instead, autograd::grad computes the inner product of the Jacobian and some tensor $\mathbf{v}$, namely $J\cdot\mathbf{v}$, where $J$ is the Jacobian of the function $f$ with respect to some tensor, and $\mathbf{v}$ is a tensor whose contents determine wether $\nabla$, $\nabla\cdot$, or $\nabla\times$ is computed, or something else entirely. When $f$ is not a real-valued function, which is the case in neural network training, we have to determine $\mathbf{v}$.

The same conclusion is stated in the documentation

The graph is differentiated using the chain rule. If any of tensors are non-scalar (i.e. their data has more than one element) and require gradient, then the Jacobian-vector product would be computed, in this case the function additionally requires specifying grad_tensors. It should be a sequence of matching length, that contains the “vector” in the Jacobian-vector product, usually the gradient of the differentiated function w.r.t. corresponding tensors (None is an acceptable value for all tensors that don’t need gradient tensors).

In other words, if we want a gradient of a scalar function, we just call the gradient. If we want other operators, or we work with vector-valued functions, we have to think which elements of the Jacobian we want to combine in which way, and design the tensor $\mathbf{v}$ in $J \cdot \mathbf{v}$ to get the operators we need. There is nothing special about the Jacobian. For example, if $y = f(x), f : \mathbb{R}^n \to \mathbb{R}$, then

$$ J = [\partial_{x_1} f \ \partial_{x_2} f , \dots , \ \partial_{x_n} f ]$$

is a vector in $R^n$. If $y = f(x), f : \mathbb{R}^n \to \mathbb{R}^n$, then, $J$ is a $n \times n$ matrix.

Example: real-valued function of a real variable

In this simple case, $sin'(x), sin''(x)$ are calculated exactly with torch::autograd.

    auto x = torch::zeros({1}, torch::requires_grad());
    auto sinx = torch::sin(x); 
    
    auto dsinx = torch::autograd::grad({sinx}, {x}, {}, true, true)[0]; 
    auto dsinx_e = torch::cos(x);
    
    auto dsinx_error = torch::abs(dsinx_e - dsinx).item<double>();
    std::cout << std::setprecision(20) 
        << "dsinx_error = " << dsinx_error << "\n";
    assert(dsinx_error == 0);

    auto ddsinx = torch::autograd::grad({dsinx}, {x})[0]; 
    auto ddsinx_e =  -torch::sin(x); 

    auto ddsinx_error = torch::abs(ddsinx_e - ddsinx).item<double>();
    std::cout << "ddsinx_error = " << ddsinx_error << "\n"; 
    assert(ddsinx_error == 0);

The first call to torch::autograd::grad has some additional arguments that require an explanation. Without going into details about AD, the thing to remember is that AD constructs the final arithmetic expresion from sub-expressions. Building the expression this way represents it as an acyclic graph. The partial derivatives are then stored by the AD mechanism at graph nodes starting from the leafs up to the final expression. Those partial derivatives are used in the chain rule to construct the gradient by combining partial derivatives stored at graph nodes above the root node. Since $sin(x)$ is a real-valued function of a real variable, there is no need to provide a tensor $\mathbf{v}$ for the dot product with the Jacobian. The other two arguments make sure this computation graph is not deleted, because we want to compute $sin''(x)$, from the documentation

retain_graph: If false, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to true is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph. create_graph: If true, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: false.

The _e exact derivatives are computed (here trivially) manually, and we see that the first and second derivatives computed by autograd are exactly the same as their exact counterparts.

The == 0 check for derivative errors makes this really interesting for anyone that was bitten by the IEEE 754 standard for floating-point arithmetic. Looking at the reverse-mode AD as a black-box, it delivers the same results as symbolic calculation!

Another example from the video on AD, shows how partial derivatives are calculated:

    // https://youtu.be/R_m4kanPy6Q?t=458
    // f(x,y) = (x + y) * (y + 3)
    // \partial_x f(x,y) = y + 3 
    // \partial_y f(x,y) = 2y + x + 3 
    // for x = 1, y = 2, 
    // \partial_x f(x,y) = 2 + 3 = 5
    // \partial_y f(x,y) = 2*2 + 1 + 3 = 8
    {
        auto x = torch::ones(1, torch::requires_grad());
        auto y = torch::full_like(x, 2., torch::requires_grad());
        auto f = (x + y) * (y + 3);

        auto partial_x_f = torch::autograd::grad({f}, {x}, {}, true, true)[0];
        assert((partial_x_f.item<double>() == 5));

        auto partial_y_f = torch::autograd::grad({f}, {y})[0];
        assert((partial_y_f.item<double>() == 8));
    }

Note: in the first call to autograd::grad, retain_graph and create_graph are set to true, so that partial_y_f can be calculated by traversing the existing graph.

Example: real-valued function of a vector variable

In this example, the inner (dot) product between two vectors is used: a real-valued function of a vector variable $f(\mathbf{x},\mathbf{y}) : \mathbb{R}^n \to \mathbb{R}$, namely

$$f(\mathbf{x},\mathbf{y}) = \sum_{i = 1}^{n} x_i y_i$$

A gradient of this real-valued function with respect to one of its input vectors is a vector $\mathbf{g} \in \mathbb{R^n}$. An interesting example is $||\mathbf{x}||_2^2$, given as $\mathbf{x}\cdot\mathbf{x}$, or

$$f(\mathbf{x},\mathbf{x}) = \sum_{i = 1}^{n} x_i x_i$$

In this case, the gradient is equal to the Jacobian, it is a $\mathbb{R}^3$ vector,

$$ \mathbf{g} = J = [\partial_{x_1} f \ \partial_{x_2} f \ \partial_{x_3} f ] = 2[x_1 \ x_2 \ x_3]$$.

The torch::autograd computes this as expected

    auto x = torch::ones(3,torch::requires_grad());
    auto f = dot(x,x); 
    auto grad_f_x = torch::autograd::grad({f}, {x}, {}, true, true);

Now, since $f=dot(x,x)$ is a scalar-valued function, there is no need to specify $\mathbf{v}$ in $J \cdot \mathbf{v}$. This is not the case when using torch::autograd to compute the divergence of $\mathbf{g}$ ($\nabla \cdot \mathbf{g}$). The Jacobian of $\mathbf{g}$ is

$$ J_\mathbf{g} = [\partial_{x_1} \mathbf{g} \ \partial_{x_2} \mathbf{g} \ \partial_{x_3} \mathbf{g}] = \begin{bmatrix} \partial^2_{x_1} f & \partial_{x_2} \partial_{x_1} f & \partial_{x_3} \partial_{x_1}f \\\
\partial_{x_1} \partial_{x_2} f & \partial^2_{x_2} f & \partial_{x_3} \partial_{x_2} f \\\
\partial_{x_1} \partial_{x_3} f & \partial_{x_2} \partial_{x_3} f & \partial^2_{x_3} f \end{bmatrix} $$

In this example $\nabla \cdot \mathbf{g}$ can be computed by summing up the diagonal elements (compute the trace of) $J_\mathbf{g}$ like this

$$ \nabla \cdot \mathbf{g} = \text{trace}(J_\mathbf{g}) = J_g \cdot \begin{bmatrix} 1 \\\
1 \\\
1 \end{bmatrix} \cdot \begin{bmatrix} 1 \\\
1 \\\
1 \end{bmatrix} = 2\begin{bmatrix} 1 & 0 & 0 \\\
0 & 1 & 0 \\\
0 & 0 & 1 \end{bmatrix} \cdot \begin{bmatrix} 1 \\\
1 \\\
1 \end{bmatrix} \cdot \begin{bmatrix} 1 \\\
1 \\\
1 \end{bmatrix} = 6, $$

In torch::autograd this is

    auto div_grad_f_x_v =  torch::autograd::grad(
        {grad_f_x}, {x}, {torch::ones(3)}, true, true
    );
    std::cout << "div(grad(f)) = (J_{grad_f} . [1 1 1]) . [1 1 1] = " 
        << div_grad_f_x_v[0].sum() << endl;

Important: using $\mathbf{v} = [1 \ 1 \ 1]^T$ like this only works in this specific example because $J_\mathbf{g}$ is diagonal!

In other words, {torch::ones(3)} can be used in the above code snippet for $\mathbf{v}$ in $J_\mathbf{g} \cdot \mathbf{v}$ only if $J_\mathbf{g}$ is diagonal, otherwise we will pick up other partial derivatives from the Jacobian.

A general solution that only picks up the diagonal elements of $J_\mathbf{g}$ requires the identity matrix

$$ \nabla \cdot \mathbf{g} = (J_\mathbf{g} \cdot I) \cdot [1 \ 1 \ 1] = \left(2\begin{bmatrix} 1 & 0 & 0 \\\
0 & 1 & 0 \\\
0 & 0 & 1 \end{bmatrix} \cdot \begin{bmatrix} 1 & 0 & 0 \\\
0 & 1 & 0 \\\
0 & 0 & 1 \end{bmatrix} \right) \cdot \begin{bmatrix} 1 & 1 & 1 \end{bmatrix} = 6. $$

In torch::autograd

    auto diag_grad2_f_x = torch::autograd::grad(
        {grad_f_x}, {x}, {torch::eye(3)}
    );
    auto div_grad_f_x = diag_grad2_f_x[0].sum();
    std::cout << "div(grad(f)) = (J_{grad_f} . I) . [1 1 1] = " 
        << div_grad_f_x << std::endl;

Example: vector-valued function of a vector variable

Say we have a tensor of input values $\mathbf{x} \in \mathbb{R}^n$, and we evaluate a function $g$ at each $x_i$, such that $y_i = g(x_i)$ and $\mathbf{y} \in \mathbb{R}^n$, then we have a vector-valued function of a vector variable, namely $f : \mathbb{R}^n \to \mathbb{R}^n$, and $g$ is applied to each $x_i$ . How to compute tensors (sequences) of individual derivatives $y_i', y_i'', \dots$ with torch::autograd?

Let’s use $g(x_i) = \sin(x_i)$, and $\mathbf{x} = [0, 0.001, 0.002, \dots 1]$ for example.

Since

$$\mathbf{y} = f(\mathbf{x}) = [sin(x_1) \ sin(x_2) \ \dots sin(x_n)],$$

the Jacobian

$$J_f = [\partial_{x_1} \mathbf{y} \ \partial_{x_2} \mathbf{y} \dots \partial_{x_n} \mathbf{y}] $$

is again diagonal, namely

$$ J_f = \begin{bmatrix} cos(x_1) & 0 & 0 \\\
0 & cos(x_2) & 0 \\\
\vdots & \ddots & \vdots \\\
0 & 0 & cos(x_n) \end{bmatrix}. $$

The calculation of $\mathbf{y}$ is done exactly like in the previous example, only, depending on $n$, the difference in computational time between

$$\nabla \cdot \nabla f = (J_f \cdot I) \cdot [1 \ 1 \ 1 \ \dots 1]$$

and

$$\nabla \cdot \nabla f = (J_f \cdot [1 \ 1 \ 1 \ \dots 1]) \cdot [1 \ 1 \ 1 \ \dots 1] $$

can be significant.

Important: when working with torch::autograd, know your Jacobian.

Summary

The torch::autograd uses reverse-mode Automatic Differentiation to compute derivatives exactly. The calculation computes the Jacobian of the expression dotted with a vector (tensor, matrix) $\mathbf{v}$, $J\cdot \mathbf{v}$, that is defined by the user. The shape and contents of $\mathbf{v}$, and subsequent matrix-vector operations are used to compute differential operators like $\nabla$, $\nabla \cdot$ and $\nabla \times$.

Data

The code is available on GitLab.

References

(1) Griewank, A., & Walther, A. (2008). Evaluating derivatives: principles and techniques of algorithmic differentiation. Society for Industrial and Applied Mathematics.

(2) Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., … Chintala, S. (2019). PyTorch: An imperative style, high-performance deep learning library. Advances in Neural Information Processing Systems, 32(NeurIPS).