Full API¶
Differentiation¶
asdex.jacobian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=False, output_format='bcoo')
¶
Detect sparsity, color, and return a function computing sparse Jacobians.
Combines jacobian_coloring
and jacobian_from_coloring
in one call.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Function whose Jacobian is to be computed. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to (default |
0
|
has_aux
|
bool
|
Whether |
False
|
holomorphic
|
bool
|
Whether |
False
|
allow_int
|
bool
|
Whether to allow differentiating with respect to
integer-valued inputs, mirroring |
False
|
mode
|
JacobianMode | None
|
AD mode.
|
None
|
symmetric
|
bool
|
Whether to use symmetric (star) coloring. Requires a square Jacobian. |
False
|
output_format
|
OutputFormat
|
Type of the output matrix.
|
'bcoo'
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function that takes the same positional args as |
asdex.value_and_jacobian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=False, output_format='bcoo')
¶
Detect sparsity, color, and return a function computing value and sparse Jacobian.
Like jacobian,
but also returns the primal value f(*args)
without an extra forward pass.
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function that takes the same positional args as |
asdex.hessian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=True, output_format='bcoo')
¶
Detect sparsity, color, and return a function computing sparse Hessians.
If f returns a squeezable shape like (1,) or (1, 1),
it is automatically squeezed to scalar.
asdex.value_and_hessian(f, *args, argnums=0, has_aux=False, holomorphic=False, allow_int=False, mode=None, symmetric=True, output_format='bcoo')
¶
Detect sparsity, color, and return a function computing value and sparse Hessian.
Like hessian, but also returns the primal value
f(*args) without an extra forward pass.
asdex.jacobian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)
¶
Build a sparse Jacobian function from a pre-computed coloring.
Uses row coloring + VJPs or column coloring + JVPs, depending on which needs fewer colors.
The returned callable accepts *args, **kwargs; kwargs are forwarded
to f at call time (matching jax.jacfwd / jax.jacrev).
asdex.value_and_jacobian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)
¶
Build a function computing value and sparse Jacobian from a pre-computed coloring.
asdex.hessian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)
¶
Build a sparse Hessian function from a pre-computed coloring.
Uses symmetric (star) coloring and Hessian-vector products by default.
asdex.value_and_hessian_from_coloring(f, coloring, output_format='bcoo', *, has_aux=False, holomorphic=False, allow_int=False)
¶
Build a function computing value and sparse Hessian from a pre-computed coloring.
Coloring¶
asdex.jacobian_coloring(f, *args, argnums=0, has_aux=False, mode=None, symmetric=False, postprocess=False)
¶
Detect Jacobian sparsity and color in one step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable
|
Function whose Jacobian is to be computed. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to (default |
0
|
has_aux
|
bool
|
If |
False
|
mode
|
JacobianMode | None
|
AD mode.
|
None
|
symmetric
|
bool
|
Whether to use symmetric (star) coloring. Requires a square Jacobian. |
False
|
postprocess
|
bool
|
Only read when |
False
|
Returns:
| Type | Description |
|---|---|
ColoredPattern
|
A |
asdex.hessian_coloring(f, *args, argnums=0, has_aux=False, mode=None, symmetric=True, postprocess=False)
¶
Detect Hessian sparsity and color in one step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable
|
Scalar-valued function taking one or more positional arrays. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to (default |
0
|
has_aux
|
bool
|
If |
False
|
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
None
|
symmetric
|
bool
|
Whether to use symmetric (star) coloring. Defaults to True (exploits H = H^T for fewer colors). |
True
|
postprocess
|
bool
|
Only read when |
False
|
Returns:
| Type | Description |
|---|---|
ColoredPattern
|
A |
asdex.jacobian_coloring_from_sparsity(sparsity, *, mode=None, symmetric=False, postprocess=False)
¶
Color a sparsity pattern for sparse Jacobian computation.
Assigns colors so that same-colored rows (or columns) can be computed together in a single VJP (or JVP).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparsity
|
SparsityPattern | NDArray | BCOO
|
A |
required |
mode
|
JacobianMode | None
|
AD mode.
|
None
|
symmetric
|
bool
|
Whether to use symmetric (star) coloring. Requires a square pattern. |
False
|
postprocess
|
bool
|
Only read when |
False
|
Returns:
| Type | Description |
|---|---|
ColoredPattern
|
A |
asdex.hessian_coloring_from_sparsity(sparsity, *, mode=None, symmetric=True, postprocess=False)
¶
Color a sparsity pattern for sparse Hessian computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparsity
|
SparsityPattern | NDArray | BCOO
|
A |
required |
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
None
|
symmetric
|
bool
|
Whether to use symmetric (star) coloring. Defaults to True (exploits Hessian symmetry for fewer colors). |
True
|
postprocess
|
bool
|
Only read when |
False
|
Returns:
| Type | Description |
|---|---|
ColoredPattern
|
A |
asdex.color_rows(sparsity)
¶
Greedy row-wise coloring for sparse Jacobian computation.
Assigns colors to rows such that no two rows sharing a non-zero column have the same color. This enables computing multiple Jacobian rows in a single VJP by using a combined seed vector.
Uses LargestFirst vertex ordering for fewer colors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparsity
|
SparsityPattern
|
SparsityPattern of shape (m, n) representing the Jacobian sparsity pattern |
required |
Returns:
| Type | Description |
|---|---|
tuple[NDArray[int32], int]
|
Tuple of (colors, num_colors) where:
|
asdex.color_cols(sparsity)
¶
Greedy column-wise coloring for sparse Jacobian computation.
Assigns colors to columns such that no two columns sharing a non-zero row have the same color. This enables computing multiple Jacobian columns in a single JVP by using a combined tangent vector.
Uses LargestFirst vertex ordering for fewer colors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparsity
|
SparsityPattern
|
SparsityPattern of shape (m, n) representing the Jacobian sparsity pattern |
required |
Returns:
| Type | Description |
|---|---|
tuple[NDArray[int32], int]
|
Tuple of (colors, num_colors) where:
|
asdex.color_symmetric(sparsity, *, postprocess=False, forced_colors=None)
¶
Greedy symmetric coloring for sparse Hessian computation.
Implements Algorithm 4.1 from Gebremedhin et al. (2007).
A star coloring is a distance-1 coloring with the additional constraint
that every path on 4 vertices uses at least 3 colors.
Returns a :class:StarSet alongside the colors so that
Hessian decompression can use hub-based extraction.
Uses LargestFirst vertex ordering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparsity
|
SparsityPattern
|
SparsityPattern of shape |
required |
postprocess
|
bool
|
If |
False
|
forced_colors
|
NDArray[int32] | list[int] | None
|
Optional pre-computed color assignment of shape |
None
|
Returns:
| Type | Description |
|---|---|
tuple[NDArray[int32], int, StarSet]
|
Tuple
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If pattern is not square. |
InvalidColoringError
|
If |
Visualization¶
asdex.spy(pattern, *, ax=None, compressed=False, cmap=None, **kwargs)
¶
Plot a sparsity pattern or colored pattern using matplotlib.
For a SparsityPattern, plots nonzeros as filled cells on a grid.
For a ColoredPattern, fills cells with their assigned color.
When compressed=True on a ColoredPattern,
plots the compressed pattern after coloring instead of the original.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pattern
|
SparsityPattern | ColoredPattern
|
The sparsity or colored pattern to plot. |
required |
ax
|
Axes | None
|
Matplotlib axes to plot on.
If |
None
|
compressed
|
bool
|
If |
False
|
cmap
|
Any
|
Matplotlib colormap for colored patterns.
If |
None
|
**kwargs
|
Any
|
Extra keyword arguments passed to |
{}
|
Returns:
| Type | Description |
|---|---|
Axes
|
The matplotlib axes with the plot. |
Sparsity Detection¶
asdex.jacobian_sparsity(f, *args, argnums=0, has_aux=False)
¶
Detect global Jacobian sparsity pattern for f.
Analyzes the computation graph structure directly, without evaluating any derivatives. The result is valid for all inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable
|
Function whose Jacobian sparsity pattern is to be detected. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to (default |
0
|
has_aux
|
bool
|
Whether |
False
|
Returns:
| Type | Description |
|---|---|
SparsityPattern
|
SparsityPattern of shape |
asdex.hessian_sparsity(f, *args, argnums=0, has_aux=False)
¶
Detect global Hessian sparsity pattern for a scalar-valued f.
Analyzes the Jacobian sparsity of the gradient function, without evaluating any derivatives. The result is valid for all inputs.
If f returns a squeezable shape like (1,) or (1, 1),
it is automatically squeezed to scalar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable
|
Scalar-valued function taking one or more positional arrays. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to (default |
0
|
has_aux
|
bool
|
Whether |
False
|
Returns:
| Type | Description |
|---|---|
SparsityPattern
|
Square SparsityPattern over the combined, selected input space. |
Data Structures¶
asdex.SparsityPattern
dataclass
¶
Sparse matrix pattern storing only structural information (no values).
Stores row and column indices separately for efficient access by the coloring and decompression stages.
Attributes:
| Name | Type | Description |
|---|---|---|
rows |
NDArray[int32]
|
Row indices of non-zero entries, shape |
cols |
NDArray[int32]
|
Column indices of non-zero entries, shape |
shape |
tuple[int, int]
|
Matrix dimensions |
input_avals |
tuple[Any, ...]
|
One pytree of |
argnums |
int | tuple[int, ...]
|
Positions of |
dyn_avals
property
¶
Sub-tuple of input_avals selected by argnums.
example_input
property
¶
The aval structure the returned Jacobian / Hessian mirrors.
When argnums is an int this is the single selected aval;
when argnums is a tuple this is the tuple of selected avals.
Matches jax/_src/api.py:746 (jacfwd) and line 840 (jacrev).
leaf_shapes
property
¶
Per-leaf shapes of the selected (differentiated) inputs.
leaf_sizes
property
¶
Per-leaf flat sizes (prod(shape)) of the selected inputs.
input_treedef
property
¶
Pytree structure of dyn_avals.
nnz
property
¶
Number of non-zero elements.
m
property
¶
Number of rows.
n
property
¶
Number of columns.
density
property
¶
Fraction of non-zero entries.
col_to_rows
cached
property
¶
Mapping from column index to list of row indices with non-zeros.
Used by the coloring algorithm to build the row conflict graph.
row_to_cols
cached
property
¶
Mapping from row index to list of column indices with non-zeros.
Used by the coloring algorithm to build the column conflict graph.
__post_init__()
¶
Validate inputs and fill in the default single-leaf aval.
from_coo(rows, cols, shape, *, input_avals=(), argnums=0)
classmethod
¶
Create pattern from row and column index arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rows
|
NDArray[int32] | list[int]
|
Row indices of non-zero entries. |
required |
cols
|
NDArray[int32] | list[int]
|
Column indices of non-zero entries. |
required |
shape
|
tuple[int, int]
|
Matrix dimensions |
required |
input_avals
|
tuple[Any, ...]
|
One pytree of |
()
|
argnums
|
int | tuple[int, ...]
|
Positions of |
0
|
from_bcoo(bcoo)
classmethod
¶
Create pattern from JAX BCOO sparse matrix.
from_dense(dense)
classmethod
¶
Create pattern from dense boolean/numeric matrix.
Non-zero entries indicate pattern positions.
to_bcoo(data=None)
¶
Convert to JAX BCOO sparse matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
ndarray | None
|
Optional data values. If None, uses all 1s. |
None
|
todense()
¶
Convert to dense numpy array with 1s at pattern positions.
save(path)
¶
Save sparsity pattern to an .npz file.
Supports multi-input and PyTree-structured patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | PathLike[str]
|
Destination file path. |
required |
load(path)
classmethod
¶
Load sparsity pattern from an .npz file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | PathLike[str]
|
Source file path. |
required |
__str__()
¶
Render sparsity pattern with header and dot/braille grid.
__repr__()
¶
Return compact single-line representation.
asdex.ColoredPattern
dataclass
¶
Result of a graph coloring for sparse differentiation.
Attributes:
| Name | Type | Description |
|---|---|---|
sparsity |
SparsityPattern
|
The sparsity pattern that was colored. |
colors |
NDArray[int32]
|
Color assignment array.
Shape |
num_colors |
int
|
Total number of active colors (number of JVPs/VJPs/HVPs). |
symmetric |
bool
|
Whether symmetric (star) coloring was used. |
mode |
ColoringMode
|
The AD mode.
Resolved, never |
star_set |
StarSet | None
|
Star-coloring structure (hub/spoke assignment per edge).
Present only for symmetric colorings produced by
|
save(path)
¶
Save colored pattern to an .npz file.
Supports multi-input and PyTree-structured patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | PathLike[str]
|
Destination file path. |
required |
load(path)
classmethod
¶
Load colored pattern from an .npz file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | PathLike[str]
|
Source file path. |
required |
__repr__()
¶
Return compact single-line representation.
__str__()
¶
Render colored pattern with sparsity grid and color assignments.
asdex.JacobianMode = Literal['fwd', 'rev']
module-attribute
¶
AD mode for Jacobian computation.
"fwd" uses JVPs (forward-mode AD),
"rev" uses VJPs (reverse-mode AD).
asdex.HessianMode = Literal['fwd_over_rev', 'rev_over_fwd', 'rev_over_rev']
module-attribute
¶
AD composition strategy for Hessian-vector products.
"fwd_over_rev" uses forward-over-reverse,
"rev_over_fwd" uses reverse-over-forward,
"rev_over_rev" uses reverse-over-reverse.
asdex.VerificationError
¶
Bases: AssertionError
Raised when asdex's sparse result does not match JAX's dense reference.
This indicates that the detected sparsity pattern is missing nonzeros, which is a bug — asdex's patterns should always be conservative (i.e., contain at least all true nonzeros). If you encounter this error, please help out asdex's development by reporting this at https://github.com/adrhill/asdex/issues.