Skip to content

Data Structures

asdex.ColoredPattern dataclass

Result of a graph coloring for sparse differentiation.

Attributes:

Name Type Description
sparsity SparsityPattern

The sparsity pattern that was colored.

colors NDArray[int32]

Color assignment array. Shape (m,) for "rev" mode, (n,) for all other modes. A value of -1 means "neutral": the vertex is not seeded (used after star-coloring postprocessing).

num_colors int

Total number of active colors (number of JVPs/VJPs/HVPs).

symmetric bool

Whether symmetric (star) coloring was used.

mode ColoringMode

The AD mode. Resolved, never "auto". "fwd" uses JVPs (forward-mode AD), "rev" uses VJPs (reverse-mode AD), "fwd_over_rev" uses forward-over-reverse HVPs, "rev_over_fwd" uses reverse-over-forward HVPs, "rev_over_rev" uses reverse-over-reverse HVPs.

star_set StarSet | None

Star-coloring structure (hub/spoke assignment per edge). Present only for symmetric colorings produced by color_symmetric; None otherwise.

save(path)

Save colored pattern to an .npz file.

Supports multi-input and PyTree-structured patterns.

Parameters:

Name Type Description Default
path str | PathLike[str]

Destination file path.

required

load(path) classmethod

Load colored pattern from an .npz file.

Parameters:

Name Type Description Default
path str | PathLike[str]

Source file path.

required

__repr__()

Return compact single-line representation.

__str__()

Render colored pattern with sparsity grid and color assignments.

asdex.SparsityPattern dataclass

Sparse matrix pattern storing only structural information (no values).

Stores row and column indices separately for efficient access by the coloring and decompression stages.

Attributes:

Name Type Description
rows NDArray[int32]

Row indices of non-zero entries, shape (nnz,).

cols NDArray[int32]

Column indices of non-zero entries, shape (nnz,).

shape tuple[int, int]

Matrix dimensions (m, n).

input_avals tuple[Any, ...]

One pytree of jax.ShapeDtypeStruct per positional argument of the traced function, in the same order jax.eval_shape(fun, *args) expects. Positions not in argnums are still stored (so the full input structure is preserved), but they do not contribute columns to the Jacobian / rows to the Hessian.

argnums int | tuple[int, ...]

Positions of input_avals that were differentiated, mirroring jax.grad / jax.jacfwd. An int stays int and a sequence becomes tuple[int, ...] — that distinction drives whether example_input is a single aval or a tuple of avals.

dyn_avals property

Sub-tuple of input_avals selected by argnums.

example_input property

The aval structure the returned Jacobian / Hessian mirrors.

When argnums is an int this is the single selected aval; when argnums is a tuple this is the tuple of selected avals. Matches jax/_src/api.py:746 (jacfwd) and line 840 (jacrev).

leaf_shapes property

Per-leaf shapes of the selected (differentiated) inputs.

leaf_sizes property

Per-leaf flat sizes (prod(shape)) of the selected inputs.

input_treedef property

Pytree structure of dyn_avals.

nnz property

Number of non-zero elements.

m property

Number of rows.

n property

Number of columns.

density property

Fraction of non-zero entries.

col_to_rows cached property

Mapping from column index to list of row indices with non-zeros.

Used by the coloring algorithm to build the row conflict graph.

row_to_cols cached property

Mapping from row index to list of column indices with non-zeros.

Used by the coloring algorithm to build the column conflict graph.

__post_init__()

Validate inputs and fill in the default single-leaf aval.

from_coo(rows, cols, shape, *, input_avals=(), argnums=0) classmethod

Create pattern from row and column index arrays.

Parameters:

Name Type Description Default
rows NDArray[int32] | list[int]

Row indices of non-zero entries.

required
cols NDArray[int32] | list[int]

Column indices of non-zero entries.

required
shape tuple[int, int]

Matrix dimensions (m, n).

required
input_avals tuple[Any, ...]

One pytree of ShapeDtypeStruct per positional argument of the traced function. Defaults to a single 1-D aval of size n.

()
argnums int | tuple[int, ...]

Positions of input_avals that were differentiated, mirroring jax.grad / jax.jacfwd.

0

from_bcoo(bcoo) classmethod

Create pattern from JAX BCOO sparse matrix.

from_dense(dense) classmethod

Create pattern from dense boolean/numeric matrix.

Non-zero entries indicate pattern positions.

to_bcoo(data=None)

Convert to JAX BCOO sparse matrix.

Parameters:

Name Type Description Default
data ndarray | None

Optional data values. If None, uses all 1s.

None

todense()

Convert to dense numpy array with 1s at pattern positions.

save(path)

Save sparsity pattern to an .npz file.

Supports multi-input and PyTree-structured patterns.

Parameters:

Name Type Description Default
path str | PathLike[str]

Destination file path.

required

load(path) classmethod

Load sparsity pattern from an .npz file.

Parameters:

Name Type Description Default
path str | PathLike[str]

Source file path.

required

__str__()

Render sparsity pattern with header and dot/braille grid.

__repr__()

Return compact single-line representation.