Sparse Automatic Differentiation

Group seminar at COSMO lab

Adrian Hill

TU Berlin

Guillaume Dalle

EPFL

2024-09-19

Introduction to Differentiation

Motivation

What is a derivative?

A linear approximation of a function around a point.

Why do we care?

Derivatives of computer code are essential in optimization and machine learning.

What do I need to do?

Not much: with Automatic Differentiation (AD), derivatives are easy to compute!

Flavors of differentiation

  • Manual: work out \(f'\) by hand
  • Numeric: \(f'(x) \approx \frac{f(x+\varepsilon) - f(x)}{\varepsilon}\)
  • Symbolic: code a formula for \(f\), get a formula for \(f'\)
  • Automatic: code a program for \(f\), get a value for \(f'(x)\)

Introduction to AD

Automatic differentiation

Three key ideas (Griewank and Walther 2008):

  1. Programs are composition chains (or DAGs) of many functions
  2. Jacobian of \(f = f_L \circ \dots \circ f_2 \circ f_1\) given by the chain rule: \[ J = J_L J_{L-1} \dots J_2 J_1 \]
  3. Avoid materializing full Jacobians with matrix-vector products: we only need \(Jv\) and \(v^\top J\)

Pushforwards and Pullbacks

Let’s introduce our notation for matrix-free linear maps:

  • Jacobian-Vector Products (JVPs), aka pushforwards \[ J(v) = Jv \]
  • Vector-Jacobian Products (VJPs), aka pullbacks \[ J^\top(w) = J^\top w = (w^\top J)^\top \]

Forward mode

JVPs /pushforwards are naturally decomposed from \(1\) to \(L\): \[ \begin{align} Jv = J_L \cdot J_{L-1} \cdot\ldots\cdot J_2 \cdot J_1 v\\ J(v) = J_L (J_{L-1}(\dots J_2(J_1(v)))) \end{align} \]

  • \(J(e_i)\) computes \(i\)-th column of Jacobian \(J\).
  • For \(f: \mathbb{R}^n \rightarrow \mathbb{R}^m\), the \(m \times n\) Jacobian requires \(n\) JVPs: one per input dimension.

Special case

The derivative of \(f : \mathbb{R} \rightarrow \mathbb{R}^m\) requires just one JVP: \(J(1)\).

Reverse mode

VJPs / pullbacks are naturally decomposed from \(L\) to \(1\): \[ \begin{align} w^\top J &= w^\top J_L \cdot J_{L-1} \cdot\ldots\cdot J_2 \cdot J_1\\ J^\top(w) &= J_1^\top(J_2^\top(\dots J_{L-1}^\top(J_L^\top(w)))) \end{align} \]

  • \(J^\top(e_i)\) computes \(i\)-th row of Jacobian \(J\).
  • For \(f: \mathbb{R}^n \rightarrow \mathbb{R}^m\), the \(m \times n\) Jacobian requires \(m\) VJPs: one per output dimension.

Special case

The gradient of \(f : \mathbb{R}^n \rightarrow \mathbb{R}\) requires just one VJP: \(J^\top(1)\).

Comparison

Forward mode (JVPs) Reverse mode (VJPs)
\(J(v) = J_L (\dots(J_1(v)))\) \(J^\top(w) = J_1^\top(\dots(J_L^\top(w)))\)
computes Jacobian column-wise computes Jacobian row-wise
forward sweep only forward + reverse sweep
often based on dual numbers often based on tapes
low memory cost high memory cost

How about gradients?

For scalar functions \(f: \mathbb{R}^n \rightarrow \mathbb{R}\), the gradient corresponds to the transpose of the Jacobian

  • n JVPs in forward mode: \(J(e_i)\) for \(i=1\ldots n\)
  • 1 VJP in reverse mode: \(J^\top(1)\)

Deep Learning

When computing gradients of scalar loss functions, reverse mode (“backpropagation”) is vastly more efficient than forward mode.

How about Hessians?

Defined for scalar functions, computed by nesting two AD calls:

  • pushforwards OR pullbacks over a gradient computation yield Hessian-vector products (HVPs)
  • analogous to Jacobians, Hessians are computed via HVPs
  • modes can be arbitrarily composed (in theory)
  • usual choice: forward over reverse-mode

Introduction to Sparse AD

Core Idea

Use graph coloring to find orthogonal columns/rows

(Gebremedhin, Manne, and Pothen 2005)

Core Idea

  1. JVPs & VJPs are linear maps
  2. Linear maps are additive \[ J(e_i+\ldots+e_j) = J(e_i) +\ldots+ J(e_j) \]
  3. If the RHS summands are orthogonal and their structure is known, the sum can be decomposed

We can compute several columns/rows of the Jacobian in a single JVP/VJP!

Sparsity Pattern Detection

How it works

  • For \(x \in \mathbb{R}^n\), the gradient of \(f\) is defined as \[ \left(\nabla f(x)\right)_{i} = \frac{\partial f}{\partial x_i} \]

  • sparsity patterns correspond to mask of non-zero values

  • can efficiently be represented by the set of indices corresponding to non-zero values: \[ \left\{i \;\big|\; \frac{\partial f}{\partial x_i} \neq 0\right\} \]

  • Related work: (Walther 2008), (Walther 2012), (Varnik et al. 2011)

  • Open source code: SparseConnectivityTracer.jl

Motivating example

\[ f(x) = x_1 + x_2x_3 + \text{sgn}(x_4) \quad \nabla f(x) = \begin{bmatrix} 1 & x_3 & x_2 & 0 \end{bmatrix}^\top \]

flowchart LR
    subgraph Inputs
    X1["x1"]
    X2["x2"]
    X3["x3"]
    X4["x4"]
    end

    PLUS((+))
    TIMES((*))
    SIGN((sgn))
    PLUS2((+))

    X1 --> |"{1}"| PLUS
    X2 --> |"{2}"| TIMES
    X3 --> |"{3}"| TIMES
    X4 --> |"{4}"| SIGN
    TIMES  --> |"{2,3}"| PLUS
    PLUS --> |"{1,2,3}"| PLUS2
    SIGN --> |"{}"| PLUS2

    PLUS2 --> |"{1,2,3}"| RES["y=f(x)"]

Toy implementation (1)

import Base: +, *, sign

struct Tracer
    indexset::Set
end

Base.:+(a::Tracer, b::Tracer) = Tracer(union(a.indexset, b.indexset))
Base.:*(a::Tracer, b::Tracer) = Tracer(union(a.indexset, b.indexset))
Base.sign(x::Tracer) = Tracer(Set()) # return empty index set

Toy implementation (2)

Let’s test this on our motivating example:

xtracer = [
    Tracer(Set(1)),
    Tracer(Set(2)),
    Tracer(Set(3)),
    Tracer(Set(4)),
]

f(x) = x[1] + x[2]*x[3] * sign(x[4])
ytracer = f(xtracer)
Tracer(Set(Any[2, 3, 1]))

\[ \nabla f(x) = \begin{bmatrix} 1 & x_3 & x_2 & 0 \end{bmatrix}^\top \]

Toy implementation (3)

It can also be used for Jacobian computations:

g(x) = [x[1], x[2]*x[3], x[1]+x[4]]
g(xtracer)
3-element Vector{Tracer}:
 Tracer(Set([1]))
 Tracer(Set([2, 3]))
 Tracer(Set([4, 1]))

Matches correct Jacobian \[ J_g(x)= \begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & x_3 & x_2 & 0 \\ 1 & 0 & 0 & 1 \end{pmatrix} \]

Demonstration: SCT

SparseConnectivityTracer Demo (1)

Let’s compare our toy implementation to SCT:

using SparseConnectivityTracer
pattern = jacobian_sparsity(g, x, TracerSparsityDetector())
3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 5 stored entries:
 1  ⋅  ⋅  ⋅
 ⋅  1  1  ⋅
 1  ⋅  ⋅  1
greedy_column_coloring(pattern)
[ Info: Compressed pattern of size 3 × 4 using 2 colors

\[ J_g(x)= \begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & x_3 & x_2 & 0 \\ 1 & 0 & 0 & 1 \end{pmatrix} \]

SparseConnectivityTracer Demo (2)

More complex example: convolutional layers

using SparseConnectivityTracer, Flux

layer = Conv((3, 3), 3 => 1);
x = rand(10, 10, 3, 1);

pattern = jacobian_sparsity(layer, x, TracerSparsityDetector())
64×300 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1728 stored entries:
⎡⠙⢾⡮⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠙⠮⣷⣷⣦⡀⠀⠀⠀⠀⠀⠀⠀⠘⠿⣷⢽⣦⣄⠀⠀⠀⠀⠀⠀⠀⠀⎤
⎢⠀⠀⠉⠻⣿⣿⡦⣄⠀⠀⠀⠀⠀⠀⠀⠈⠻⢿⡯⣷⣄⠀⠀⠀⠀⠀⠀⠀⠈⠙⢿⣿⣷⢤⡀⠀⠀⠀⠀⠀⎥
⎢⠀⠀⠀⠀⠀⠙⠯⣿⣿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠉⠻⣿⢽⣦⣄⠀⠀⠀⠀⠀⠀⠀⠉⠻⢽⣿⣶⣄⠀⠀⠀⎥
⎢⠀⠀⠀⠀⠀⠀⠀⠈⠻⢿⡯⣷⣤⡀⠀⠀⠀⠀⠀⠀⠈⠙⢿⣿⣷⢤⡀⠀⠀⠀⠀⠀⠀⠀⠙⠿⣿⢽⣦⡀⎥
⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠈⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠈⠁⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠁⠁⠉⎦
greedy_column_coloring(pattern)
[ Info: Compressed pattern of size 64 × 300 using 27 colors

SparseConnectivityTracer Demo (3)

Even more complex: sparsity detection over an ODE solve

using SimpleDiffEq: ODEProblem, solve, SimpleEuler

# Define ODE of Brusselator
function brusselator_f(x, y, t)
    return ifelse((((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) && (t >= 1.1), 5.0, 0.0)
end

limit(a, N) =
    if a == N + 1
        1
    elseif a == 0
        N
    else
        a
    end

function brusselator_2d_loop!(du, u, p, t)
    A, B, alpha, xyd, dx, N = p
    alpha = alpha / dx^2
    @inbounds for I in CartesianIndices((N, N))
        i, j = Tuple(I)
        x, y = xyd[I[1]], xyd[I[2]]
        ip1, im1, jp1, jm1 = limit(i + 1, N),
        limit(i - 1, N), limit(j + 1, N),
        limit(j - 1, N)
        du[i, j, 1] =
            alpha *
            (u[im1, j, 1] + u[ip1, j, 1] + u[i, jp1, 1] + u[i, jm1, 1] - 4u[i, j, 1]) +
            B +
            u[i, j, 1]^2 * u[i, j, 2] - (A + 1) * u[i, j, 1] + brusselator_f(x, y, t)
        du[i, j, 2] =
            alpha *
            (u[im1, j, 2] + u[ip1, j, 2] + u[i, jp1, 2] + u[i, jm1, 2] - 4u[i, j, 2]) +
            A * u[i, j, 1] - u[i, j, 1]^2 * u[i, j, 2]
    end
end

# Define Brusselator
struct Brusselator!{P}
    N::Int
    params::P
end

function Brusselator!(N::Integer)
    dims = (N, N, 2)
    A = 1.0
    B = 1.0
    alpha = 1.0
    xyd = fill(1.0, N)
    dx = 1.0
    params = (; A, B, alpha, xyd, dx, N)
    return Brusselator!(N, params)
end

(b!::Brusselator!)(y, x) = brusselator_2d_loop!(y, x, b!.params, nothing)

# Define ODE problem
N = 6
f! = Brusselator!(N)
x = rand(N, N, 2)
y = similar(x)
solver = SimpleEuler()
prob = ODEProblem(brusselator_2d_loop!, x, (0.0, 1.0), f!.params)

# Call ODE solver
function brusselator_ode_solve(x)
    return solve(ODEProblem(brusselator_2d_loop!, x, (0.0, 1.0), f!.params), solver; dt=0.5).u[end]
end

SparseConnectivityTracer Demo (3)

pattern = jacobian_sparsity(brusselator_ode_solve, x, TracerSparsityDetector())
greedy_column_coloring(pattern)
[ Info: Compressed pattern of size 72 × 72 using 37 colors

Demonstration: DifferentiationInterface

Software Implementation

MIT licensed, well documented, thoroughly tested:

DifferentiationInterface.jl Downloads
SparseConnectivityTracer.jl Downloads
SparseMatrixColorings.jl Downloads

Dense Jacobian computation

using DifferentiationInterface
using SparseConnectivityTracer
import ForwardDiff

f(x) = diff(x .^ 2) + diff(reverse(x .^ 2))
x = [1.0, 2.0, 3.0, 4.0, 5.0]

backend = AutoForwardDiff()
jacobian(f, backend, x)
4×5 Matrix{Float64}:
 -2.0   4.0    0.0   8.0  -10.0
  0.0  -4.0   12.0  -8.0    0.0
  0.0   4.0  -12.0   8.0    0.0
  2.0  -4.0    0.0  -8.0   10.0

Sparse Jacobian computation

sparse_backend = AutoSparse(
    backend;
    sparsity_detector=TracerSparsityDetector(),
    coloring_algorithm=GreedyColoringAlgorithm(),
)
jacobian(f, sparse_backend, x)
4×5 SparseArrays.SparseMatrixCSC{Float64, Int64} with 14 stored entries:
 -2.0   4.0     ⋅    8.0  -10.0
   ⋅   -4.0   12.0  -8.0     ⋅ 
   ⋅    4.0  -12.0   8.0     ⋅ 
  2.0  -4.0     ⋅   -8.0   10.0

Performance comparison

using BenchmarkTools
n = 1000
x = randn(n)

extras_dense = prepare_jacobian(f, backend, randn(n))
@benchmark jacobian($f, $extras_dense, $backend, $x)
BenchmarkTools.Trial: 1016 samples with 1 evaluation.
 Range (min … max):  4.618 ms …  11.892 ms  ┊ GC (min … max): 4.86% … 38.26%
 Time  (median):     4.839 ms               ┊ GC (median):    5.28%
 Time  (mean ± σ):   4.910 ms ± 389.165 μs  ┊ GC (mean ± σ):  5.69% ±  1.73%

       ▅▇█▅▄▂                                                  
  ▂▃▂▃▇███████▇▅▅▄▃▄▃▃▃▃▃▃▂▃▂▃▂▁▂▂▁▁▁▂▁▁▁▂▁▁▁▁▂▁▁▁▂▁▁▁▂▁▁▁▂▁▂ ▃
  4.62 ms         Histogram: frequency by time        6.12 ms <

 Memory estimate: 57.62 MiB, allocs estimate: 1011.
extras_sparse = prepare_jacobian(f, sparse_backend, rand(n))
@benchmark jacobian($f, $extras_sparse, $sparse_backend, $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  23.313 μs … 583.592 μs  ┊ GC (min … max): 0.00% … 87.72%
 Time  (median):     26.459 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   28.760 μs ±  27.892 μs  ┊ GC (mean ± σ):  6.24% ±  6.04%

        ▂▇█▆▂                                                   
  ▂▃▄▅▅▆█████▇▆▆▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  23.3 μs         Histogram: frequency by time         44.9 μs <

 Memory estimate: 305.25 KiB, allocs estimate: 17.

Conclusion

Applications

(This is a meme, it’s a bit more nuanced.)

Outlook

  • “Chunked” mode with statically sized index sets
  • GPU support

References

Gebremedhin, Assefaw Hadish, Fredrik Manne, and Alex Pothen. 2005. “What Color Is Your Jacobian? Graph Coloring for Computing Derivatives.” SIAM Review 47 (4): 629–705. https://doi.org/cmwds4.
Griewank, Andreas, and Andrea Walther. 2008. Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation. 2nd ed. Philadelphia, PA: Society for Industrial and Applied Mathematics.
Varnik, Ebadollah, Lukas Razik, Viktor Mosenkis, and Uwe Naumann. 2011. “Fast Conservative Estimation of Hessian Sparsity.” In Fifth SIAM Workshop on Combinatorial Scientific Computing, 18. Darmstadt, Germany. http://ftp.informatik.rwth-aachen.de/Publications/AIB/2011/2011-09.pdf#page=21.
Walther, Andrea. 2008. “Computing Sparse Hessians with Automatic Differentiation.” ACM Transactions on Mathematical Software 34 (1): 3:1–15. https://doi.org/10.1145/1322436.1322439.
———. 2012. “On the Efficient Computation of Sparsity Patterns for Hessians.” In Recent Advances in Algorithmic Differentiation, edited by Shaun Forth, Paul Hovland, Eric Phipps, Jean Utke, and Andrea Walther, 139–49. Berlin, Heidelberg: Springer. https://doi.org/10.1007/978-3-642-30023-3_13.

The Julia AD ecosystem

Three types of AD users

  1. Package users want to differentiate through functions
  2. Package developers want to write differentiable functions
  3. Backend developers want to create new AD systems

Python vs. Julia: user experience

Python vs. Julia: developers

Python vs. Julia: developers

Why so many backends?

  • Conflicting paradigms:
    • numeric vs. symbolic vs. algorithmic
    • operator overloading vs. source-to-source (which source?)
  • Cover varying subsets of the language
  • Historical reasons: developed by different people

DifferentiationInterface

Goals

  • DifferentiationInterface (DI) offers a common syntax for all AD backends1
  • AD users can compare correctness and performance without reading each documentation
  • AD developers get access to a wider user base

Supported packages

Getting started with DI

Step 1: load the necessary packages

using DifferentiationInterface
import ForwardDiff, Enzyme, Zygote

f(x) = sum(abs2, x)
x = [1.0, 2.0, 3.0, 4.0]

Step 2: Combine DI’s operators with a backend from ADTypes

value_and_gradient(f, AutoForwardDiff(), x)
(30.0, [2.0, 4.0, 6.0, 8.0])
value_and_gradient(f, AutoEnzyme(), x)
(30.0, [2.0, 4.0, 6.0, 8.0])
value_and_gradient(f, AutoZygote(), x)
(30.0, [2.0, 4.0, 6.0, 8.0])

Step 3: Increase performance via DI’s preparation mechanism

Features of DI

  • Support for functions f(x) or f!(y, x) with scalar/array inputs & outputs
  • Eight standard operators: pushforward, pullback, derivative, gradient, jacobian, hvp, second_derivative, hessian
  • Out-of-place and in-place versions
  • Combine different backends using SecondOrder
  • Translate between backends using DifferentiateWith