Skip to content

Full API

Differentiation

asdex.jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing sparse Jacobians.

Combines jacobian_coloring and jacobian_from_coloring in one call.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns a pytree of Jacobian blocks matching argnums, with each leaf shaped (*out_shape, *in_leaf_shape). The block type depends on output_format (jax.experimental.sparse.BCOO by default, or jax.Array when "dense").

asdex.jacobian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a sparse Jacobian function from a pre-computed coloring.

Uses row coloring + VJPs or column coloring + JVPs, depending on which needs fewer colors.

The returned callable accepts *args, **kwargs; kwargs are forwarded to f at call time (matching jax.jacfwd / jax.jacrev).

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

asdex.value_and_jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing value and sparse Jacobian.

Like jacobian, but also returns the primal value f(*args) without an extra forward pass.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, jac) — or ((value, aux), jac) when has_aux=True, matching jax.value_and_grad ordering.

asdex.value_and_jacobian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a function computing value and sparse Jacobian from a pre-computed coloring.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

asdex.hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing sparse Hessians.

If f returns a squeezable shape like (1,) or (1, 1), it is automatically squeezed to scalar.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Defaults to True.

_DEFAULT_SYMMETRIC_HESSIAN
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Hessians: the input must be a single flat (1D) array.

_DEFAULT_OUTPUT_FORMAT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns the sparse Hessian.

asdex.hessian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a sparse Hessian function from a pre-computed coloring.

Uses symmetric (star) coloring and Hessian-vector products by default.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Hessians: the input must be a single flat (1D) array.

_DEFAULT_OUTPUT_FORMAT
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

asdex.value_and_hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing value and sparse Hessian.

Like hessian, but also returns the primal value f(*args) without an extra forward pass.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Defaults to True.

_DEFAULT_SYMMETRIC_HESSIAN
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Hessians: the input must be a single flat (1D) array.

_DEFAULT_OUTPUT_FORMAT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, hessian).

asdex.value_and_hessian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a function computing value and sparse Hessian from a pre-computed coloring.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Hessians: the input must be a single flat (1D) array.

_DEFAULT_OUTPUT_FORMAT
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Coloring

asdex.jacobian_coloring(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, postprocess=_DEFAULT_POSTPROCESS, **kwargs)

Detect Jacobian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). When True, the auxiliary output is ignored and only output is analyzed for sparsity.

_DEFAULT_HAS_AUX
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False.

_DEFAULT_POSTPROCESS
**kwargs Any

Sample keyword arguments of f. Non-traceable values (bools, strings, ints) are bound statically.

{}

Returns:

Type Description
ColoredPattern

asdex.jacobian_coloring_from_sparsity(sparsity, *, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, postprocess=_DEFAULT_POSTPROCESS)

Color a sparsity pattern for sparse Jacobian computation.

Assigns colors so that same-colored rows (or columns) can be computed together in a single VJP (or JVP).

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (m, n).

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False.

_DEFAULT_POSTPROCESS

Returns:

Type Description
ColoredPattern

asdex.hessian_coloring(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, postprocess=_DEFAULT_POSTPROCESS, **kwargs)

Detect Hessian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Scalar-valued function whose Hessian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). When True, the auxiliary output is ignored and only output is analyzed for sparsity.

_DEFAULT_HAS_AUX
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Defaults to True.

_DEFAULT_SYMMETRIC_HESSIAN
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of HVPs during decompression). Pruned vertices get the neutral color -1 (no HVP is computed for them). Defaults to False.

_DEFAULT_POSTPROCESS
**kwargs Any

Sample keyword arguments of f. Non-traceable values (bools, strings, ints) are bound statically.

{}

Returns:

Type Description
ColoredPattern

asdex.hessian_coloring_from_sparsity(sparsity, *, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, postprocess=_DEFAULT_POSTPROCESS)

Color a sparsity pattern for sparse Hessian computation.

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (n, n).

required
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Defaults to True.

_DEFAULT_SYMMETRIC_HESSIAN
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of HVPs during decompression). Pruned vertices get the neutral color -1 (no HVP is computed for them). Defaults to False.

_DEFAULT_POSTPROCESS

Returns:

Type Description
ColoredPattern

asdex.color_rows(sparsity)

Greedy row-wise coloring for sparse Jacobian computation.

Assigns colors to rows such that no two rows sharing a non-zero column have the same color. This enables computing multiple Jacobian rows in a single VJP by using a combined seed vector.

Uses LargestFirst vertex ordering for fewer colors.

Parameters:

Name Type Description Default
sparsity SparsityPattern

Jacobian sparsity pattern of shape (m, n).

required

Returns:

Type Description
tuple[NDArray[int32], int]

Tuple of (colors, num_colors) where:

  • colors: Array of shape (m,) with color assignment for each row
  • num_colors: Total number of colors used

asdex.color_cols(sparsity)

Greedy column-wise coloring for sparse Jacobian computation.

Assigns colors to columns such that no two columns sharing a non-zero row have the same color. This enables computing multiple Jacobian columns in a single JVP by using a combined tangent vector.

Uses LargestFirst vertex ordering for fewer colors.

Parameters:

Name Type Description Default
sparsity SparsityPattern

Jacobian sparsity pattern of shape (m, n).

required

Returns:

Type Description
tuple[NDArray[int32], int]

Tuple of (colors, num_colors) where:

  • colors: Array of shape (n,) with color assignment for each column
  • num_colors: Total number of colors used

asdex.color_symmetric(sparsity, *, postprocess=_DEFAULT_POSTPROCESS, forced_colors=None)

Greedy symmetric coloring for sparse Hessian computation.

Implements Algorithm 4.1 from Gebremedhin et al. (2007). A star coloring is a distance-1 coloring with the additional constraint that every path on 4 vertices uses at least 3 colors. Returns a :class:StarSet alongside the colors so that Hessian decompression can use hub-based extraction.

Uses LargestFirst vertex ordering.

Parameters:

Name Type Description Default
sparsity SparsityPattern

Hessian sparsity pattern of shape (n, n).

required
postprocess bool

If True, replace colors that are never used as a hub color (and not forced by a diagonal entry) with -1 (neutral), then compact remaining colors down. This reduces the number of HVPs needed during decompression. Defaults to False, matching SparseMatrixColorings.jl's postprocessing=false default.

_DEFAULT_POSTPROCESS
forced_colors NDArray[int32] | list[int] | None

Optional pre-computed color assignment of shape (n,). When provided, the algorithm verifies it satisfies the star-coloring constraints and raises InvalidColoringError otherwise.

None

Returns:

Type Description
tuple[NDArray[int32], int, StarSet]

Tuple (colors, num_colors, star_set) where:

  • colors: Array of shape (n,) with color assignment for each vertex. Values are in [0, num_colors - 1] for active vertices. After postprocessing, vertices whose color is pruned have value -1 (neutral — no HVP needed for them).
  • num_colors: Number of active colors (i.e. number of HVPs).
  • star_set: :class:StarSet encoding the 2-colored star decomposition.

Raises:

Type Description
ValueError

If pattern is not square.

InvalidColoringError

If forced_colors violates a star-coloring constraint.

Visualization

asdex.spy(pattern, *, ax=None, compressed=False, cmap=None, **kwargs)

Plot a sparsity pattern or colored pattern using matplotlib.

For a SparsityPattern, plots nonzeros as filled cells on a grid. For a ColoredPattern, fills cells with their assigned color.

When compressed=True on a ColoredPattern, plots the compressed pattern after coloring instead of the original.

Parameters:

Name Type Description Default
pattern SparsityPattern | ColoredPattern

The sparsity or colored pattern to plot.

required
ax Any

Matplotlib axes to plot on. If None, creates a new figure.

None
compressed bool

If True and pattern is a ColoredPattern, plot the compressed pattern instead of the original.

False
cmap Any

Matplotlib colormap for colored patterns. If None, uses tab10.

None
**kwargs Any

Extra keyword arguments passed to ax.imshow.

{}

Returns:

Type Description
Any

The matplotlib axes with the plot.

Sparsity Detection

asdex.jacobian_sparsity(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, **kwargs)

Detect global Jacobian sparsity pattern for f.

Analyzes the computation graph structure directly, without evaluating any derivatives. The result is valid for all inputs.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian sparsity pattern is to be detected.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). When True, the auxiliary output is ignored and only output is analyzed for sparsity.

_DEFAULT_HAS_AUX
**kwargs Any

Sample keyword arguments of f. Non-traceable values (bools, strings, ints) are bound statically.

{}

Returns:

Type Description
SparsityPattern

SparsityPattern of shape (m, n_selected) where m = prod(output_shape) and n_selected is the total flat size of the selected inputs.

asdex.hessian_sparsity(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, **kwargs)

Detect global Hessian sparsity pattern for a scalar-valued f.

Analyzes the Jacobian sparsity of the gradient function, without evaluating any derivatives. The result is valid for all inputs.

If f returns a squeezable shape like (1,) or (1, 1), it is automatically squeezed to scalar.

Parameters:

Name Type Description Default
f Callable

Scalar-valued function whose Hessian sparsity pattern is to be detected.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). When True, the auxiliary output is ignored and only output is analyzed for sparsity.

_DEFAULT_HAS_AUX
**kwargs Any

Sample keyword arguments of f. Non-traceable values (bools, strings, ints) are bound statically.

{}

Returns:

Type Description
SparsityPattern

Square SparsityPattern over the combined, selected input space.

Data Structures

asdex.SparsityPattern dataclass

Sparse matrix pattern storing only structural information (no values).

Stores row and column indices separately for efficient access by the coloring and decompression stages.

Attributes:

Name Type Description
rows NDArray[int32]

Row indices of non-zero entries, shape (nnz,).

cols NDArray[int32]

Column indices of non-zero entries, shape (nnz,).

shape tuple[int, int]

Matrix dimensions (m, n).

input_avals tuple[Any, ...]

One pytree of jax.ShapeDtypeStruct per positional argument of the traced function, in the same order jax.eval_shape(fun, *args) expects. Positions not in argnums are still stored (so the full input structure is preserved), but they do not contribute columns to the Jacobian / rows to the Hessian.

argnums int | tuple[int, ...]

Positions of input_avals that were differentiated, mirroring jax.grad / jax.jacfwd. An int stays int and a sequence becomes tuple[int, ...] — that distinction drives whether example_input is a single aval or a tuple of avals.

dyn_avals property

Sub-tuple of input_avals selected by argnums.

example_input property

The aval structure the returned Jacobian / Hessian mirrors.

When argnums is an int this is the single selected aval; when argnums is a tuple this is the tuple of selected avals. Matches jax/_src/api.py:746 (jacfwd) and line 840 (jacrev).

leaf_shapes property

Per-leaf shapes of the selected (differentiated) inputs.

leaf_sizes property

Per-leaf flat sizes (prod(shape)) of the selected inputs.

input_treedef property

Pytree structure of dyn_avals.

nnz property

Number of non-zero elements.

m property

Number of rows.

n property

Number of columns.

density property

Fraction of non-zero entries.

col_to_rows cached property

Mapping from column index to list of row indices with non-zeros.

Used by the coloring algorithm to build the row conflict graph.

row_to_cols cached property

Mapping from row index to list of column indices with non-zeros.

Used by the coloring algorithm to build the column conflict graph.

__post_init__()

Validate inputs and fill in the default single-leaf aval.

from_coo(rows, cols, shape, *, input_avals=(), argnums=0) classmethod

Create pattern from row and column index arrays.

Parameters:

Name Type Description Default
rows NDArray[int32] | list[int]

Row indices of non-zero entries.

required
cols NDArray[int32] | list[int]

Column indices of non-zero entries.

required
shape tuple[int, int]

Matrix dimensions (m, n).

required
input_avals tuple[Any, ...]

One pytree of ShapeDtypeStruct per positional argument of the traced function. Defaults to a single 1-D aval of size n.

()
argnums int | tuple[int, ...]

Positions of input_avals that were differentiated, mirroring jax.grad / jax.jacfwd.

0

from_bcoo(bcoo) classmethod

Create pattern from JAX BCOO sparse matrix.

from_dense(dense) classmethod

Create pattern from dense boolean/numeric matrix.

Non-zero entries indicate pattern positions.

to_bcoo(data=None)

Convert to JAX BCOO sparse matrix.

Parameters:

Name Type Description Default
data ndarray | None

Optional data values. If None, uses all 1s.

None

todense()

Convert to dense numpy array with 1s at pattern positions.

save(path)

Save sparsity pattern to an .npz file.

Supports multi-input and PyTree-structured patterns.

Parameters:

Name Type Description Default
path str | PathLike[str]

Destination file path.

required

load(path) classmethod

Load sparsity pattern from an .npz file.

Parameters:

Name Type Description Default
path str | PathLike[str]

Source file path.

required

__str__()

Render sparsity pattern with header and dot/braille grid.

__repr__()

Return compact single-line representation.

asdex.ColoredPattern dataclass

Result of a graph coloring for sparse differentiation.

Attributes:

Name Type Description
sparsity SparsityPattern

The sparsity pattern that was colored.

colors NDArray[int32]

Color assignment array. Shape (m,) for "rev" mode, (n,) for all other modes. A value of -1 means "neutral": the vertex is not seeded (used after star-coloring postprocessing).

num_colors int

Total number of active colors (number of JVPs/VJPs/HVPs).

symmetric bool

Whether symmetric (star) coloring was used.

mode ColoringMode

The AD mode. Resolved, never "auto". "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD), "fwd_over_rev" uses forward-over-reverse HVPs, "rev_over_fwd" uses reverse-over-forward HVPs, "rev_over_rev" uses reverse-over-reverse HVPs.

star_set StarSet | None

Star-coloring structure (hub/spoke assignment per edge). Present only for symmetric colorings produced by color_symmetric; None otherwise.

save(path)

Save colored pattern to an .npz file.

Supports multi-input and PyTree-structured patterns.

Parameters:

Name Type Description Default
path str | PathLike[str]

Destination file path.

required

load(path) classmethod

Load colored pattern from an .npz file.

Parameters:

Name Type Description Default
path str | PathLike[str]

Source file path.

required

__repr__()

Return compact single-line representation.

__str__()

Render colored pattern with sparsity grid and color assignments.


asdex.JacobianMode = Literal['fwd', 'rev'] module-attribute

AD mode for Jacobian computation.

"fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD).

asdex.HessianMode = Literal['fwd_over_rev', 'rev_over_fwd', 'rev_over_rev'] module-attribute

AD composition strategy for Hessian-vector products.

"fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse.

asdex.ColoringMode = JacobianMode | HessianMode module-attribute

AD mode that a coloring was computed for.

asdex.VerificationError

Bases: AssertionError

Raised when asdex's sparse result does not match JAX's dense reference.

This indicates that the detected sparsity pattern is missing nonzeros, which is a bug — asdex's patterns should always be conservative (i.e., contain at least all true nonzeros). If you encounter this error, please help out asdex's development by reporting this at https://github.com/adrhill/asdex/issues.

Compressed Differentiation

Advanced entry points that stop at the compressed matrix \(B\), leaving decompression to the caller. See Skipping Decompression.

asdex.compressed_jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing the compressed Jacobian.

Runs the same detect-and-color steps as jacobian, but stops at the dense compressed matrix B of shape (num_colors, dim): one VJP/JVP per color, before decompression scatters B into the pattern. Recover the sparse matrix with decompress or decompress_data, or work with B directly (custom solvers, cross-checks, debugging).

The returned B is a plain jax.Array, so the returned function is jit-able by the caller.

Unlike jacobian, it takes no output_format: formatting is the job of decompress.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns the compressed matrix B of shape (num_colors, dim), or (B, aux) when has_aux=True. dim is the flattened size of the differentiated inputs (the leaves selected by argnums) in "rev" mode, and the flattened size of f's output in "fwd" mode.

asdex.compressed_jacobian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a compressed Jacobian function from a pre-computed coloring.

Like jacobian_from_coloring, but stops at the compressed matrix B of shape (num_colors, dim) instead of materializing the sparse matrix. See compressed_jacobian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Returns:

Type Description
Callable[..., Any]

A function returning B of shape (num_colors, dim), or (B, aux) when has_aux=True.

asdex.value_and_compressed_jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Like compressed_jacobian, also returning the value.

The primal value f(*args) rides the compression forward pass, so it is nearly free. See compressed_jacobian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, B) — or ((value, aux), B) when has_aux=True, matching jax.value_and_grad ordering.

asdex.value_and_compressed_jacobian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Value and compressed Jacobian from a pre-computed coloring.

Like value_and_jacobian_from_coloring, but stops at the compressed matrix B. See compressed_jacobian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Returns:

Type Description
Callable[..., Any]

A function returning (value, B), or ((value, aux), B) when has_aux=True.


asdex.compressed_hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing the compressed Hessian.

Runs the same detect-and-color steps as hessian, but stops at the dense compressed matrix B of shape (num_colors, n): one HVP per color, before decompression scatters B into the pattern. n is the flattened size of the differentiated inputs (the leaves selected by argnums), so the Hessian is (n, n). Recover the sparse matrix with decompress or decompress_data, or work with B directly.

The returned B is a plain jax.Array, so the returned function is jit-able by the caller.

Unlike hessian, it takes no output_format: formatting is the job of decompress.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Defaults to True.

_DEFAULT_SYMMETRIC_HESSIAN
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns the compressed matrix B of shape (num_colors, n), or (B, aux) when has_aux=True.

asdex.compressed_hessian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a compressed Hessian function from a pre-computed coloring.

Like hessian_from_coloring, but stops at the compressed matrix B of shape (num_colors, n) instead of materializing the sparse matrix. See compressed_hessian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Returns:

Type Description
Callable[..., Any]

A function returning B of shape (num_colors, n), or (B, aux) when has_aux=True.

asdex.value_and_compressed_hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Like compressed_hessian, also returning the value.

The primal value f(*args) rides the HVP forward pass, so it is nearly free. See compressed_hessian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
mode HessianMode | None

AD composition strategy for Hessian-vector products. "fwd_over_rev" uses forward-over-reverse, "rev_over_fwd" uses reverse-over-forward, "rev_over_rev" uses reverse-over-reverse. Defaults to "fwd_over_rev".

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Defaults to True.

_DEFAULT_SYMMETRIC_HESSIAN
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, B) — or ((value, aux), B) when has_aux=True.

asdex.value_and_compressed_hessian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Value and compressed Hessian from a pre-computed coloring.

Like value_and_hessian_from_coloring, but stops at the compressed matrix B. See compressed_hessian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Scalar-valued function whose Hessian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Unsupported for Hessians; passing True raises TypeError (integer inputs cannot be differentiated twice, matching jax.hessian).

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Returns:

Type Description
Callable[..., Any]

A function returning (value, B), or ((value, aux), B) when has_aux=True.


asdex.decompress(compressed, coloring, output_format=_DEFAULT_OUTPUT_FORMAT)

Decompress a compressed matrix B into a 2-D sparse matrix.

Composes decompress_data with format dispatch: it gathers B into the sparse values, then materializes the flat (m, n) matrix in the requested format.

Unlike the matrices returned by jacobian / hessian, this is always the flat 2-D matrix regardless of input/output pytree structure: B's natural domain is the 2-D compressed matrix.

Parameters:

Name Type Description Default
compressed Array

The compressed matrix B of shape (num_colors, dim), as returned by compressed_jacobian or compressed_hessian.

required
coloring ColoredPattern

The ColoredPattern that produced compressed.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy.

_DEFAULT_OUTPUT_FORMAT

Returns:

Type Description
Any

The sparse matrix of shape (m, n) in the requested format.

Raises:

Type Description
ValueError

If compressed does not match coloring's expected shape, or output_format is unknown.

ImportError

If a scipy output_format is requested but scipy is not installed.

asdex.decompress_data(compressed, coloring)

Gather a compressed matrix B into sparse values in pattern order.

Returns a plain jax.Array of shape (coloring.sparsity.nnz,) holding the sparse values in coloring.sparsity order, so data[k] is the entry at (coloring.sparsity.rows[k], coloring.sparsity.cols[k]).

This is the jittable numeric core of decompression: it always returns a jax.Array, so it composes inside jax.jit and can feed a custom solver or sparse format, whereas decompress may return host (numpy/scipy) objects that cannot. Pair it with to_bcoo for a BCOO, or with coloring.sparsity.rows / coloring.sparsity.cols to assemble a custom format.

Parameters:

Name Type Description Default
compressed Array

The compressed matrix B of shape (num_colors, dim), as returned by compressed_jacobian or compressed_hessian.

required
coloring ColoredPattern

The ColoredPattern that produced compressed.

required

Returns:

Type Description
Array

A jax.Array of shape (nnz,) with the sparse values in pattern order, matching compressed's dtype.

Raises:

Type Description
ValueError

If compressed does not have shape (num_colors, dim) for coloring (see compressed_jacobian for the per-mode dim).