Computing Sparse Hessians¶
asdex computes sparse Hessians for scalar-valued functions
\(f: \mathbb{R}^n \to \mathbb{R}\)
using symmetric coloring and forward-over-reverse AD.
Verify correctness at least once
asdex's sparsity patterns should always be conservative, but a bug in sparsity detection could cause missing nonzeros. Always verify against vanilla JAX at least once on a new function. See Verifying Results below.
Basic Usage¶
Pass your scalar-valued function and its input_shape to hessian:
This runs the computationally expensive sparsity detection and coloring steps when defining hess_fn.
Subsequent calls to hess_fn only need to perform the cheap decompression step.
The result is a JAX BCOO sparse matrix.
The same function can be reused across evaluations at different inputs:
asdex supports multi-dimensional input arrays.
The Hessian is always returned as a 2D matrix
of shape \((n, n)\) where \(n\) is the total number of input elements.
Precomputing the Colored Pattern¶
For more control, precompute the coloring explicitly:
import jax.numpy as jnp
from asdex import hessian_coloring, hessian_from_coloring
def g(x):
return jnp.sum((1 - x[:-1]) ** 2 + 100 * (x[1:] - x[:-1] ** 2) ** 2)
coloring = hessian_coloring(g, input_shape=100)
ColoredPattern(100×100, nnz=298, sparsity=97.0%, HVP, 3 colors)
3 HVPs (instead of 100 HVPs)
⎡⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⡷⡇⎤
⎢⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡅⎥
⎢⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣯⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡅⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣯⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡅⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡃⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣟⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡃⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⎥ ⎢⣟⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⎥ ⎢⣿⠆⎥
⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⎦ ⎣⣿⠃⎦
This is useful when you want to visually inspect the coloring for correctness,
or save it to disk to avoid recomputation.
Pass the coloring to hessian_from_coloring to compute the Hessian:
Tip
If your coloring looks wrong or overly dense,
please help out asdex's development by
reporting it.
These reports directly drive improvements
and are one of the most impactful ways to contribute.
Saving and Loading Patterns¶
Save a coloring to disk and reload it in a later session:
from asdex import hessian_coloring
coloring = hessian_coloring(g, input_shape=100)
coloring.save("colored.npz")
from asdex import ColoredPattern, hessian_from_coloring
coloring = ColoredPattern.load("colored.npz")
hess_fn = hessian_from_coloring(g, coloring)
SparsityPattern supports the same save/load interface.
Symmetric Coloring¶
Hessians are symmetric (\(H = H^\top\)),
and asdex exploits this with star coloring
(Gebremedhin et al., 2005).
Symmetric coloring typically needs fewer colors than row or column coloring,
since both \(H_{ij}\) and \(H_{ji}\) can be recovered from a single coloring.
The convenience functions hessian_coloring and hessian use symmetric coloring automatically.
Here we use the Rosenbrock function,
a classic optimization benchmark whose Hessian is tridiagonal:
import jax.numpy as jnp
from asdex import hessian_coloring
def rosenbrock(x):
return jnp.sum((1 - x[:-1]) ** 2 + 100 * (x[1:] - x[:-1] ** 2) ** 2)
coloring = hessian_coloring(rosenbrock, input_shape=100)
ColoredPattern(100×100, nnz=298, sparsity=97.0%, HVP, 3 colors)
3 HVPs (instead of 100 HVPs)
⎡⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⡷⡇⎤
⎢⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡅⎥
⎢⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣯⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡅⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣯⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡅⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡃⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣟⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⠆⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⡃⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⠀⎥ ⎢⡷⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⎥ ⎢⣟⡇⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⡀⠀⎥ ⎢⣿⠆⎥
⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣦⎦ ⎣⣿⠃⎦
Separate Detection and Coloring¶
For even more control, you can split detection and coloring:
from asdex import hessian_sparsity, hessian_coloring_from_sparsity
sparsity = hessian_sparsity(g, input_shape=100)
coloring = hessian_coloring_from_sparsity(sparsity)
Since the Hessian is the Jacobian of the gradient,
hessian_sparsity simply calls jacobian_sparsity(jax.grad(f), input_shape).
The sparsity interpreter composes naturally with JAX's autodiff transforms.
This is useful when you want to manually provide a sparsity pattern.
Manually Providing a Sparsity Pattern¶
You can provide a sparsity pattern manually if you already know it ahead of time.
Create a SparsityPattern from coordinate arrays, a dense matrix, or a JAX BCOO matrix.
From a dense boolean or numeric matrix:
import numpy as np
from asdex import SparsityPattern
dense = np.array([[1, 1, 0, 0],
[1, 1, 1, 0],
[0, 1, 1, 1],
[0, 0, 1, 1]])
sparsity = SparsityPattern.from_dense(dense)
From row and column index arrays:
sparsity = SparsityPattern.from_coo(
rows=[0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
cols=[0, 1, 0, 1, 2, 1, 2, 3, 2, 3],
shape=(4, 4),
)
From a JAX BCOO sparse matrix:
Finally, color the sparsity pattern and compute the Hessian:
from asdex import hessian_coloring_from_sparsity, hessian_from_coloring
coloring = hessian_coloring_from_sparsity(sparsity)
hess_fn = hessian_from_coloring(f, coloring)
H = hess_fn(x)
Choosing an HVP Mode¶
By default, hessian uses forward-over-reverse AD to compute Hessian-vector products.
You can select a different AD composition strategy via the mode parameter:
from asdex import hessian
hess_fn_for = hessian(f, input_shape=100, mode="fwd_over_rev") # default
hess_fn_rof = hessian(f, input_shape=100, mode="rev_over_fwd")
hess_fn_ror = hessian(f, input_shape=100, mode="rev_over_rev")
All three modes produce the same mathematical result. They differ in how JAX's AD primitives are composed:
fwd_over_rev(default):jvp(grad(f), ...). Generally the fastest under JIT.rev_over_fwd:grad(lambda p: jvp(f, (p,), (v,))[1]). Can use less memory than forward-over-reverse for functions with many intermediates.rev_over_rev:grad(lambda y: dot(grad(f)(y), v)). Avoids forward-mode entirely; useful when forward-mode is expensive or unsupported.
Tip
When in doubt, stick with the default "fwd_over_rev".
It is the most widely used and typically the most efficient under jax.jit.
Verifying Results¶
Use check_hessian_correctness
to verify asdex's sparse Hessian against vanilla JAX.
from asdex import check_hessian_correctness, hessian_coloring
coloring = hessian_coloring(g, input_shape=x.shape)
check_hessian_correctness(g, x, coloring)
Use verification for debugging and initial setup, not in production loops. A good place to call it is in your test suite.
By default, this uses randomized matrix-vector products (method="matvec")
to check the sparse Hessian against an HVP reference.
The AD mode is derived from the coloring.
This is cheap — O(k) in the number of probes — and scales to large problems.
If the results match, the function returns silently.
If they disagree, it raises a VerificationError.
You can also set custom tolerances, the number of probes, and the PRNG seed:
For an exact element-wise comparison against the full dense Hessian,
use method="dense":
Dense computation
method="dense" materializes the full dense Hessian,
which is computationally very expensive for large problems.