Adding Overloads

Internals may change

The developer documentation might refer to internals which can change without warning in a future release of SparseConnectivityTracer. Only functionality that is exported or part of the user documentation adheres to semantic versioning.

Having read our guide "How SparseConnectivityTracer works", you might want to add your own methods on GradientTracer, HessianTracer and Dual to improve the performance of your functions or to work around some of SCT's limitations.

Don't overload manually

If you want to overload a Function that takes Real arguments, we strongly discourage you from manually adding methods to your function that use our internal tracer types.

Instead, use the same code generation mechanisms that we use. This page shows you how.

Generated overloads

Copy one of our package extensions

The easiest way to add overloads is to copy one of our package extensions, e.g. our NNlib extension, and to modify it. Please upstream your additions by opening a pull request! We will help you out to get your feature merged.

Operator classification

SCT currently supports three types of functions:

  1. 1-to-1: operators with one input and one output
  2. 2-to-1: operators with two inputs and one output
  3. 1-to-2: operators with one input and two outputs

Depending on the type of function you're dealing with, you will have to specify the way in which your function is differentiable:

InOutExamplesMethods you need to implement
11sin, cos, absis_der1_zero_global, is_der2_zero_global
21+, *, >, isequalis_der1_arg1_zero_global, is_der2_arg1_zero_global, is_der1_arg2_zero_global, is_der2_arg2_zero_global, is_der_cross_zero_global
12sincosis_der1_out1_zero_global, is_der2_out1_zero_global, is_der1_out2_zero_global, is_der2_out2_zero_global
Methods you have to implement for 1-to-1 operators
FunctionMeaning
is_der1_zero_global(::typeof(f)) = false$\frac{\partial f}{\partial x} \neq 0$ for some $x$
is_der2_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for some $x$

Optionally, to increase the sparsity of TracerLocalSparsityDetector, you can additionally implement

FunctionMeaning
is_der1_zero_local(::typeof(f), x) = false$\frac{\partial f}{\partial x} \neq 0$ for given $x$
is_der2_zero_local(::typeof(f), x) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x$

These fall back to

is_der1_zero_local(f::F, x) where {F} = is_der1_zero_global(f)
is_der2_zero_local(f::F, x) where {F} = is_der2_zero_global(f)
Methods you have to implement for 2-to-1 operators
FunctionMeaning
is_der1_arg1_zero_global(::typeof(f)) = false$\frac{\partial f}{\partial x} \neq 0$ for some $x,y$
is_der2_arg1_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for some $x,y$
is_der1_arg2_zero_global(::typeof(f)) = false$\frac{\partial f}{\partial y} \neq 0$ for some $x,y$
is_der2_arg2_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial y^2} \neq 0$ for some $x,y$
is_der_cross_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial x \partial y} \neq 0$ for some $x,y$

Optionally, to increase the sparsity of TracerLocalSparsityDetector, you can additionally implement

FunctionMeaning
is_der1_arg1_zero_local(::typeof(f), x, y) = false$\frac{\partial f}{\partial x} \neq 0$ for given $x,y$
is_der2_arg1_zero_local(::typeof(f), x, y) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x,y$
is_der1_arg2_zero_local(::typeof(f), x, y) = false$\frac{\partial f}{\partial x} \neq 0$ for given $x,y$
is_der2_arg2_zero_local(::typeof(f), x, y) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x,y$
is_der_cross_zero_local(::typeof(f), x, y) = false$\frac{\partial^2 f}{\partial x \partial y} \neq 0$ for given $x,y$

These fall back to

is_der1_arg1_zero_local(f::F, x, y) where {F} = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_local(f::F, x, y) where {F} = is_der2_arg1_zero_global(f)
is_der1_arg2_zero_local(f::F, x, y) where {F} = is_der1_arg2_zero_global(f)
is_der2_arg2_zero_local(f::F, x, y) where {F} = is_der2_arg2_zero_global(f)
is_der_cross_zero_local(f::F, x, y) where {F} = is_der_cross_zero_global(f)
Methods you have to implement for 1-to-2 operators
FunctionMeaning
is_der1_out1_zero_local(::typeof(f)) = false$\frac{\partial f_1}{\partial x} \neq 0$ for some $x$
is_der2_out1_zero_local(::typeof(f)) = false$\frac{\partial^2 f_1}{\partial x^2} \neq 0$ for some $x$
is_der1_out2_zero_local(::typeof(f)) = false$\frac{\partial f_2}{\partial x} \neq 0$ for some $x$
is_der2_out2_zero_local(::typeof(f)) = false$\frac{\partial^2 f_2}{\partial x^2} \neq 0$ for some $x$

Optionally, to increase the sparsity of TracerLocalSparsityDetector, you can additionally implement

FunctionMeaning
is_der1_out1_zero_local(::typeof(f), x) = false$\frac{\partial f_1}{\partial x} \neq 0$ for given $x$
is_der2_out1_zero_local(::typeof(f), x) = false$\frac{\partial^2 f_1}{\partial x^2} \neq 0$ for given $x$
is_der1_out2_zero_local(::typeof(f), x) = false$\frac{\partial f_2}{\partial x} \neq 0$ for given $x$
is_der2_out2_zero_local(::typeof(f), x) = false$\frac{\partial^2 f_2}{\partial x^2} \neq 0$ for given $x$

These fall back to

is_der1_out1_zero_local(f::F, x) where {F} = is_der1_out1_zero_global(f)
is_der2_out1_zero_local(f::F, x) where {F} = is_der2_out1_zero_global(f)
is_der1_out2_zero_local(f::F, x) where {F} = is_der1_out2_zero_global(f)
is_der2_out2_zero_local(f::F, x) where {F} = is_der2_out2_zero_global(f)

Generating code

After implementing the required classification methods for a function, the function has not been overloaded on our tracer types yet. SCT provides three functions that generate code via meta-programming:

  • 1-to-1: eval(SCT.generate_code_1_to_1(module_symbol, f))
  • 2-to-1: eval(SCT.generate_code_1_to_2(module_symbol, f))
  • 1-to-2: eval(SCT.generate_code_2_to_1(module_symbol, f))

You are required to call the function that matches your type of operator.

Code generation

We will take a look at the code generation mechanism in the example below.

Example

For some examples on how to overload methods, take a look at our package extensions. Let's look at the relu activation function from ext/SparseConnectivityTracerNNlibExt.jl, which is a 1-to-1 operator defined as $\text{relu}(x) = \text{max}(0, x)$.

Step 1: Classification

The relu function has a non-zero first-order derivative $\frac{\partial f}{\partial x}=1$ for inputs $x>0$. The second derivative is zero everywhere. We therefore implement:

import SparseConnectivityTracer as SCT
using NNlib

SCT.is_der1_zero_global(::typeof(relu)) = false
SCT.is_der2_zero_global(::typeof(relu)) = true

SCT.is_der1_zero_local(::typeof(relu), x) = x < 0
import SparseConnectivityTracer

Note that we imported SCT to extend its operator classification methods on typeof(relu).

Step 2: Generating code

The relu function has not been overloaded on our tracer types yet. Let's call the code generation utilities from the "Generating code" section for this purpose:

eval(SCT.generate_code_1_to_1(:NNlib, relu))

The relu function is now ready to be called with SCT's tracer types.

What is the eval call doing?

Let's call generate_code_1_to_1 without wrapping it eval:

SCT.generate_code_1_to_1(:NNlib, relu)
quote
    begin
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:29 =#
            function NNlib.relu(t::(SparseConnectivityTracer).GradientTracer)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:29 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:30 =#
                return #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:30 =# @noinline((SparseConnectivityTracer).gradient_tracer_1_to_1(t, false))
            end
        end
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:43 =#
            function NNlib.relu(d::D) where {P, T <: (SparseConnectivityTracer).GradientTracer, D <: (SparseConnectivityTracer).Dual{P, T}}
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:43 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:44 =#
                x = (SparseConnectivityTracer).primal(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:45 =#
                p_out = NNlib.relu(x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:47 =#
                t = (SparseConnectivityTracer).tracer(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:48 =#
                is_der1_zero = (SparseConnectivityTracer).is_der1_zero_local(NNlib.relu, x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:49 =#
                t_out = #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:49 =# @noinline((SparseConnectivityTracer).gradient_tracer_1_to_1(t, is_der1_zero))
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:50 =#
                return (SparseConnectivityTracer).Dual(p_out, t_out)
            end
        end
    end
    begin
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:67 =#
            function NNlib.relu(t::(SparseConnectivityTracer).HessianTracer)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:67 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:68 =#
                return #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:68 =# @noinline((SparseConnectivityTracer).hessian_tracer_1_to_1(t, false, true))
            end
        end
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:81 =#
            function NNlib.relu(d::D) where {P, T <: (SparseConnectivityTracer).HessianTracer, D <: (SparseConnectivityTracer).Dual{P, T}}
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:81 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:82 =#
                x = (SparseConnectivityTracer).primal(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:83 =#
                p_out = NNlib.relu(x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:85 =#
                t = (SparseConnectivityTracer).tracer(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:86 =#
                is_der1_zero = (SparseConnectivityTracer).is_der1_zero_local(NNlib.relu, x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:87 =#
                is_der2_zero = (SparseConnectivityTracer).is_der2_zero_local(NNlib.relu, x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:88 =#
                t_out = #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:88 =# @noinline((SparseConnectivityTracer).hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero))
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:89 =#
                return (SparseConnectivityTracer).Dual(p_out, t_out)
            end
        end
    end
end

As you can see, this returns a quote, a type of expression containing our generated Julia code.

We have to use quotes: The code generation mechanism lives in SCT, but the generated code has to be evaluated in the package extension, not SCT. As you can see in the generated quote, we handle the necessary name-spacing for you.

Manual overloads

As mentioned above, for functions that take Real arguments, manual overloads should generally be avoided. If such an overload is necessary (e.g. for array inputs), it should follow the following design priciples, ordered by importance:

Local sparsity detection (Dual):

  • Overloads must return conservative sparsity patterns (no false negatives) at the given input.
  • Local tracers are allowed to enter branches in user code. User code can be stateful.
  • MethodErrors due to missing overloads can be avoided by returning a very conservative sparsity pattern.

Global sparsity detection (GradientTracer and HessianTracer):

  • Overloads must return conservative sparsity patterns (no false negatives) over the entire input domain.
  • Tracers must error instead of entering branches in user code. This requires that overloaded functions return tracers instead of Bool (or numbers), as the former are designed to error in comparisons.
  • Overloads must ignore scalar values of non-tracer inputs. (While not set in stone, changing this rule in the future would require a breaking release.)
  • Sparsity should be prioritized over performance. We assume global sparsity detection can be amortized.
  • MethodErrors due to missing overloads can be avoided by returning a very conservative sparsity pattern.