Adrian Hill

The chain rule is all you need: an introduction to mathematical foundations of automatic differentiation.


Motivation

To apply gradient-based optimization methods such as stochastic gradient descent to a neural network, we need to compute the gradient of its loss function with respect to its parameters.

Since deep learning models can get large and complicated, it would be nice to have machinery that can take an arbitrary function $f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ and return its derivative. This is called automatic differentiation (AD).

The Julia AD ecosystem

Julia has more than a dozen AD systems. A summary of available packages can be found at juliadiff.org. The list is sorted by type:

  • Reverse-mode
  • Forward-mode
  • Symbolic
  • Finite differencing

and other more exotic approaches. These are already abstract sounding terms, but within these categories, there are further differences:

  • Is the AD system operator overloading or source-to-source?
  • Which representation level does it operate on?
  • Does it only work on scalar functions?
  • Does it allow higher-order AD?

As you may not be familiar with these terms, the goal of this lecture is to explain differences in approaches between various AD packages and outline their pros and cons.

For this purpose, we will take a step back and start with a recapitulation of two fundamental mathematical concepts: linear maps and derivatives.

Linear maps

Properties

Linear maps, also called linear transformations, are functions with the following properties:

PropertyEquationproperty satisfied1property not satisfied1
Additivity$f(v_1+v_2) = f(v_1) + f(v_2)$
Homogeneity$f(\lambda v) = \lambda f(v)$
1

Visualizations by Stephan Kulla, CC0.

Mathematically more rigorous definition:

Assuming two arbitrary vector spaces $V, W$ over the field $K$, a function $f:V\rightarrow W$ is called a linear map if additivity and homogeneity are satisfied for any vectors $v_1, v_2 \in V$ and $\lambda \in K$.

Connection to matrices

Every linear map $f$ between two finite-dimensional vector spaces $V, W$ can be represented as a matrix, given a basis for each vector space (e.g. the standard basis).

A linear map $f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ can be represented as

$$ f(x) = Ax $$

where $A$ is a $m \times n$ matrix and $x \in \mathbb{R}^{n}$.

Composition

Connection to matrix multiplication

The composition $h(x) = g(f(x))$ of two linear maps $f: V \rightarrow W$, $g: W \rightarrow Z$ is also a linear map $h: V \rightarrow Z$.

In finite-dimensional vector spaces, the composition of linear maps corresponds to matrix multiplication:

$$ \begin{aligned} f(x) &= Fx &, \enspace &f: \mathbb{R}^{m} \rightarrow \mathbb{R}^{n} &, \enspace &F \in \mathbb{R}^{n \times m} \\ g(x) &= Gx &, \enspace &g: \mathbb{R}^{n} \rightarrow \mathbb{R}^{p} &, \enspace &G \in \mathbb{R}^{p \times n} \\ h(x) = (g \circ f)(x) &= (G \cdot F) x = Hx &, \enspace &h: \mathbb{R}^{m} \rightarrow \mathbb{R}^{p} &, \enspace &H \in \mathbb{R}^{p \times m} \\ \end{aligned} $$

Connection to matrix addition

The sum of two linear maps $f_1: V \rightarrow W$, $f_2: V \rightarrow W$ is also a linear map:

$(f_1 + f_2)(x) = f_1(x) + f_2(x) \quad$

In finite-dimensional vector spaces, the addition of linear maps corresponds to matrix addition.

For $f_1$ and $f_2: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ and $A, B \in \mathbb{R}^{m \times n}$

$$ \begin{aligned} f_1(x) &= Ax \\ f_2(x) &= Bx \\ (f_1 + f_2)(x) &= (A+B)x \quad . \end{aligned} $$

Derivatives

What is a derivative?

The (total) derivative of a function $f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ at a point $\tilde{x} \in \mathbb{R}^{n}$ is the linear approximation of $f$ near the point $\tilde{x}$.

We give the derivative the symbol $\mathcal{D}f_{\tilde{x}}$. You can read this as "$\mathcal{D}$erivative of $f$ at $\tilde{x}$".

Most importantly, the derivative is a linear map

$$ \mathcal{D}f_{\tilde{x}}: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m} \quad . $$

Let's visualize this on a simple scalar function:

$$ f(x) = x^2 - 5 \sin(x) - 10 $$

Plot showing the function f(x) = x² - 5sin(x) - 10 in blue, its derivative as an orange line through the origin, and the first-order Taylor approximation in green tangent to f at the expansion point.

The orange line $\mathcal{D}f_{\tilde{x}}$ is of biggest interest to us. Notice how the derivative fulfills homogeneity: it always goes through the origin $(x,y)=(0,0).$

Using $\mathcal{D}f_{\tilde{x}}$, we can construct the first order Taylor series approximation of $f$ around $\tilde{x}$ (shown in green). For points close to $\tilde{x},$

$$ f(x) \approx f(\tilde{x}) + \mathcal{D}f_{\tilde{x}}(x-\tilde{x}) \quad . $$

Differentiability

From your calculus classes, you might recall that a function $f: \mathbb{R} \rightarrow \mathbb{R}$ is differentiable at $\tilde{x}$ if there is a number $f'(\tilde{x})$ such that

$$ \lim_{h \rightarrow 0} \frac{f(\tilde{x} + h) - f(\tilde{x})}{h} = f'(\tilde{x}) \quad . $$

This number $f'(\tilde{x})$ is called the derivative of $f$ at $\tilde{x}$.

We can now extend this notion to multivariate functions: A function $f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ is totally differentiable at a point $\tilde{x}$ if there exists a linear map $\mathcal{D}f_{\tilde{x}}$ such that

$$ \lim_{h \rightarrow 0} \frac{|f(\tilde{x} + h) - f(\tilde{x}) - \mathcal{D}f_{\tilde{x}}(h)|}{|h|} = 0 \quad . $$

Jacobians

Linear maps $f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ can be represented as a $m \times n$ matrices.

In the standard basis, the matrix corresponding to $\mathcal{D}f$ is called the Jacobian:

$$ J_f = \begin{bmatrix} \dfrac{\partial f_1}{\partial x_1} & \cdots & \dfrac{\partial f_1}{\partial x_n}\\ \vdots & \ddots & \vdots\\ \dfrac{\partial f_m}{\partial x_1} & \cdots & \dfrac{\partial f_m}{\partial x_n} \end{bmatrix} $$

Note that every entry $[J_f]_{ij}=\frac{\partial f_i}{\partial x_j}$ in this matrix is a scalar function $\mathbb{R} \rightarrow \mathbb{R}$.

If we evaluate the Jacobian at a specific point $\tilde{x}$, we get the matrix corresponding to $\mathcal{D}f_{\tilde{x}}$:

$$ J_f\big|_{\tilde{x}} = \begin{bmatrix} \dfrac{\partial f_1}{\partial x_1}\Bigg|_{\tilde{x}} & \cdots & \dfrac{\partial f_1}{\partial x_n}\Bigg|_{\tilde{x}}\\ \vdots & \ddots & \vdots\\ \dfrac{\partial f_m}{\partial x_1}\Bigg|_{\tilde{x}} & \cdots & \dfrac{\partial f_m}{\partial x_n}\Bigg|_{\tilde{x}} \end{bmatrix} \in \mathbb{R}^{m \times n} $$

Jacobian-Vector products

As we have seen in the example above, the total derivative2

$$ \mathcal{D}f_{\tilde{x}}(v) = J_f\big|_{\tilde{x}} \cdot v $$

computes a Jacobian-Vector product. It is also called the "pushforward" and is one of the two core primitives behind AD systems:

  1. Jacobian-Vector products (JVPs) computed by the pushforward, used in forward-mode AD
  2. Vector-Jacobian products (VJPs) computed by the pullback, used in reverse-mode AD
2

In our notation, all vectors $v$ are column vectors and row vectors are written as $v^T$.

Chain rule

Let's look at a function $h(x)=g(f(x))$ composed from two differentiable functions

$f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}$ and $g: \mathbb{R}^{m} \rightarrow \mathbb{R}^{p} $:

$$ h = g \circ f, . $$

Since derivatives are linear maps, we can obtain the derivate of $h$ by composing the derivatives of $g$ and $f$ using the chain rule:

$$ \mathcal{D}h_{\tilde{x}} = \mathcal{D}(g \circ f)_{\tilde{x}} = \mathcal{D}g_{f(\tilde{x})} \circ \mathcal{D}f_{\tilde{x}} $$

As we have seen in the section on linear maps, this composition of linear maps is also a linear map. It corresponds to simple matrix multiplication.3

3

A proof of the chain rule can be found on page 19 of Spivak's Calculus on Manifolds.