Skip to content

Getting Started

This tutorial walks through the three stages of automatic sparse differentiation: detection, coloring, and decompression.

The Problem

For a function \(f: \mathbb{R}^n \to \mathbb{R}^m\), computing the full Jacobian \(J \in \mathbb{R}^{m \times n}\) requires \(n\) forward-mode or \(m\) reverse-mode AD passes. In practice, for example in scientific machine learning, many Jacobians are sparse (i.e., most entries are structurally zero, regardless of the input).

asdex exploits this sparsity in three steps: 1. Detect the sparsity pattern by tracing the computation graph 2. Color the pattern so that structurally orthogonal rows (or columns) share a color 3. Decompress one AD pass per color into the sparse Jacobian or Hessian

This reduces the computational cost from \(m\) (or \(n\)) AD passes to just the number of colors, yielding significant speedups on large sparse problems, especially when the cost of detection and coloring can be amortized over repeated evaluations. The same approach applies to sparse Hessians via forward-over-reverse AD.

Sparse Jacobians

Consider the squared differences function \(f(x)_i = (x_{i+1} - x_i)^2\), which has a banded Jacobian. Detect the sparsity pattern and color it in one step:

import jax.numpy as jnp
from asdex import jacobian_coloring

def f(x):
    return (x[1:] - x[:-1]) ** 2

x = jnp.ones(50)

coloring = jacobian_coloring(f, input_shape=x.shape)
ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
  2 JVPs (instead of 49 VJPs or 50 JVPs)
⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤   ⎡⣿⎤
⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥   ⎢⣿⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥   ⎢⣿⎥
⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦   ⎣⠉⎦

The print-out shows the original sparsity pattern (left) compressed into just two colors (right). asdex automatically ran multiple coloring algorithms, selected column coloring (2 JVPs) over row coloring (2 VJPs) since JVPs are cheaper, reducing the cost from 49 VJPs or 50 JVPs without coloring to just 2 JVPs. Note that on small problems, this doesn't directly translate into a speedup of factor 25x, as the decompression overhead dominates.

Global Sparsity Patterns

The detected pattern is a global sparsity pattern: it depends only on the function's structure, not on any particular input. This means it may contain extra nonzeros compared to the sparsity at a specific point, but it is guaranteed to be correct everywhere and can therefore be reused.

If you encounter overly conservative patterns, please open an issue. These reports directly drive improvements and are one of the most impactful ways to contribute.

Now we can compute the sparse Jacobian using the coloring:

from asdex import jacobian_from_coloring

jac_fn = jacobian_from_coloring(f, coloring)
J = jac_fn(x)

The result is a JAX BCOO sparse matrix. We can verify that asdex produces the same result as jax.jacobian:

import jax
import numpy as np

J_asdex = J.todense()
J_jax = jax.jacobian(f)(x)

np.testing.assert_allclose(J_asdex, J_jax, atol=1e-6)

asdex also provides check_jacobian_correctness as a convenience for this comparison — see Verifying Results.

On larger problems, the speedup from coloring becomes significant. Let's benchmark on a 5000-dimensional input (note that timings may vary as part of the doc-building process). This time, we use asdex.jacobian, which calls jacobian_coloring and jacobian_from_coloring:

import asdex
import jax
import timeit

n = 5000
x = jnp.ones(n)

jac_fn_asdex = asdex.jacobian(f, input_shape=n)
jac_fn_jax = jax.jacobian(f)

# Warm up JIT caches
_ = jac_fn_asdex(x)
_ = jac_fn_jax(x)

t_asdex = timeit.timeit(lambda: jac_fn_asdex(x).block_until_ready(), number=10) / 10
t_jax = timeit.timeit(lambda: jac_fn_jax(x).block_until_ready(), number=10) / 10
asdex.jacobian:      4.82 ms
jax.jacobian:      112.08 ms
speedup:             23.2x

Precompute for Repeated Evaluations

The coloring depends only on the function structure, not the input values. When computing Jacobians at many different inputs, precompute the coloring once and reuse it:

jac_fn = jacobian(f, input_shape=5000)

for x in inputs:
    J = jac_fn(x)

Sparse Hessians

For scalar-valued functions \(f: \mathbb{R}^n \to \mathbb{R}\), asdex can detect Hessian sparsity and compute sparse Hessians:

from asdex import hessian

def g(x):
    return jnp.sum(x ** 2)

hess_fn = hessian(g, input_shape=20)

for x in inputs:
    H = hess_fn(x)

Next Steps