Skip to content

asdex

CI codecov PyPI DOI Benchmarks Changelog

Automatic Sparse Differentiation in JAX.

asdex (pronounced Aztecs) exploits sparsity structure to efficiently compute sparse Jacobians and Hessians. It implements a custom Jaxpr interpreter that uses abstract interpretation to detect global sparsity patterns from the computation graph, then uses graph coloring to minimize the number of AD passes needed. Refer to our Illustrated Guide to Automatic Sparse Differentiation for more information.

Installation

pip install asdex

Or with uv:

uv add asdex

Quick Example

import numpy as np
from asdex import jacobian

def f(x):
    return (x[1:] - x[:-1]) ** 2

x = np.random.randn(1000)

jac_fn = jacobian(f, input_shape=x.shape)
J = jac_fn(x)

Instead of 999 VJPs or 1000 JVPs, asdex computes the full sparse Jacobian with just 2 JVPs.

Next Steps

Acknowledgements

This package is built with Claude Code based on previous work by Adrian Hill, Guillaume Dalle, and Alexis Montoison in the Julia programming language:

which in turn stands on the shoulders of giants — notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.