Differentiable Subprograms

A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application...

Differentiable Subprograms

A differentiable subprogram is a program fragment that can participate in derivative propagation as a coherent unit. Instead of differentiating an entire application monolithically, AD systems decompose computation into smaller callable pieces with well-defined derivative behavior.

A simple example is:

def square(x):
    return x * x

Mathematically:

$$ f(x) = x^2. $$

Its derivative is:

$$ f'(x) = 2x. $$

An AD system can either inline the body into a larger graph or treat the function as a reusable differentiable component.

Functions as Differentiable Maps

A differentiable subprogram behaves like a map:

$$ f : X \to Y. $$

Automatic differentiation constructs associated derivative maps.

Forward mode constructs:

$$ Df : (x, \dot x) \mapsto (f(x), \dot y). $$

Reverse mode constructs a backward map:

$$ B_f : (x, \bar y) \mapsto \bar x. $$

The function boundary becomes part of the differentiation structure.

Encapsulation

Subprograms encapsulate local computation.

Example:

def layer(x, w, b):
    return relu(matmul(x, w) + b)

A larger model may call this repeatedly:

h1 = layer(x, w1, b1)
h2 = layer(h1, w2, b2)
y  = layer(h2, w3, b3)

The AD system can:

Strategy Meaning
Inline expand the body each call
Reuse transformed version cache derivative transform
Treat as primitive use custom derivative rule

Encapsulation allows modular differentiation.

Call Graphs

Programs with functions form a call graph.

Example:

main
 ├─> encoder
 │    ├─> attention
 │    └─> mlp
 └─> decoder
      ├─> attention
      └─> mlp

AD propagates derivatives through the same call structure.

Forward mode follows call direction. Reverse mode propagates adjoints back through return dependencies.

A reverse-mode engine must know:

Item Purpose
Inputs to function backward derivatives
Outputs from function output adjoints
Saved intermediates local backward rules
Call ordering reverse traversal

Function Composition

Subprograms compose naturally.

If:

$$ y = g(f(x)), $$

then:

$$ \frac{dy}{dx} = \frac{dy}{df} \frac{df}{dx}. $$

In code:

u = f(x)
y = g(u)

The dependency graph becomes:

x -> f -> u -> g -> y

AD applies the chain rule across function boundaries exactly as across primitive operations.

Local Derivative Contracts

A differentiable subprogram exposes a local derivative contract.

For forward mode:

(primal_in, tangent_in)
    ->
(primal_out, tangent_out)

For reverse mode:

(primal_in, primal_out, output_adj)
    ->
input_adj

The outer system does not need to know the internal implementation if the contract is correct.

This abstraction enables custom derivative rules.

Primitive Operations

Some subprograms are treated as primitives.

Example:

y = sin(x)

The AD system does not expand the implementation of sine from numerical approximation code. Instead, it uses the known derivative rule:

$$ \frac{d}{dx}\sin(x) = \cos(x). $$

Similarly for:

Primitive Derivative
exp(x) exp(x)
log(x) 1/x
matmul(a,b) matrix rules
conv(x,w) convolution rules

Primitive differentiation hides implementation complexity.

User-Defined Functions

User-defined functions can usually be differentiated automatically.

Example:

def f(x):
    a = x * x
    b = sin(a)
    return b + 1

The system builds a dependency graph for the body and differentiates it mechanically.

Forward-mode transformed version conceptually becomes:

def df(x, dx):
    a  = x * x
    da = 2 * x * dx

    b  = sin(a)
    db = cos(a) * da

    y  = b + 1
    dy = db

    return y, dy

This transformation is systematic.

Closures

A closure captures external variables.

def make_scale(a):
    def scale(x):
        return a * x
    return scale

The inner function depends on both x and captured variable a.

Mathematically:

$$ f(x; a) = ax. $$

The derivatives are:

$$ \frac{\partial f}{\partial x} = a, \qquad \frac{\partial f}{\partial a} = x. $$

AD systems must track captured values as dependencies.

Recursive Functions

Recursive subprograms define themselves in terms of earlier calls.

Example:

def f(x, n):
    if n == 0:
        return x
    return sin(f(x, n - 1))

For fixed n, the call tree expands into a finite dependency graph.

AD differentiates the expanded execution trace.

Reverse mode must preserve:

Item Reason
Call stack reverse traversal
Local variables local derivatives
Return structure adjoint propagation

Recursive AD therefore resembles stack replay.

Higher-Order Functions

A higher-order function takes functions as inputs or outputs.

Example:

def apply_twice(f, x):
    return f(f(x))

If f is differentiable:

$$ y = f(f(x)). $$

Then:

$$ \frac{dy}{dx} = f'(f(x))f'(x). $$

AD systems for functional languages often treat derivative transforms themselves as higher-order functions.

Differentiation as Program Transformation

A differentiable subprogram may be transformed into a new subprogram.

Original:

f : X -> Y

Forward-mode transform:

Df : (X × TX) -> (Y × TY)

Reverse-mode transform:

Rf : X -> (Y, Y* -> X*)

where $TX$ and $TY$ are tangent spaces, and $X^$, $Y^$ are cotangent spaces.

Thus AD itself becomes a compiler transform on callable units.

Custom Gradient Rules

Sometimes the default derivative is inefficient or numerically unstable.

A user may define a custom backward rule.

Example:

def stable_logsumexp(x):
    ...

The backward rule can be supplied directly:

def backward(output_adj):
    ...

Advantages include:

Benefit Example
Better stability log-sum-exp
Lower memory fused backward
Faster execution custom kernels
Implicit differentiation solvers
Approximate gradients quantization

The subprogram becomes a primitive with user-defined differentiation semantics.

Opaque External Functions

Some subprograms call external libraries.

y = cuda_kernel(x)

The AD system may not know the internal implementation.

Possible strategies:

Strategy Meaning
Treat as non-differentiable stop gradient
Provide custom rule manual backward
Trace internal ops if supported
Use finite differences fallback approximation

Large systems often rely heavily on custom derivative rules for external kernels.

Nested Differentiation

A differentiable subprogram may itself invoke AD.

Example:

def gradient_norm(f, x):
    g = grad(f)(x)
    return dot(g, g)

Now AD is applied to a program that already performs differentiation.

This creates nested derivative structures:

Outer level Inner level
differentiate gradient_norm differentiate f

Nested AD requires careful management of tangent and adjoint scopes.

Perturbation Confusion

Nested differentiation can accidentally mix derivative levels.

Example:

grad(lambda x: grad(f)(x))(x)

The inner and outer derivative computations must remain distinct.

Correct systems isolate derivative contexts so that:

Level Meaning
Inner tangent derivative of f
Outer tangent derivative of derivative

Without isolation, perturbations may interfere and produce incorrect higher-order derivatives.

Function Boundaries and Optimization

Subprogram boundaries influence optimization.

Inlining may expose more fusion opportunities:

y = relu(matmul(x, w) + b)

Modular boundaries may improve reuse and compilation caching.

Compiler-based systems often balance:

Goal Preference
Maximum optimization aggressive inlining
Fast compilation preserve modularity
Reusability cached differentiated kernels
Lower memory fused backward passes

Differentiable subprograms are therefore both semantic and optimization units.

Interface Design

A minimal differentiable interface may look like:

type Function interface {
    Forward(x Value) Value
    Backward(yAdj Value) Value
}

More realistic systems require:

Requirement Reason
Multiple inputs tensor programs
Multiple outputs structured models
Saved intermediates reverse mode
Device awareness GPU execution
Shape metadata tensor validation
Batched evaluation vectorization

Production AD systems therefore build sophisticated callable abstractions around these basic ideas.

Core Idea

A differentiable subprogram is a callable computation unit with well-defined derivative behavior. Automatic differentiation propagates through function boundaries exactly as through primitive operations: by composing local derivative rules according to the dependency structure.

Subprograms provide modularity, reuse, abstraction, and optimization boundaries. They also enable custom gradients, higher-order differentiation, and compiler-level differentiation transforms.