Skip to content

Jacobian

Jacobian Computation

asdex.jacobian(f, input_shape, *, mode=None, symmetric=False, output_format='bcoo')

jacobian(
    f: Callable[[ArrayLike], ArrayLike],
    input_shape: int | tuple[int, ...],
    *,
    mode: JacobianMode | None = None,
    symmetric: bool = False,
    output_format: Literal["bcoo"] = ...,
) -> Callable[[ArrayLike], BCOO]
jacobian(
    f: Callable[[ArrayLike], ArrayLike],
    input_shape: int | tuple[int, ...],
    *,
    mode: JacobianMode | None = None,
    symmetric: bool = False,
    output_format: Literal["dense"],
) -> Callable[[ArrayLike], jax.Array]

Detect sparsity, color, and return a function computing sparse Jacobians.

Combines jacobian_coloring and jacobian_from_coloring in one call.

Parameters:

Name Type Description Default
f Callable[[ArrayLike], ArrayLike]

Function taking an array and returning an array. Input and output may be multi-dimensional.

required
input_shape int | tuple[int, ...]

Shape of the input array.

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). None picks whichever of fwd/rev needs fewer colors.

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
output_format OutputFormat

Type of the output matrix. "bcoo" returns a sparse matrix of type jax.experimental.sparse.BCOO (default), "dense" returns a dense matrix of type jax.Array.

'bcoo'

Returns:

Type Description
Callable[[ArrayLike], BCOO | Array]

A function that takes an input array and returns the Jacobian of shape (m, n) where n = x.size and m = prod(output_shape). The output type depends on output_format: a sparse jax.experimental.sparse.BCOO (default) or a dense jax.Array.

asdex.value_and_jacobian(f, input_shape, *, mode=None, symmetric=False, output_format='bcoo')

value_and_jacobian(
    f: Callable[[ArrayLike], ArrayLike],
    input_shape: int | tuple[int, ...],
    *,
    mode: JacobianMode | None = None,
    symmetric: bool = False,
    output_format: Literal["bcoo"] = ...,
) -> Callable[[ArrayLike], tuple[jax.Array, BCOO]]
value_and_jacobian(
    f: Callable[[ArrayLike], ArrayLike],
    input_shape: int | tuple[int, ...],
    *,
    mode: JacobianMode | None = None,
    symmetric: bool = False,
    output_format: Literal["dense"],
) -> Callable[[ArrayLike], tuple[jax.Array, jax.Array]]

Detect sparsity, color, and return a function computing value and sparse Jacobian.

Like jacobian, but also returns the primal value f(x) without an extra forward pass.

Parameters:

Name Type Description Default
f Callable[[ArrayLike], ArrayLike]

Function taking an array and returning an array. Input and output may be multi-dimensional.

required
input_shape int | tuple[int, ...]

Shape of the input array.

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD). None picks whichever of fwd/rev needs fewer colors.

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
output_format OutputFormat

Type of the output matrix. "bcoo" returns a sparse matrix of type jax.experimental.sparse.BCOO (default), "dense" returns a dense matrix of type jax.Array.

'bcoo'

Returns:

Type Description
Callable[[ArrayLike], tuple[Array, BCOO | Array]]

A function that takes an input array and returns (f(x), J) where J is the Jacobian of shape (m, n) where n = x.size and m = prod(output_shape). The output type depends on output_format: a sparse jax.experimental.sparse.BCOO (default) or a dense jax.Array.

asdex.jacobian_from_coloring(f, coloring, output_format='bcoo')

jacobian_from_coloring(
    f: Callable[[ArrayLike], ArrayLike],
    coloring: ColoredPattern,
    output_format: Literal["bcoo"] = ...,
) -> Callable[[ArrayLike], BCOO]
jacobian_from_coloring(
    f: Callable[[ArrayLike], ArrayLike],
    coloring: ColoredPattern,
    output_format: Literal["dense"],
) -> Callable[[ArrayLike], jax.Array]

Build a sparse Jacobian function from a pre-computed coloring.

Uses row coloring + VJPs or column coloring + JVPs, depending on which needs fewer colors.

Parameters:

Name Type Description Default
f Callable[[ArrayLike], ArrayLike]

Function taking an array and returning an array. Input and output may be multi-dimensional.

required
coloring ColoredPattern

Pre-computed ColoredPattern from jacobian_coloring.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns a sparse matrix of type jax.experimental.sparse.BCOO (default), "dense" returns a dense matrix of type jax.Array.

'bcoo'

Returns:

Type Description
Callable[[ArrayLike], BCOO | Array]

A function that takes an input array and returns the Jacobian of shape (m, n) where n = x.size and m = prod(output_shape). The output type depends on output_format: a sparse jax.experimental.sparse.BCOO (default) or a dense jax.Array.

asdex.value_and_jacobian_from_coloring(f, coloring, output_format='bcoo')

value_and_jacobian_from_coloring(
    f: Callable[[ArrayLike], ArrayLike],
    coloring: ColoredPattern,
    output_format: Literal["bcoo"] = ...,
) -> Callable[[ArrayLike], tuple[jax.Array, BCOO]]
value_and_jacobian_from_coloring(
    f: Callable[[ArrayLike], ArrayLike],
    coloring: ColoredPattern,
    output_format: Literal["dense"],
) -> Callable[[ArrayLike], tuple[jax.Array, jax.Array]]

Build a function computing value and sparse Jacobian from a pre-computed coloring.

Like jacobian_from_coloring, but also returns the primal value f(x) without an extra forward pass.

Parameters:

Name Type Description Default
f Callable[[ArrayLike], ArrayLike]

Function taking an array and returning an array. Input and output may be multi-dimensional.

required
coloring ColoredPattern

Pre-computed ColoredPattern from jacobian_coloring.

required
output_format OutputFormat

Type of the output matrix. "bcoo" returns a sparse matrix of type jax.experimental.sparse.BCOO (default), "dense" returns a dense matrix of type jax.Array.

'bcoo'

Returns:

Type Description
Callable[[ArrayLike], tuple[Array, BCOO | Array]]

A function that takes an input array and returns (f(x), J) where J is the Jacobian of shape (m, n) where n = x.size and m = prod(output_shape). The output type depends on output_format: a sparse jax.experimental.sparse.BCOO (default) or a dense jax.Array.

Coloring

asdex.jacobian_coloring(f, input_shape, *, mode=None, symmetric=False, postprocess=False)

Detect Jacobian sparsity and color in one step.

Parameters:

Name Type Description Default
f Callable

Function taking an array and returning an array.

required
input_shape int | tuple[int, ...]

Shape of the input array.

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD), None picks whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square Jacobian.

False
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

asdex.jacobian_coloring_from_sparsity(sparsity, *, mode=None, symmetric=False, postprocess=False)

Color a sparsity pattern for sparse Jacobian computation.

Assigns colors so that same-colored rows (or columns) can be computed together in a single VJP (or JVP).

Parameters:

Name Type Description Default
sparsity SparsityPattern | NDArray | BCOO

A SparsityPattern, NumPy array, or JAX BCOO matrix of shape (m, n).

required
mode JacobianMode | None

AD mode. "fwd" uses JVPs (column coloring), "rev" uses VJPs (row coloring). None picks whichever of fwd/rev needs fewer colors (unless symmetric is True, in which case defaults to "fwd").

None
symmetric bool

Whether to use symmetric (star) coloring. Requires a square pattern.

False
postprocess bool

Only read when symmetric=True. Prune colors never used as hubs and compact the remaining ones (reduces the number of VJPs/JVPs during decompression). Defaults to False, matching SparseMatrixColorings.jl.

False

Returns:

Type Description
ColoredPattern

Sparsity Detection

asdex.jacobian_sparsity(f, input_shape)

Detect global Jacobian sparsity pattern for f: R^n -> R^m.

Analyzes the computation graph structure directly, without evaluating any derivatives. The result is valid for all inputs.

Parameters:

Name Type Description Default
f Callable

Function taking an array and returning an array.

required
input_shape int | tuple[int, ...]

Shape of the input array. An integer is treated as a 1D length.

required

Returns:

Type Description
SparsityPattern

SparsityPattern of shape (m, n) where n = prod(input_shape) and m = prod(output_shape). Entry (i, j) is present if output i depends on input j.

Verification

asdex.check_jacobian_correctness(f, x, coloring, *, method='matvec', num_probes=25, seed=0, rtol=None, atol=None)

Verify asdex's sparse Jacobian against a JAX reference at a given input.

Parameters:

Name Type Description Default
f Callable[[ArrayLike], ArrayLike]

Function taking an array and returning an array.

required
x ArrayLike

Input at which to evaluate the Jacobian.

required
coloring ColoredPattern

Pre-computed colored pattern from :func:~asdex.jacobian_coloring.

required
method Literal['matvec', 'dense']

Verification method. "matvec" uses randomized matrix-vector products, which is O(k) in the number of probes. "dense" materializes the full dense Jacobian, which is O(n^2).

'matvec'
num_probes int

Number of random probe vectors (only used by "matvec").

25
seed int

PRNG seed for reproducibility (only used by "matvec").

0
rtol float | None

Relative tolerance for comparison. Defaults to 1e-5 for "matvec" and 1e-7 for "dense".

None
atol float | None

Absolute tolerance for comparison. Defaults to 1e-5 for "matvec" and 1e-7 for "dense".

None

Raises:

Type Description
VerificationError

If the sparse and reference Jacobians disagree.

Configuration

asdex.JacobianMode = Literal['fwd', 'rev'] module-attribute

AD mode for Jacobian computation.

"fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD).