Adding Overloads
The developer documentation might refer to internals which can change without warning in a future release of SparseConnectivityTracer. Only functionality that is exported or part of the user documentation adheres to semantic versioning.
Having read our guide "How SparseConnectivityTracer works", you might want to add your own methods on GradientTracer
, HessianTracer
and Dual
to improve the performance of your functions or to work around some of SCT's limitations.
Avoid hand-written overloads
If you want to overload a Function
that takes Real
arguments, we strongly discourage you from manually adding methods to your function that use our internal tracer types.
Instead, use the same code generation mechanisms that we use. This page of the documentation shows you how.
The easiest way to add overloads is to copy one of our package extensions, e.g. our NNlib extension, and to modify it. Please upstream your additions by opening a pull request! We will help you out to get your feature merged.
Operator classification
SCT currently supports three types of functions:
- 1-to-1: operators with one input and one output
- 2-to-1: operators with two inputs and one output
- 1-to-2: operators with one input and two outputs
Depending on the type of function you're dealing with, you will have to specify the way in which your function is differentiable:
In | Out | Examples | Methods you need to implement |
---|---|---|---|
1 | 1 | sin , cos , abs | is_der1_zero_global , is_der2_zero_global |
2 | 1 | + , * , > , isequal | is_der1_arg1_zero_global , is_der2_arg1_zero_global , is_der1_arg2_zero_global , is_der2_arg2_zero_global , is_der_cross_zero_global |
1 | 2 | sincos | is_der1_out1_zero_global , is_der2_out1_zero_global , is_der1_out2_zero_global , is_der2_out2_zero_global |
Methods you have to implement for 1-to-1 operators
Function | Meaning |
---|---|
is_der1_zero_global(::typeof(f)) = false | $\frac{\partial f}{\partial x} \neq 0$ for some $x$ |
is_der2_zero_global(::typeof(f)) = false | $\frac{\partial^2 f}{\partial x^2} \neq 0$ for some $x$ |
Optionally, to increase the sparsity of TracerLocalSparsityDetector
, you can additionally implement
Function | Meaning |
---|---|
is_der1_zero_local(::typeof(f), x) = false | $\frac{\partial f}{\partial x} \neq 0$ for given $x$ |
is_der2_zero_local(::typeof(f), x) = false | $\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x$ |
These fall back to
is_der1_zero_local(f::F, x) where {F} = is_der1_zero_global(f)
is_der2_zero_local(f::F, x) where {F} = is_der2_zero_global(f)
Methods you have to implement for 2-to-1 operators
Function | Meaning |
---|---|
is_der1_arg1_zero_global(::typeof(f)) = false | $\frac{\partial f}{\partial x} \neq 0$ for some $x,y$ |
is_der2_arg1_zero_global(::typeof(f)) = false | $\frac{\partial^2 f}{\partial x^2} \neq 0$ for some $x,y$ |
is_der1_arg2_zero_global(::typeof(f)) = false | $\frac{\partial f}{\partial y} \neq 0$ for some $x,y$ |
is_der2_arg2_zero_global(::typeof(f)) = false | $\frac{\partial^2 f}{\partial y^2} \neq 0$ for some $x,y$ |
is_der_cross_zero_global(::typeof(f)) = false | $\frac{\partial^2 f}{\partial x \partial y} \neq 0$ for some $x,y$ |
Optionally, to increase the sparsity of TracerLocalSparsityDetector
, you can additionally implement
Function | Meaning |
---|---|
is_der1_arg1_zero_local(::typeof(f), x, y) = false | $\frac{\partial f}{\partial x} \neq 0$ for given $x,y$ |
is_der2_arg1_zero_local(::typeof(f), x, y) = false | $\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x,y$ |
is_der1_arg2_zero_local(::typeof(f), x, y) = false | $\frac{\partial f}{\partial x} \neq 0$ for given $x,y$ |
is_der2_arg2_zero_local(::typeof(f), x, y) = false | $\frac{\partial^2 f}{\partial x^2} \neq 0$ for given $x,y$ |
is_der_cross_zero_local(::typeof(f), x, y) = false | $\frac{\partial^2 f}{\partial x \partial y} \neq 0$ for given $x,y$ |
These fall back to
is_der1_arg1_zero_local(f::F, x, y) where {F} = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_local(f::F, x, y) where {F} = is_der2_arg1_zero_global(f)
is_der1_arg2_zero_local(f::F, x, y) where {F} = is_der1_arg2_zero_global(f)
is_der2_arg2_zero_local(f::F, x, y) where {F} = is_der2_arg2_zero_global(f)
is_der_cross_zero_local(f::F, x, y) where {F} = is_der_cross_zero_global(f)
Methods you have to implement for 1-to-2 operators
Function | Meaning |
---|---|
is_der1_out1_zero_local(::typeof(f)) = false | $\frac{\partial f_1}{\partial x} \neq 0$ for some $x$ |
is_der2_out1_zero_local(::typeof(f)) = false | $\frac{\partial^2 f_1}{\partial x^2} \neq 0$ for some $x$ |
is_der1_out2_zero_local(::typeof(f)) = false | $\frac{\partial f_2}{\partial x} \neq 0$ for some $x$ |
is_der2_out2_zero_local(::typeof(f)) = false | $\frac{\partial^2 f_2}{\partial x^2} \neq 0$ for some $x$ |
Optionally, to increase the sparsity of TracerLocalSparsityDetector
, you can additionally implement
Function | Meaning |
---|---|
is_der1_out1_zero_local(::typeof(f), x) = false | $\frac{\partial f_1}{\partial x} \neq 0$ for given $x$ |
is_der2_out1_zero_local(::typeof(f), x) = false | $\frac{\partial^2 f_1}{\partial x^2} \neq 0$ for given $x$ |
is_der1_out2_zero_local(::typeof(f), x) = false | $\frac{\partial f_2}{\partial x} \neq 0$ for given $x$ |
is_der2_out2_zero_local(::typeof(f), x) = false | $\frac{\partial^2 f_2}{\partial x^2} \neq 0$ for given $x$ |
These fall back to
is_der1_out1_zero_local(f::F, x) where {F} = is_der1_out1_zero_global(f)
is_der2_out1_zero_local(f::F, x) where {F} = is_der2_out1_zero_global(f)
is_der1_out2_zero_local(f::F, x) where {F} = is_der1_out2_zero_global(f)
is_der2_out2_zero_local(f::F, x) where {F} = is_der2_out2_zero_global(f)
Overloading
After implementing the required classification methods for a function, the function has not been overloaded on our tracer types yet. SCT provides three functions that generate code via meta-programming:
- 1-to-1:
eval(SCT.generate_code_1_to_1(module_symbol, f))
- 2-to-1:
eval(SCT.generate_code_1_to_2(module_symbol, f))
- 1-to-2:
eval(SCT.generate_code_2_to_1(module_symbol, f))
You are required to call the function that matches your type of operator.
We will take a look at the code generation mechanism in the example below.
Example
For some examples on how to overload methods, take a look at our package extensions. Let's look at the relu
activation function from ext/SparseConnectivityTracerNNlibExt.jl
, which is a 1-to-1 operator defined as $\text{relu}(x) = \text{max}(0, x)$.
Step 1: Classification
The relu
function has a non-zero first-order derivative $\frac{\partial f}{\partial x}=1$ for inputs $x>0$. The second derivative is zero everywhere. We therefore implement:
import SparseConnectivityTracer as SCT
using NNlib
SCT.is_der1_zero_global(::typeof(relu)) = false
SCT.is_der2_zero_global(::typeof(relu)) = true
SCT.is_der1_zero_local(::typeof(relu), x) = x < 0
Note that we imported SCT to extend its operator classification methods on typeof(relu)
.
Step 2: Overloading
The relu
function has not been overloaded on our tracer types yet. Let's call the code generation utilities from the "Overloading" section for this purpose:
eval(SCT.generate_code_1_to_1(:NNlib, relu))
The relu
function is now ready to be called with SCT's tracer types.
What is the eval call doing?
Let's call generate_code_1_to_1
without wrapping it eval
:
SCT.generate_code_1_to_1(:NNlib, relu)
quote
begin
begin
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:29 =#
function NNlib.relu(t::(SparseConnectivityTracer).GradientTracer)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:29 =#
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:30 =#
return #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:30 =# @noinline((SparseConnectivityTracer).gradient_tracer_1_to_1(t, false))
end
end
begin
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:43 =#
function NNlib.relu(d::D) where {P, T <: (SparseConnectivityTracer).GradientTracer, D <: (SparseConnectivityTracer).Dual{P, T}}
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:43 =#
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:44 =#
x = (SparseConnectivityTracer).primal(d)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:45 =#
p_out = NNlib.relu(x)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:47 =#
t = (SparseConnectivityTracer).tracer(d)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:48 =#
is_der1_zero = (SparseConnectivityTracer).is_der1_zero_local(NNlib.relu, x)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:49 =#
t_out = #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:49 =# @noinline((SparseConnectivityTracer).gradient_tracer_1_to_1(t, is_der1_zero))
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/gradient_tracer.jl:50 =#
return (SparseConnectivityTracer).Dual(p_out, t_out)
end
end
end
begin
begin
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:67 =#
function NNlib.relu(t::(SparseConnectivityTracer).HessianTracer)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:67 =#
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:68 =#
return #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:68 =# @noinline((SparseConnectivityTracer).hessian_tracer_1_to_1(t, false, true))
end
end
begin
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:81 =#
function NNlib.relu(d::D) where {P, T <: (SparseConnectivityTracer).HessianTracer, D <: (SparseConnectivityTracer).Dual{P, T}}
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:81 =#
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:82 =#
x = (SparseConnectivityTracer).primal(d)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:83 =#
p_out = NNlib.relu(x)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:85 =#
t = (SparseConnectivityTracer).tracer(d)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:86 =#
is_der1_zero = (SparseConnectivityTracer).is_der1_zero_local(NNlib.relu, x)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:87 =#
is_der2_zero = (SparseConnectivityTracer).is_der2_zero_local(NNlib.relu, x)
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:88 =#
t_out = #= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:88 =# @noinline((SparseConnectivityTracer).hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero))
#= /home/runner/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overloads/hessian_tracer.jl:89 =#
return (SparseConnectivityTracer).Dual(p_out, t_out)
end
end
end
end
As you can see, this returns a quote
, a type of expression containing our generated Julia code.
We have to use quotes: The code generation mechanism lives in SCT, but the generated code has to be evaluated in the package extension, not SCT. As you can see in the generated quote, we handle the necessary name-spacing for you.