Adding Overloads

Internals may change

The developer documentation might refer to internals which can change without warning in a future release of SparseConnectivityTracer. Only functionality that is exported or part of the user documentation adheres to semantic versioning.

Having read our guide "How SparseConnectivityTracer works", you might want to add your own methods on GradientTracer, HessianTracer and Dual to improve the performance of your functions or to work around some of SCT's limitations.

Avoid hand-written overloads

Don't overload manually

If you want to overload a Function that takes Real arguments, we strongly discourage you from manually adding methods to your function that use our internal tracer types.

Instead, use the same code generation mechanisms that we use. This page of the documentation shows you how.

Copy one of our package extensions

The easiest way to add overloads is to copy one of our package extensions, e.g. our NNlib extension, and to modify it. Please upstream your additions by opening a pull request! We will help you out to get your feature merged.

Operator classification

SCT currently supports three types of functions:

  1. 1-to-1: operators with one input and one output
  2. 2-to-1: operators with two inputs and one output
  3. 1-to-2: operators with one input and two outputs

Depending on the type of function you're dealing with, you will have to specify the way in which your function is differentiable:

InOutExamplesMethods you need to implement
11sin, cos, absis_der1_zero_global, is_der2_zero_global
21+, *, >, isequalis_der1_arg1_zero_global, is_der2_arg1_zero_global, is_der1_arg2_zero_global, is_der2_arg2_zero_global, is_der_cross_zero_global
12sincosis_der1_out1_zero_global, is_der2_out1_zero_global, is_der1_out2_zero_global, is_der2_out2_zero_global
Methods you have to implement for 1-to-1 operators
FunctionMeaning
is_der1_zero_global(::typeof(f)) = false$\frac{\partial f}{\partial x} \neq 0$ for some $x$
is_der2_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for some $x$

Optionally, to increase the sparsity of TracerLocalSparsityDetector, you can additionally implement

FunctionMeaning
is_der1_zero_local(::typeof(f), x) = false$\frac{\partial f}{\partial x} \neq 0$ for given $x$
is_der2_zero_local(::typeof(f), x) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x$

These fall back to

is_der1_zero_local(f::F, x) where {F} = is_der1_zero_global(f)
is_der2_zero_local(f::F, x) where {F} = is_der2_zero_global(f)
Methods you have to implement for 2-to-1 operators
FunctionMeaning
is_der1_arg1_zero_global(::typeof(f)) = false$\frac{\partial f}{\partial x} \neq 0$ for some $x,y$
is_der2_arg1_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for some $x,y$
is_der1_arg2_zero_global(::typeof(f)) = false$\frac{\partial f}{\partial y} \neq 0$ for some $x,y$
is_der2_arg2_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial y^2} \neq 0$ for some $x,y$
is_der_cross_zero_global(::typeof(f)) = false$\frac{\partial^2 f}{\partial x \partial y} \neq 0$ for some $x,y$

Optionally, to increase the sparsity of TracerLocalSparsityDetector, you can additionally implement

FunctionMeaning
is_der1_arg1_zero_local(::typeof(f), x, y) = false$\frac{\partial f}{\partial x} \neq 0$ for given $x,y$
is_der2_arg1_zero_local(::typeof(f), x, y) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x,y$
is_der1_arg2_zero_local(::typeof(f), x, y) = false$\frac{\partial f}{\partial x} \neq 0$ for given $x,y$
is_der2_arg2_zero_local(::typeof(f), x, y) = false$\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x,y$
is_der_cross_zero_local(::typeof(f), x, y) = false$\frac{\partial^2 f}{\partial x \partial y} \neq 0$ for given $x,y$

These fall back to

is_der1_arg1_zero_local(f::F, x, y) where {F} = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_local(f::F, x, y) where {F} = is_der2_arg1_zero_global(f)
is_der1_arg2_zero_local(f::F, x, y) where {F} = is_der1_arg2_zero_global(f)
is_der2_arg2_zero_local(f::F, x, y) where {F} = is_der2_arg2_zero_global(f)
is_der_cross_zero_local(f::F, x, y) where {F} = is_der_cross_zero_global(f)
Methods you have to implement for 1-to-2 operators
FunctionMeaning
is_der1_out1_zero_local(::typeof(f)) = false$\frac{\partial f_1}{\partial x} \neq 0$ for some $x$
is_der2_out1_zero_local(::typeof(f)) = false$\frac{\partial^2 f_1}{\partial x^2} \neq 0$ for some $x$
is_der1_out2_zero_local(::typeof(f)) = false$\frac{\partial f_2}{\partial x} \neq 0$ for some $x$
is_der2_out2_zero_local(::typeof(f)) = false$\frac{\partial^2 f_2}{\partial x^2} \neq 0$ for some $x$

Optionally, to increase the sparsity of TracerLocalSparsityDetector, you can additionally implement

FunctionMeaning
is_der1_out1_zero_local(::typeof(f), x) = false$\frac{\partial f_1}{\partial x} \neq 0$ for given $x$
is_der2_out1_zero_local(::typeof(f), x) = false$\frac{\partial^2 f_1}{\partial x^2} \neq 0$ for given $x$
is_der1_out2_zero_local(::typeof(f), x) = false$\frac{\partial f_2}{\partial x} \neq 0$ for given $x$
is_der2_out2_zero_local(::typeof(f), x) = false$\frac{\partial^2 f_2}{\partial x^2} \neq 0$ for given $x$

These fall back to

is_der1_out1_zero_local(f::F, x) where {F} = is_der1_out1_zero_global(f)
is_der2_out1_zero_local(f::F, x) where {F} = is_der2_out1_zero_global(f)
is_der1_out2_zero_local(f::F, x) where {F} = is_der1_out2_zero_global(f)
is_der2_out2_zero_local(f::F, x) where {F} = is_der2_out2_zero_global(f)

Overloading

After implementing the required classification methods for a function, the function has not been overloaded on our tracer types yet. SCT provides three functions that generate code via meta-programming:

  • 1-to-1: eval(SCT.generate_code_1_to_1(module_symbol, f))
  • 2-to-1: eval(SCT.generate_code_1_to_2(module_symbol, f))
  • 1-to-2: eval(SCT.generate_code_2_to_1(module_symbol, f))

You are required to call the function that matches your type of operator.

Code generation

We will take a look at the code generation mechanism in the example below.

Example

For some examples on how to overload methods, take a look at our package extensions. Let's look at the relu activation function from ext/SparseConnectivityTracerNNlibExt.jl, which is a 1-to-1 operator defined as $\text{relu}(x) = \text{max}(0, x)$.

Step 1: Classification

The relu function has a non-zero first-order derivative $\frac{\partial f}{\partial x}=1$ for inputs $x>0$. The second derivative is zero everywhere. We therefore implement:

import SparseConnectivityTracer as SCT
using NNlib

SCT.is_der1_zero_global(::typeof(relu)) = false
SCT.is_der2_zero_global(::typeof(relu)) = true

SCT.is_der1_zero_local(::typeof(relu), x) = x < 0
import SparseConnectivityTracer

Note that we imported SCT to extend its operator classification methods on typeof(relu).

Step 2: Overloading

The relu function has not been overloaded on our tracer types yet. Let's call the code generation utilities from the "Overloading" section for this purpose:

eval(SCT.generate_code_1_to_1(:NNlib, relu))

The relu function is now ready to be called with SCT's tracer types.

What is the eval call doing?

Let's call generate_code_1_to_1 without wrapping it eval:

SCT.generate_code_1_to_1(:NNlib, relu)
quote
    begin
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:29 =#
            function NNlib.relu(t::(SparseConnectivityTracer).GradientTracer)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:29 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:30 =#
                return #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:30 =# @noinline((SparseConnectivityTracer).gradient_tracer_1_to_1(t, false))
            end
        end
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:43 =#
            function NNlib.relu(d::D) where {P, T <: (SparseConnectivityTracer).GradientTracer, D <: (SparseConnectivityTracer).Dual{P, T}}
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:43 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:44 =#
                x = (SparseConnectivityTracer).primal(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:45 =#
                p_out = NNlib.relu(x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:47 =#
                t = (SparseConnectivityTracer).tracer(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:48 =#
                is_der1_zero = (SparseConnectivityTracer).is_der1_zero_local(NNlib.relu, x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:49 =#
                t_out = #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:49 =# @noinline((SparseConnectivityTracer).gradient_tracer_1_to_1(t, is_der1_zero))
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:50 =#
                return (SparseConnectivityTracer).Dual(p_out, t_out)
            end
        end
    end
    begin
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:67 =#
            function NNlib.relu(t::(SparseConnectivityTracer).HessianTracer)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:67 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:68 =#
                return #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:68 =# @noinline((SparseConnectivityTracer).hessian_tracer_1_to_1(t, false, true))
            end
        end
        begin
            #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:81 =#
            function NNlib.relu(d::D) where {P, T <: (SparseConnectivityTracer).HessianTracer, D <: (SparseConnectivityTracer).Dual{P, T}}
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:81 =#
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:82 =#
                x = (SparseConnectivityTracer).primal(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:83 =#
                p_out = NNlib.relu(x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:85 =#
                t = (SparseConnectivityTracer).tracer(d)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:86 =#
                is_der1_zero = (SparseConnectivityTracer).is_der1_zero_local(NNlib.relu, x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:87 =#
                is_der2_zero = (SparseConnectivityTracer).is_der2_zero_local(NNlib.relu, x)
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:88 =#
                t_out = #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:88 =# @noinline((SparseConnectivityTracer).hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero))
                #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:89 =#
                return (SparseConnectivityTracer).Dual(p_out, t_out)
            end
        end
    end
end

As you can see, this returns a quote, a type of expression containing our generated Julia code.

We have to use quotes: The code generation mechanism lives in SCT, but the generated code has to be evaluated in the package extension, not SCT. As you can see in the generated quote, we handle the necessary name-spacing for you.