Skip to content

Full API

Differentiation

asdex.jacobian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=False, output_format='bcoo')

Detect sparsity, color, and return a function computing sparse Jacobians.

Combines jacobian_coloring and jacobian_from_coloring in one call.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

Whether f returns (output, auxiliary_data), mirroring jax.jacrev. When True, the returned function yields (jac, aux).

False
holomorphic bool

Whether f is promised to be holomorphic, mirroring jax.jacrev. Validates dtype compatibility at call time.

False
allow_int bool

Whether to allow differentiating with respect to integer-valued inputs, mirroring jax.jacrev.

False
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). None picks whichever of fwd/rev needs fewer colors.

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
output_format OutputFormat

Type of the output matrix. "bcoo" returns a sparse matrix of type jax.experimental.sparse.BCOO (default), "dense" returns a dense matrix of type jax.Array.

'bcoo'

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns a pytree of Jacobian blocks matching argnums, with each leaf shaped (*out_shape, *in_leaf_shape). The block type depends on output_format (jax.experimental.sparse.BCOO by default, or jax.Array when "dense").

asdex.value_and_jacobian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=False, output_format='bcoo')

Detect sparsity, color, and return a function computing value and sparse Jacobian.

Like jacobian, but also returns the primal value f(*args) without an extra forward pass.

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, jac) — or ((value, aux), jac) when has_aux=True, matching jax.value_and_grad ordering.

asdex.hessian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=True, output_format='bcoo')

Detect sparsity, color, and return a function computing sparse Hessians.

If f returns a squeezable shape like (1,) or (1, 1), it is automatically squeezed to scalar.

asdex.value_and_hessian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=True, output_format='bcoo')

Detect sparsity, color, and return a function computing value and sparse Hessian.

Like hessian, but also returns the primal value f(*args) without an extra forward pass.


asdex.jacobian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)

Build a sparse Jacobian function from a pre-computed coloring.

Uses row coloring + VJPs or column coloring + JVPs, depending on which needs fewer colors.

The returned callable accepts *args, **kwargs; kwargs are forwarded to f at call time (matching jax.jacfwd / jax.jacrev).

asdex.value_and_jacobian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)

Build a function computing value and sparse Jacobian from a pre-computed coloring.

asdex.hessian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)

Build a sparse Hessian function from a pre-computed coloring.

Uses symmetric (star) coloring and Hessian-vector products by default.

asdex.value_and_hessian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)

Build a function computing value and sparse Hessian from a pre-computed coloring.

Coloring

asdex.jacobian_coloring(f, *args, argnums=0, has_aux=False, mode=None, symmetric=False, postprocess=False)

Detect Jacobian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

If True, f is assumed to return (output, aux) where aux is auxiliary data ignored by sparsity detection.

False
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD), None picks whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

asdex.hessian_coloring(f, *args, argnums=0, has_aux=False, mode=None, symmetric=True, postprocess=False)

Detect Hessian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Scalar-valued function taking one or more positional arrays.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

If True, f is assumed to return (output, aux) where aux is auxiliary data ignored by sparsity detection.

False
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

None
symmetric bool

Whether to use symmetric (star) coloring. Defaults to True (exploits H = H^T for fewer colors).

True
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of HVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

asdex.jacobian_coloring_from_sparsity(sparsity, *, mode=None, symmetric=False, postprocess=False)

Color a sparsity pattern for sparse Jacobian computation.

Assigns colors so that same-colored rows (or columns) can be computed together in a single VJP (or JVP).

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (m, n).

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). None picks whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square pattern.

False
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

asdex.hessian_coloring_from_sparsity(sparsity, *, mode=None, symmetric=True, postprocess=False)

Color a sparsity pattern for sparse Hessian computation.

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (n, n).

required
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

None
symmetric bool

Whether to use symmetric (star) coloring. Defaults to True (exploits Hessian symmetry for fewer colors).

True
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of HVPs during decompression). Pruned vertices get the neutral color -1 in the output (no HVP is computed for them). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

asdex.color_rows(sparsity)

Greedy row-wise coloring for sparse Jacobian computation.

Assigns colors to rows such that no two rows sharing a non-zero column have the same color. This enables computing multiple Jacobian rows in a single VJP by using a combined seed vector.

Uses LargestFirst vertex ordering for fewer colors.

Parameters:

Name Type Description Default
sparsity SparsityPattern

SparsityPattern of shape (m, n) representing the Jacobian sparsity pattern

required

Returns:

Type Description
tuple[NDArray[int32], int]

Tuple of (colors, num_colors) where:

  • colors: Array of shape (m,) with color assignment for each row
  • num_colors: Total number of colors used

asdex.color_cols(sparsity)

Greedy column-wise coloring for sparse Jacobian computation.

Assigns colors to columns such that no two columns sharing a non-zero row have the same color. This enables computing multiple Jacobian columns in a single JVP by using a combined tangent vector.

Uses LargestFirst vertex ordering for fewer colors.

Parameters:

Name Type Description Default
sparsity SparsityPattern

SparsityPattern of shape (m, n) representing the Jacobian sparsity pattern

required

Returns:

Type Description
tuple[NDArray[int32], int]

Tuple of (colors, num_colors) where:

  • colors: Array of shape (n,) with color assignment for each column
  • num_colors: Total number of colors used

asdex.color_symmetric(sparsity, *, postprocess=False, forced_colors=None)

Greedy symmetric coloring for sparse Hessian computation.

Implements Algorithm 4.1 from Gebremedhin et al. (2007). A star coloring is a distance-1 coloring with the additional constraint that every path on 4 vertices uses at least 3 colors. Returns a :class:StarSet alongside the colors so that Hessian decompression can use hub-based extraction.

Uses LargestFirst vertex ordering.

Parameters:

Name Type Description Default
sparsity SparsityPattern

SparsityPattern of shape (n, n) representing the symmetric Hessian sparsity pattern.

required
postprocess bool

If True, replace colors that are never used as a hub color (and not forced by a diagonal entry) with -1 (neutral), then compact remaining colors down. This reduces the number of HVPs needed during decompression. Defaults to False, matching SparseMatrixColorings.jl's postprocessing=false default.

False
forced_colors NDArray[int32] | list[int] | None

Optional pre-computed color assignment of shape (n,). When provided, the algorithm verifies it satisfies the star-coloring constraints and raises :class:InvalidColoringError otherwise.

None

Returns:

Type Description
tuple[NDArray[int32], int, StarSet]

Tuple (colors, num_colors, star_set) where:

  • colors: Array of shape (n,) with color assignment for each vertex. Values are in [0, num_colors - 1] for active vertices. After postprocessing, vertices whose color is pruned have value -1 (neutral — no HVP needed for them).
  • num_colors: Number of active colors (i.e. number of HVPs).
  • star_set: :class:StarSet encoding the 2-colored star decomposition.

Raises:

Type Description
ValueError

If pattern is not square.

InvalidColoringError

If forced_colors violates a star-coloring constraint.

Visualization

asdex.spy(pattern, *, ax=None, compressed=False, cmap=None, **kwargs)

Plot a sparsity pattern or colored pattern using matplotlib.

For a SparsityPattern, plots nonzeros as filled cells on a grid. For a ColoredPattern, fills cells with their assigned color.

When compressed=True on a ColoredPattern, plots the compressed pattern after coloring instead of the original.

Parameters:

Name Type Description Default
pattern SparsityPattern | ColoredPattern

The sparsity or colored pattern to plot.

required
ax Axes | None

Matplotlib axes to plot on. If None, creates a new figure.

None
compressed bool

If True and pattern is a ColoredPattern, plot the compressed pattern instead of the original.

False
cmap Any

Matplotlib colormap for colored patterns. If None, uses tab10.

None
**kwargs Any

Extra keyword arguments passed to ax.imshow.

{}

Returns:

Type Description
Axes

The matplotlib axes with the plot.

Sparsity Detection

asdex.jacobian_sparsity(f, *args, argnums=0, has_aux=False)

Detect global Jacobian sparsity pattern for f.

Analyzes the computation graph structure directly, without evaluating any derivatives. The result is valid for all inputs.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian sparsity pattern is to be detected.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

Whether f returns (output, auxiliary_data). When True, only output is analyzed for sparsity; the auxiliary branch of the computation is not traced.

False

Returns:

Type Description
SparsityPattern

SparsityPattern of shape (m, n_selected) where m = prod(output_shape) and n_selected is the total flat size of the selected inputs.

asdex.hessian_sparsity(f, *args, argnums=0, has_aux=False)

Detect global Hessian sparsity pattern for a scalar-valued f.

Analyzes the Jacobian sparsity of the gradient function, without evaluating any derivatives. The result is valid for all inputs.

If f returns a squeezable shape like (1,) or (1, 1), it is automatically squeezed to scalar.

Parameters:

Name Type Description Default
f Callable

Scalar-valued function taking one or more positional arrays.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to (default 0).

0
has_aux bool

Whether f returns (scalar_output, auxiliary_data). When True, aux is stripped before detection.

False

Returns:

Type Description
SparsityPattern

Square SparsityPattern over the combined, selected input space.

Data Structures

asdex.SparsityPattern dataclass

Sparse matrix pattern storing only structural information (no values).

Stores row and column indices separately for efficient access by the coloring and decompression stages.

Attributes:

Name Type Description
rows NDArray[int32]

Row indices of non-zero entries, shape (nnz,).

cols NDArray[int32]

Column indices of non-zero entries, shape (nnz,).

shape tuple[int, int]

Matrix dimensions (m, n).

input_avals tuple[Any, ...]

One pytree of jax.ShapeDtypeStruct per positional argument of the traced function, in the same order jax.eval_shape(fun, *args) expects. Positions not in argnums are still stored (so the full input structure is preserved), but they do not contribute columns to the Jacobian / rows to the Hessian.

argnums int | tuple[int, ...]

Positions of input_avals that were differentiated, mirroring jax.grad / jax.jacfwd. An int stays int and a sequence becomes tuple[int, ...] — that distinction drives whether example_input is a single aval or a tuple of avals.

dyn_avals property

Sub-tuple of input_avals selected by argnums.

example_input property

The aval structure the returned Jacobian / Hessian mirrors.

When argnums is an int this is the single selected aval; when argnums is a tuple this is the tuple of selected avals. Matches jax/_src/api.py:746 (jacfwd) and line 840 (jacrev).

leaf_shapes property

Per-leaf shapes of the selected (differentiated) inputs.

leaf_sizes property

Per-leaf flat sizes (prod(shape)) of the selected inputs.

input_treedef property

Pytree structure of dyn_avals.

nnz property

Number of non-zero elements.

m property

Number of rows.

n property

Number of columns.

density property

Fraction of non-zero entries.

col_to_rows cached property

Mapping from column index to list of row indices with non-zeros.

Used by the coloring algorithm to build the row conflict graph.

row_to_cols cached property

Mapping from row index to list of column indices with non-zeros.

Used by the coloring algorithm to build the column conflict graph.

__post_init__()

Validate inputs and fill in the default single-leaf aval.

from_coo(rows, cols, shape, *, input_avals=(), argnums=0) classmethod

Create pattern from row and column index arrays.

Parameters:

Name Type Description Default
rows NDArray[int32] | list[int]

Row indices of non-zero entries.

required
cols NDArray[int32] | list[int]

Column indices of non-zero entries.

required
shape tuple[int, int]

Matrix dimensions (m, n).

required
input_avals tuple[Any, ...]

One pytree of ShapeDtypeStruct per positional argument of the traced function. Defaults to a single 1-D aval of size n.

()
argnums int | tuple[int, ...]

Positions of input_avals that were differentiated, mirroring jax.grad / jax.jacfwd.

0

from_bcoo(bcoo) classmethod

Create pattern from JAX BCOO sparse matrix.

from_dense(dense) classmethod

Create pattern from dense boolean/numeric matrix.

Non-zero entries indicate pattern positions.

to_bcoo(data=None)

Convert to JAX BCOO sparse matrix.

Parameters:

Name Type Description Default
data ndarray | None

Optional data values. If None, uses all 1s.

None

todense()

Convert to dense numpy array with 1s at pattern positions.

save(path)

Save sparsity pattern to an .npz file.

Supports multi-input and PyTree-structured patterns.

Parameters:

Name Type Description Default
path str | PathLike[str]

Destination file path.

required

load(path) classmethod

Load sparsity pattern from an .npz file.

Parameters:

Name Type Description Default
path str | PathLike[str]

Source file path.

required

__str__()

Render sparsity pattern with header and dot/braille grid.

__repr__()

Return compact single-line representation.

asdex.ColoredPattern dataclass

Result of a graph coloring for sparse differentiation.

Attributes:

Name Type Description
sparsity SparsityPattern

The sparsity pattern that was colored.

colors NDArray[int32]

Color assignment array. Shape (m,) for "rev" mode, (n,) for all other modes. A value of -1 means "neutral": the vertex is not seeded (used after star-coloring postprocessing).

num_colors int

Total number of active colors (number of JVPs/VJPs/HVPs).

symmetric bool

Whether symmetric (star) coloring was used.

mode ColoringMode

The AD mode. Resolved, never "auto". "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD), "fwd_over_rev" uses forward-over-reverse HVPs, "rev_over_fwd" uses reverse-over-forward HVPs, "rev_over_rev" uses reverse-over-reverse HVPs.

star_set StarSet | None

Star-coloring structure (hub/spoke assignment per edge). Present only for symmetric colorings produced by color_symmetric; None otherwise.

save(path)

Save colored pattern to an .npz file.

Supports multi-input and PyTree-structured patterns.

Parameters:

Name Type Description Default
path str | PathLike[str]

Destination file path.

required

load(path) classmethod

Load colored pattern from an .npz file.

Parameters:

Name Type Description Default
path str | PathLike[str]

Source file path.

required

__repr__()

Return compact single-line representation.

__str__()

Render colored pattern with sparsity grid and color assignments.


asdex.JacobianMode = Literal['fwd', 'rev'] module-attribute

AD mode for Jacobian computation.

"fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD).

asdex.HessianMode = Literal['fwd_over_rev', 'rev_over_fwd', 'rev_over_rev'] module-attribute

AD composition strategy for Hessian-vector products.

"fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse.

asdex.VerificationError

Bases: AssertionError

Raised when asdex's sparse result does not match JAX's dense reference.

This indicates that the detected sparsity pattern is missing nonzeros, which is a bug — asdex's patterns should always be conservative (i.e., contain at least all true nonzeros). If you encounter this error, please help out asdex's development by reporting this at https://github.com/adrhill/asdex/issues.