Hessian¶
Hessian Computation¶
asdex.hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)
¶
Detect sparsity, color, and return a function computing sparse Hessians.
If f returns a squeezable shape like (1,) or (1, 1),
it is automatically squeezed to scalar.
For repeated evaluation, wrap the returned function in jax.jit:
each unjitted call re-traces f,
which can cost far more than the differentiation itself.
The "numpy_dense" and scipy output formats cannot be jitted
since they produce non-JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
*sample_args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to.
Defaults to |
_DEFAULT_ARGNUMS
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
_DEFAULT_MODE
|
symmetric
|
bool
|
Whether to use symmetric coloring.
Defaults to |
_DEFAULT_SYMMETRIC_HESSIAN
|
output_format
|
OutputFormat
|
Type of the output matrix.
|
_DEFAULT_OUTPUT_FORMAT
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
**sample_kwargs
|
Any
|
Sample keyword arguments of |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function that takes the same positional args as |
asdex.hessian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)
¶
Build a sparse Hessian function from a pre-computed coloring.
Uses symmetric (star) coloring and Hessian-vector products by default.
For repeated evaluation, wrap the returned function in jax.jit:
each unjitted call re-traces f,
which can cost far more than the differentiation itself.
The "numpy_dense" and scipy output formats cannot be jitted
since they produce non-JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
coloring
|
ColoredPattern
|
Pre-computed colored sparsity pattern of type |
required |
output_format
|
OutputFormat
|
Type of the output matrix.
|
_DEFAULT_OUTPUT_FORMAT
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
asdex.value_and_hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, output_format=_DEFAULT_OUTPUT_FORMAT, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)
¶
Detect sparsity, color, and return a function computing value and sparse Hessian.
Like hessian, but also returns the primal value
f(*args) without an extra forward pass.
For repeated evaluation, wrap the returned function in jax.jit:
each unjitted call re-traces f,
which can cost far more than the differentiation itself.
The "numpy_dense" and scipy output formats cannot be jitted
since they produce non-JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
*sample_args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to.
Defaults to |
_DEFAULT_ARGNUMS
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
_DEFAULT_MODE
|
symmetric
|
bool
|
Whether to use symmetric coloring.
Defaults to |
_DEFAULT_SYMMETRIC_HESSIAN
|
output_format
|
OutputFormat
|
Type of the output matrix.
|
_DEFAULT_OUTPUT_FORMAT
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
**sample_kwargs
|
Any
|
Sample keyword arguments of |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function that takes the same positional args as |
asdex.value_and_hessian_from_coloring(f, coloring, output_format=_DEFAULT_OUTPUT_FORMAT, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)
¶
Build a function computing value and sparse Hessian from a pre-computed coloring.
For repeated evaluation, wrap the returned function in jax.jit:
each unjitted call re-traces f,
which can cost far more than the differentiation itself.
The "numpy_dense" and scipy output formats cannot be jitted
since they produce non-JAX arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
coloring
|
ColoredPattern
|
Pre-computed colored sparsity pattern of type |
required |
output_format
|
OutputFormat
|
Type of the output matrix.
|
_DEFAULT_OUTPUT_FORMAT
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
Compressed Hessian¶
asdex.compressed_hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)
¶
Detect sparsity, color, and return a function computing the compressed Hessian.
Runs the same detect-and-color steps as hessian,
but stops at the dense compressed matrix B of shape (num_colors, n):
one HVP per color, before decompression scatters B into the pattern.
n is the flattened size of the differentiated inputs
(the leaves selected by argnums), so the Hessian is (n, n).
Recover the sparse matrix with decompress or
decompress_data,
or work with B directly.
The returned B is a plain jax.Array,
so the returned function is jit-able by the caller.
Unlike hessian, it takes no output_format:
formatting is the job of decompress.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
*sample_args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to.
Defaults to |
_DEFAULT_ARGNUMS
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
_DEFAULT_MODE
|
symmetric
|
bool
|
Whether to use symmetric coloring.
Defaults to |
_DEFAULT_SYMMETRIC_HESSIAN
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
**sample_kwargs
|
Any
|
Sample keyword arguments of |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function that takes the same positional args as |
asdex.compressed_hessian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)
¶
Build a compressed Hessian function from a pre-computed coloring.
Like hessian_from_coloring,
but stops at the compressed matrix B of shape (num_colors, n)
instead of materializing the sparse matrix.
See compressed_hessian for B's layout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
coloring
|
ColoredPattern
|
Pre-computed colored sparsity pattern of type |
required |
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function returning |
asdex.value_and_compressed_hessian(f, *sample_args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, chunk_size=_DEFAULT_CHUNK_SIZE, **sample_kwargs)
¶
Like compressed_hessian, also returning the value.
The primal value f(*args) rides the HVP forward pass,
so it is nearly free.
See compressed_hessian for B's layout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
*sample_args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to.
Defaults to |
_DEFAULT_ARGNUMS
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
_DEFAULT_MODE
|
symmetric
|
bool
|
Whether to use symmetric coloring.
Defaults to |
_DEFAULT_SYMMETRIC_HESSIAN
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
**sample_kwargs
|
Any
|
Sample keyword arguments of |
{}
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function that takes the same positional args as |
asdex.value_and_compressed_hessian_from_coloring(f, coloring, *, has_aux=_DEFAULT_HAS_AUX, holomorphic=_DEFAULT_HOLOMORPHIC, allow_int=_DEFAULT_ALLOW_INT, chunk_size=_DEFAULT_CHUNK_SIZE)
¶
Value and compressed Hessian from a pre-computed coloring.
Like value_and_hessian_from_coloring,
but stops at the compressed matrix B.
See compressed_hessian for B's layout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be computed. |
required |
coloring
|
ColoredPattern
|
Pre-computed colored sparsity pattern of type |
required |
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
holomorphic
|
bool
|
Whether |
_DEFAULT_HOLOMORPHIC
|
allow_int
|
bool
|
Unsupported for Hessians; passing |
_DEFAULT_ALLOW_INT
|
chunk_size
|
int | None
|
Maximum number of colors to process in parallel.
Defaults to |
_DEFAULT_CHUNK_SIZE
|
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
A function returning |
Decompression¶
asdex.decompress(compressed, coloring, output_format=_DEFAULT_OUTPUT_FORMAT)
¶
Decompress a compressed matrix B into a 2-D sparse matrix.
Composes decompress_data with format dispatch:
it gathers B into the sparse values,
then materializes the flat (m, n) matrix in the requested format.
Unlike the matrices returned by jacobian /
hessian,
this is always the flat 2-D matrix regardless of input/output pytree structure:
B's natural domain is the 2-D compressed matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
compressed
|
Array
|
The compressed matrix |
required |
coloring
|
ColoredPattern
|
The |
required |
output_format
|
OutputFormat
|
Type of the output matrix.
|
_DEFAULT_OUTPUT_FORMAT
|
Returns:
| Type | Description |
|---|---|
Any
|
The sparse matrix of shape |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
ImportError
|
If a scipy |
asdex.decompress_data(compressed, coloring)
¶
Gather a compressed matrix B into sparse values in pattern order.
Returns a plain jax.Array of shape (coloring.sparsity.nnz,)
holding the sparse values in coloring.sparsity order,
so data[k] is the entry at
(coloring.sparsity.rows[k], coloring.sparsity.cols[k]).
This is the jittable numeric core of decompression:
it always returns a jax.Array, so it composes inside jax.jit
and can feed a custom solver or sparse format,
whereas decompress may return host
(numpy/scipy) objects that cannot.
Pair it with to_bcoo for a BCOO,
or with coloring.sparsity.rows / coloring.sparsity.cols
to assemble a custom format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
compressed
|
Array
|
The compressed matrix |
required |
coloring
|
ColoredPattern
|
The |
required |
Returns:
| Type | Description |
|---|---|
Array
|
A |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Coloring¶
asdex.hessian_coloring(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, postprocess=_DEFAULT_POSTPROCESS, **kwargs)
¶
Detect Hessian sparsity and color in one step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable
|
Scalar-valued function whose Hessian is to be computed. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to.
Defaults to |
_DEFAULT_ARGNUMS
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
_DEFAULT_MODE
|
symmetric
|
bool
|
Whether to use symmetric coloring.
Defaults to |
_DEFAULT_SYMMETRIC_HESSIAN
|
postprocess
|
bool
|
Only read when |
_DEFAULT_POSTPROCESS
|
**kwargs
|
Any
|
Sample keyword arguments of |
{}
|
Returns:
| Type | Description |
|---|---|
ColoredPattern
|
A |
asdex.hessian_coloring_from_sparsity(sparsity, *, mode=_DEFAULT_MODE, symmetric=_DEFAULT_SYMMETRIC_HESSIAN, postprocess=_DEFAULT_POSTPROCESS)
¶
Color a sparsity pattern for sparse Hessian computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparsity
|
SparsityPattern | NDArray | BCOO
|
A |
required |
mode
|
HessianMode | None
|
AD composition strategy for Hessian-vector products.
|
_DEFAULT_MODE
|
symmetric
|
bool
|
Whether to use symmetric coloring.
Defaults to |
_DEFAULT_SYMMETRIC_HESSIAN
|
postprocess
|
bool
|
Only read when |
_DEFAULT_POSTPROCESS
|
Returns:
| Type | Description |
|---|---|
ColoredPattern
|
A |
Sparsity Detection¶
asdex.hessian_sparsity(f, *args, argnums=_DEFAULT_ARGNUMS, has_aux=_DEFAULT_HAS_AUX, **kwargs)
¶
Detect global Hessian sparsity pattern for a scalar-valued f.
Analyzes the Jacobian sparsity of the gradient function, without evaluating any derivatives. The result is valid for all inputs.
If f returns a squeezable shape like (1,) or (1, 1),
it is automatically squeezed to scalar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable
|
Scalar-valued function whose Hessian sparsity pattern is to be detected. |
required |
*args
|
Any
|
Sample arguments of |
()
|
argnums
|
int | Sequence[int]
|
Specifies which positional argument(s) to differentiate
with respect to.
Defaults to |
_DEFAULT_ARGNUMS
|
has_aux
|
bool
|
Whether |
_DEFAULT_HAS_AUX
|
**kwargs
|
Any
|
Sample keyword arguments of |
{}
|
Returns:
| Type | Description |
|---|---|
SparsityPattern
|
Square SparsityPattern over the combined, selected input space. |
Verification¶
asdex.check_hessian_correctness(f, x, coloring, *, method=_DEFAULT_VERIFY_METHOD, num_probes=_DEFAULT_NUM_PROBES, seed=_DEFAULT_SEED, rtol=_DEFAULT_TOL, atol=_DEFAULT_TOL)
¶
Verify asdex's sparse Hessian against a JAX reference at a given input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[..., Any]
|
Scalar-valued function whose Hessian is to be verified. |
required |
x
|
Any
|
Input at which to evaluate the Hessian.
For multi-input functions (where |
required |
coloring
|
ColoredPattern
|
Pre-computed colored pattern from |
required |
method
|
Literal['matvec', 'dense']
|
Verification method.
|
_DEFAULT_VERIFY_METHOD
|
num_probes
|
int
|
Number of random probe vectors (only used by |
_DEFAULT_NUM_PROBES
|
seed
|
int
|
PRNG seed for reproducibility (only used by |
_DEFAULT_SEED
|
rtol
|
float | None
|
Relative tolerance for comparison.
Defaults to 1e-05 for |
_DEFAULT_TOL
|
atol
|
float | None
|
Absolute tolerance for comparison.
Defaults to 1e-05 for |
_DEFAULT_TOL
|
Raises:
| Type | Description |
|---|---|
VerificationError
|
If the sparse and reference Hessians disagree. |
Configuration¶
asdex.HessianMode = Literal['fwd_over_rev', 'rev_over_fwd', 'rev_over_rev']
module-attribute
¶
AD composition strategy for Hessian-vector products.
"fwd_over_rev" uses forward-over-reverse,
"rev_over_fwd" uses reverse-over-forward,
"rev_over_rev" uses reverse-over-reverse.