Skip to content

Jacobian

Jacobian Computation

asdex.jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing sparse Jacobians.

Combines jacobian_coloring and jacobian_from_coloring in one call.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns a pytree of Jacobian blocks matching argnums, with each leaf shaped (*out_shape, *in_leaf_shape). The block type depends on output_format (jax.experimental.sparse.BCOO by default, or jax.Array when "dense").

asdex.jacobian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a sparse Jacobian function from a pre-computed coloring.

Uses row coloring + VJPs or column coloring + JVPs, depending on which needs fewer colors.

The returned callable accepts *args, **kwargs; kwargs are forwarded to f at call time (matching jax.jacfwd / jax.jacrev).

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

asdex.value_and_jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing value and sparse Jacobian.

Like jacobian, but also returns the primal value f(*args) without an extra forward pass.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, jac) — or ((value, aux), jac) when has_aux=True, matching jax.value_and_grad ordering.

asdex.value_and_jacobian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a function computing value and sparse Jacobian from a pre-computed coloring.

For repeated evaluation, wrap the returned function in jax.jit: each unjitted call re-traces f, which can cost far more than the differentiation itself. The "numpy_dense" and scipy output formats cannot be jitted since they produce non-JAX arrays.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy and only support 2D Jacobians: the input and output must each be a single flat (1D) array (scalar outputs are not supported).

_DEFAULT_OUTPUT_FORMAT
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Compressed Jacobian

asdex.compressed_jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Detect sparsity, color, and return a function computing the compressed Jacobian.

Runs the same detect-and-color steps as jacobian, but stops at the dense compressed matrix B of shape (num_colors, dim): one VJP/JVP per color, before decompression scatters B into the pattern. Recover the sparse matrix with decompress or decompress_data, or work with B directly (custom solvers, cross-checks, debugging).

The returned B is a plain jax.Array, so the returned function is jit-able by the caller.

Unlike jacobian, it takes no output_format: formatting is the job of decompress.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns the compressed matrix B of shape (num_colors, dim), or (B, aux) when has_aux=True. dim is the flattened size of the differentiated inputs (the leaves selected by argnums) in "rev" mode, and the flattened size of f's output in "fwd" mode.

asdex.compressed_jacobian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Build a compressed Jacobian function from a pre-computed coloring.

Like jacobian_from_coloring, but stops at the compressed matrix B of shape (num_colors, dim) instead of materializing the sparse matrix. See compressed_jacobian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Returns:

Type Description
Callable[..., Any]

A function returning B of shape (num_colors, dim), or (B, aux) when has_aux=True.

asdex.value_and_compressed_jacobian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)

Like compressed_jacobian, also returning the value.

The primal value f(*args) rides the compression forward pass, so it is nearly free. See compressed_jacobian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
*sample_args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
mode JacobianMode | None

AD mode for Jacobian computation. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE
**sample_kwargs Any

Sample keyword arguments of f. Merged with sample_args based on f's signature.

{}

Returns:

Type Description
Callable[..., Any]

A function that takes the same positional args as f and returns (value, B) — or ((value, aux), B) when has_aux=True, matching jax.value_and_grad ordering.

asdex.value_and_compressed_jacobian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)

Value and compressed Jacobian from a pre-computed coloring.

Like value_and_jacobian_from_coloring, but stops at the compressed matrix B. See compressed_jacobian for B's layout.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be computed.

required
coloring ColoredPattern

Pre-computed colored sparsity pattern of type ColoredPattern.

required
has_aux bool

Whether f returns (output, auxiliary_data). Defaults to False.

_DEFAULT_HAS_AUX
holomorphic bool

Whether f is promised to be holomorphic. Defaults to False.

_DEFAULT_HOLOMORPHIC
allow_int bool

Whether to allow differentiating with respect to integer inputs.

_DEFAULT_ALLOW_INT
chunk_size int | None

Maximum number of colors to process in parallel. Defaults to None, processing all colors in a single vmapped batch. When specified, colors are processed in chunks of this size to reduce peak memory usage.

_DEFAULT_CHUNK_SIZE

Returns:

Type Description
Callable[..., Any]

A function returning (value, B), or ((value, aux), B) when has_aux=True.

Decompression

asdex.decompress(compressed, coloring, output_format=_DEFAULT_OUTPUT_FORMAT)

Decompress a compressed matrix B into a 2-D sparse matrix.

Composes decompress_data with format dispatch: it gathers B into the sparse values, then materializes the flat (m, n) matrix in the requested format.

Unlike the matrices returned by jacobian / hessian, this is always the flat 2-D matrix regardless of input/output pytree structure: B's natural domain is the 2-D compressed matrix.

Parameters:

Name Type Description Default
compressed Array

The compressed matrix B of shape (num_colors, dim), as returned by compressed_jacobian or compressed_hessian.

required
coloring ColoredPattern

The ColoredPattern that produced compressed.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns jax.experimental.sparse.BCOO (default), "dense" returns jax.Array, "numpy_dense" returns numpy.ndarray, "scipy_coo" returns scipy.sparse.coo_array, "scipy_csr" returns scipy.sparse.csr_array, "scipy_csc" returns scipy.sparse.csc_array. SciPy formats require scipy.

_DEFAULT_OUTPUT_FORMAT

Returns:

Type Description
Any

The sparse matrix of shape (m, n) in the requested format.

Raises:

Type Description
ValueError

If compressed does not match coloring's expected shape, or output_format is unknown.

ImportError

If a scipy output_format is requested but scipy is not installed.

asdex.decompress_data(compressed, coloring)

Gather a compressed matrix B into sparse values in pattern order.

Returns a plain jax.Array of shape (coloring.sparsity.nnz,) holding the sparse values in coloring.sparsity order, so data[k] is the entry at (coloring.sparsity.rows[k], coloring.sparsity.cols[k]).

This is the jittable numeric core of decompression: it always returns a jax.Array, so it composes inside jax.jit and can feed a custom solver or sparse format, whereas decompress may return host (numpy/scipy) objects that cannot. Pair it with to_bcoo for a BCOO, or with coloring.sparsity.rows / coloring.sparsity.cols to assemble a custom format.

Parameters:

Name Type Description Default
compressed Array

The compressed matrix B of shape (num_colors, dim), as returned by compressed_jacobian or compressed_hessian.

required
coloring ColoredPattern

The ColoredPattern that produced compressed.

required

Returns:

Type Description
Array

A jax.Array of shape (nnz,) with the sparse values in pattern order, matching compressed's dtype.

Raises:

Type Description
ValueError

If compressed does not have shape (num_colors, dim) for coloring (see compressed_jacobian for the per-mode dim).

Coloring

asdex.jacobian_coloring(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, postprocess=_DEFAULT_POSTPROCESS, **kwargs)

Detect Jacobian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian is to be computed.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). When True, the auxiliary output is ignored and only output is analyzed for sparsity.

_DEFAULT_HAS_AUX
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False.

_DEFAULT_POSTPROCESS
**kwargs Any

Sample keyword arguments of f. Non-traceable values (bools, strings, ints) are bound statically.

{}

Returns:

Type Description
ColoredPattern

asdex.jacobian_coloring_from_sparsity(sparsity, *, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_JACOBIAN, postprocess=_DEFAULT_POSTPROCESS)

Color a sparsity pattern for sparse Jacobian computation.

Assigns colors so that same-colored rows (or columns) can be computed together in a single VJP (or JVP).

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (m, n).

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). Defaults to picking whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

_DEFAULT_MODE
symmetric bool

Whether to use symmetric coloring. Requires a square Jacobian. Defaults to False.

_DEFAULT_SYMMETRIC_JACOBIAN
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False.

_DEFAULT_POSTPROCESS

Returns:

Type Description
ColoredPattern

Sparsity Detection

asdex.jacobian_sparsity(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, **kwargs)

Detect global Jacobian sparsity pattern for f.

Analyzes the computation graph structure directly, without evaluating any derivatives. The result is valid for all inputs.

Parameters:

Name Type Description Default
f Callable

Function whose Jacobian sparsity pattern is to be detected.

required
*args Any

Sample arguments of f. Only structure and dtypes are used, values are ignored.

()
argnums int | Sequence[int]

Specifies which positional argument(s) to differentiate with respect to. Defaults to 0.

_DEFAULT_ARGNUMS
has_aux bool

Whether f returns (output, auxiliary_data). When True, the auxiliary output is ignored and only output is analyzed for sparsity.

_DEFAULT_HAS_AUX
**kwargs Any

Sample keyword arguments of f. Non-traceable values (bools, strings, ints) are bound statically.

{}

Returns:

Type Description
SparsityPattern

SparsityPattern of shape (m, n_selected) where m = prod(output_shape) and n_selected is the total flat size of the selected inputs.

Verification

asdex.check_jacobian_correctness(f, x, coloring, *, method=_DEFAULT_VERIFY_METHOD, num_probes=_DEFAULT_NUM_PROBES, seed=_DEFAULT_SEED, rtol=_DEFAULT_TOL, atol=_DEFAULT_TOL)

Verify asdex's sparse Jacobian against a JAX reference at a given input.

Parameters:

Name Type Description Default
f Callable[..., Any]

Function whose Jacobian is to be verified.

required
x Any

Input at which to evaluate the Jacobian. For multi-input functions (where argnums is a tuple), pass a tuple of all positional arguments.

required
coloring ColoredPattern

Pre-computed colored pattern from jacobian_coloring.

required
method Literal['matvec', 'dense']

Verification method. "matvec" uses randomized matrix-vector products, which is O(k) in the number of probes. "dense" materializes the full dense matrix, which is O(n^2). Defaults to "matvec".

_DEFAULT_VERIFY_METHOD
num_probes int

Number of random probe vectors (only used by "matvec"). Defaults to 25.

_DEFAULT_NUM_PROBES
seed int

PRNG seed for reproducibility (only used by "matvec").

_DEFAULT_SEED
rtol float | None

Relative tolerance for comparison. Defaults to 1e-05 for "matvec" and 1e-07 for "dense".

_DEFAULT_TOL
atol float | None

Absolute tolerance for comparison. Defaults to 1e-05 for "matvec" and 1e-07 for "dense".

_DEFAULT_TOL

Raises:

Type Description
VerificationError

If the sparse and reference Jacobians disagree.

Configuration

asdex.JacobianMode = Literal['fwd', 'rev'] module-attribute

AD mode for Jacobian computation.

"fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD).