Skip to content

Jacobian

Jacobian Computation

asdex.jacobian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=False, output_format='bcoo')

Detect sparsity, color, and return a function computing sparse Jacobians.

Combines jacobian_coloring and jacobian_from_coloring in one call.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

Whether f returns (output, auxiliary_data), mirroring jax.jacrev. When True, the returned function yields (jac, aux).

False
holomorphic bool

Whether f is promised to be holomorphic, mirroring jax.jacrev. Validates dtype compatibility at call time.

False
allow_int bool

Whether to allow differentiating with respect to integer-valued inputs, mirroring jax.jacrev.

False
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). None picks whichever of fwd/rev needs fewer colors.

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
output_format OutputFormat

Type of the output matrix. "bcoo" returns a sparse matrix of type jax.experimental.sparse.BCOO (default), "dense" returns a dense matrix of type jax.Array.

'bcoo'

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns a pytree of Jacobian blocks matching argnums, with each leaf shaped (*out_shape, *in_leaf_shape). The block type depends on output_format (jax.experimental.sparse.BCOO by default, or jax.Array when "dense").

asdex.value_and_jacobian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=False, output_format='bcoo')

Detect sparsity, color, and return a function computing value and sparse Jacobian.

Like jacobian, but also returns the primal value f(*args) without an extra forward pass.

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, jac) — or ((value, aux), jac) when has_aux=True, matching jax.value_and_grad ordering.

asdex.jacobian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)

Build a sparse Jacobian function from a pre-computed coloring.

Uses row coloring + VJPs or column coloring + JVPs, depending on which needs fewer colors.

The returned callable accepts *args, **kwargs; kwargs are forwarded to f at call time (matching jax.jacfwd / jax.jacrev).

asdex.value_and_jacobian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)

Build a function computing value and sparse Jacobian from a pre-computed coloring.

Coloring

asdex.jacobian_coloring(f, *args, argnums=0, has_aux=False, mode=None, symmetric=False, postprocess=False)

Detect Jacobian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

If True, f is assumed to return (output, aux) where aux is auxiliary data ignored by sparsity detection.

False
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD), None picks whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

asdex.jacobian_coloring_from_sparsity(sparsity, *, mode=None, symmetric=False, postprocess=False)

Color a sparsity pattern for sparse Jacobian computation.

Assigns colors so that same-colored rows (or columns) can be computed together in a single VJP (or JVP).

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (m, n).

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). None picks whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square pattern.

False
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

Sparsity Detection

asdex.jacobian_sparsity(f, *args, argnums=0, has_aux=False)

Detect global Jacobian sparsity pattern for f.

Analyzes the computation graph structure directly, without evaluating any derivatives. The result is valid for all inputs.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian sparsity pattern is to be detected.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

Whether f returns (output, auxiliary_data). When True, only output is analyzed for sparsity; the auxiliary branch of the computation is not traced.

False

Returns:

Type Description
SparsityPattern

SparsityPattern of shape (m, n_selected) where m = prod(output_shape) and n_selected is the total flat size of the selected inputs.

Verification

asdex.check_jacobian_correctness(f, x, coloring, *, method='matvec', num_probes=25, seed=0, rtol=None, atol=None)

Verify asdex's sparse Jacobian against a JAX reference at a given input.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be verified.

required
x Any

Input at which to evaluate the Jacobian. For multi-input functions (where argnums is a tuple), pass a tuple of all positional arguments.

required
coloring ColoredPattern

Pre-computed colored pattern from :func:~asdex.jacobian_coloring.

required
method Literal['matvec', 'dense']

Verification method. "matvec" uses randomized matrix-vector products, which is O(k) in the number of probes. "dense" materializes the full dense Jacobian, which is O(n^2).

'matvec'
num_probes int

Number of random probe vectors (only used by "matvec").

25
seed int

PRNG seed for reproducibility (only used by "matvec").

0
rtol float | None

Relative tolerance for comparison. Defaults to 1e-5 for "matvec" and 1e-7 for "dense".

None
atol float | None

Absolute tolerance for comparison. Defaults to 1e-5 for "matvec" and 1e-7 for "dense".

None

Raises:

Type Description
VerificationError

If the sparse and reference Jacobians disagree.

Configuration

asdex.JacobianMode = Literal['fwd', 'rev'] module-attribute

AD mode for Jacobian computation.

"fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD).