Getting Started¶
This tutorial walks through the three stages of automatic sparse differentiation: detection, coloring, and decompression.
The Problem¶
For a function \(f: \mathbb{R}^n \to \mathbb{R}^m\), computing the full Jacobian \(J \in \mathbb{R}^{m \times n}\) requires \(n\) forward-mode or \(m\) reverse-mode AD passes. In practice, for example in scientific machine learning, many Jacobians are sparse (i.e., most entries are structurally zero, regardless of the input).
asdex exploits this sparsity in three steps:
1. Detect the sparsity pattern by tracing the computation graph
2. Color the pattern so that structurally orthogonal rows (or columns) share a color
3. Decompress one AD pass per color into the sparse Jacobian or Hessian
This reduces the computational cost from \(m\) (or \(n\)) AD passes to just the number of colors, yielding significant speedups on large sparse problems, especially when the cost of detection and coloring can be amortized over repeated evaluations. The same approach applies to sparse Hessians via forward-over-reverse AD.
Sparse Jacobians¶
Consider the squared differences function \(f(x)_i = (x_{i+1} - x_i)^2\), which has a banded Jacobian. Detect the sparsity pattern and color it in one step:
import jax.numpy as jnp
from asdex import jacobian_coloring
def f(x):
return (x[1:] - x[:-1]) ** 2
x = jnp.ones(50)
coloring = jacobian_coloring(f, input_shape=x.shape)
ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
2 JVPs (instead of 49 VJPs or 50 JVPs)
⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⣿⎤
⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥ ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥ ⎢⣿⎥
⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦ ⎣⠉⎦
The print-out shows the original sparsity pattern (left) compressed into just two colors (right).
asdex automatically ran multiple coloring algorithms,
selected column coloring (2 JVPs) over row coloring (2 VJPs) since JVPs are cheaper,
reducing the cost from 49 VJPs or 50 JVPs without coloring to just 2 JVPs.
Note that on small problems, this doesn't directly translate into a speedup of factor 25x,
as the decompression overhead dominates.
Global Sparsity Patterns
The detected pattern is a global sparsity pattern: it depends only on the function's structure, not on any particular input. This means it may contain extra nonzeros compared to the sparsity at a specific point, but it is guaranteed to be correct everywhere and can therefore be reused.
If you encounter overly conservative patterns, please open an issue. These reports directly drive improvements and are one of the most impactful ways to contribute.
Now we can compute the sparse Jacobian using the coloring:
The result is a JAX
BCOO sparse matrix.
We can verify that asdex produces the same result as jax.jacobian:
import jax
import numpy as np
J_asdex = J.todense()
J_jax = jax.jacobian(f)(x)
np.testing.assert_allclose(J_asdex, J_jax, atol=1e-6)
asdex also provides check_jacobian_correctness
as a convenience for this comparison —
see Verifying Results.
On larger problems, the speedup from coloring becomes significant.
Let's benchmark on a 5000-dimensional input
(note that timings may vary as part of the doc-building process).
This time, we use asdex.jacobian, which calls jacobian_coloring and jacobian_from_coloring:
import asdex
import jax
import timeit
n = 5000
x = jnp.ones(n)
jac_fn_asdex = asdex.jacobian(f, input_shape=n)
jac_fn_jax = jax.jacobian(f)
# Warm up JIT caches
_ = jac_fn_asdex(x)
_ = jac_fn_jax(x)
t_asdex = timeit.timeit(lambda: jac_fn_asdex(x).block_until_ready(), number=10) / 10
t_jax = timeit.timeit(lambda: jac_fn_jax(x).block_until_ready(), number=10) / 10
Precompute for Repeated Evaluations
The coloring depends only on the function structure, not the input values. When computing Jacobians at many different inputs, precompute the coloring once and reuse it:
Sparse Hessians¶
For scalar-valued functions \(f: \mathbb{R}^n \to \mathbb{R}\),
asdex can detect Hessian sparsity and compute sparse Hessians:
from asdex import hessian
def g(x):
return jnp.sum(x ** 2)
hess_fn = hessian(g, input_shape=20)
for x in inputs:
H = hess_fn(x)
Next Steps¶
- Computing Sparse Jacobians — Guide on Jacobian computation
- Computing Sparse Hessians — Guide on Hessian computation
- Sparsity Detection — Explanation how sparsity patterns are detected
- Graph Coloring — Explanation how coloring reduces cost