Differentiable Programming

Differentiable programming treats differentiation as a general programming-language feature. A program can contain numerical kernels, control flow, data structures, solvers,...

Differentiable Programming

Differentiable programming treats differentiation as a general programming-language feature. A program can contain numerical kernels, control flow, data structures, solvers, simulations, and model code. If the program denotes a differentiable computation, the system should be able to compute derivatives of that program.

This is broader than neural network training. A differentiable program may include a physics simulator, a renderer, a database query, an optimizer, or a probabilistic model. The central idea is that gradients should be available across the whole program, not only inside a fixed tensor graph.

From Automatic Differentiation to Differentiable Programming

Automatic differentiation began as a technique for computing derivatives of numerical programs. Differentiable programming extends that idea into a design principle for languages and systems.

The core operator is a program transformation:

$$ \mathrm{grad} : (X \rightarrow \mathbb{R}) \rightarrow (X \rightarrow X) $$

Given a scalar-output function, grad produces another function that returns its gradient.

In code, this appears as:

def loss(params, batch):
    pred = model(params, batch.x)
    return mse(pred, batch.y)

g = grad(loss)(params, batch)

The important point is not the syntax. The important point is that loss can be an ordinary program. It can call functions, branch, loop, allocate intermediate values, and invoke libraries.

Differentiation as a First-Class Transformation

In a differentiable programming system, transformations such as grad, vjp, jvp, jacfwd, and jacrev are first-class operations.

Transformation Meaning
grad(f) Gradient of scalar-output function
jvp(f) Jacobian-vector product
vjp(f) Vector-Jacobian product
jacfwd(f) Jacobian using forward mode
jacrev(f) Jacobian using reverse mode
hessian(f) Second derivative matrix
value_and_grad(f) Primal value and gradient together

This style changes how numerical software is written. Instead of manually deriving update rules, the programmer writes the objective and asks the system for derivatives.

Programs Beyond Tensor Graphs

Early deep learning systems often represented models as static tensor graphs. A graph contained operations such as matrix multiplication, convolution, addition, and activation functions.

Differentiable programming relaxes this model. The differentiable object is the program itself.

For example:

def f(x):
    y = 0.0
    for i in range(10):
        if x[i] > 0:
            y = y + x[i] * x[i]
        else:
            y = y - x[i]
    return y

This program contains a loop and a branch. A differentiable programming system differentiates the executed computation. The derivative depends on the path taken through the program.

This makes differentiation useful for real numerical code, where fixed graphs are often too restrictive.

The Language Boundary

A central difficulty is deciding which parts of the language are differentiable.

Arithmetic over real-valued arrays is differentiable. Integer indexing, mutation, allocation, I/O, hashing, sorting, and discrete control decisions need more care.

Construct Differentiation issue
Floating point arithmetic Usually differentiable except at singularities
Branching Differentiates the selected branch
Loops Differentiates the executed iterations
Integer indexing Index choice is usually non-differentiable
Sorting Piecewise differentiable, discontinuous at ties
Mutation Requires correct adjoint semantics
I/O Usually outside derivative computation
Randomness Needs reparameterization or score estimators
Allocation Affects runtime, not usually mathematical derivative

A differentiable programming language must make these boundaries explicit. Silent behavior is dangerous. A system should reject invalid differentiation, define a subgradient convention, or require a custom rule.

Custom Derivative Rules

Not every useful operation should be differentiated by expanding its implementation. Some operations have better derivative formulas than their source code suggests.

For example, a numerically stable logsumexp implementation may contain branching and shifts:

def logsumexp(x):
    m = max(x)
    return m + log(sum(exp(x - m)))

Differentiating the implementation directly may work, but a custom derivative rule is often clearer and more stable.

A custom rule gives the AD system a local derivative definition:

Operation Custom derivative reason
logsumexp Numerical stability
Matrix inverse Avoid differentiating solver internals naively
Cholesky factorization Structured derivative
ODE solver Use adjoint method or sensitivity equation
Optimization solver Use implicit differentiation
Sampling operation Use reparameterization

Custom derivatives are part of the language interface. They allow library authors to expose mathematically correct and efficient differentiation behavior.

Differentiating Through Libraries

A real program uses libraries. A differentiable programming system must decide how derivative information crosses library boundaries.

There are several strategies.

Strategy Description
Operator overloading Library functions execute on derivative-aware values
Tracing Runtime records primitive operations
Source transformation Compiler rewrites source code
Compiler IR differentiation AD operates on lowered intermediate representation
Custom primitive rules Libraries expose derivative rules manually

Each strategy has a different boundary. Operator overloading works well when all operations dispatch through overloaded types. Tracing works when the runtime can observe the operations. Source transformation works when source code is available and transformable. Compiler IR differentiation works when the program has been lowered into an analyzable representation.

Differentiable Programming and Compilation

Differentiable programming is closely connected to compiler design.

A compiler-based system can perform several transformations together:

  1. Normalize the original program.
  2. Differentiate the normalized representation.
  3. Optimize the primal and derivative code.
  4. Fuse kernels.
  5. Plan memory.
  6. Lower to CPU, GPU, TPU, or accelerator code.

This is important because naive differentiated programs are often inefficient. Reverse mode can store too many intermediates. Higher-order AD can duplicate computation. Tensor programs can produce many small kernels.

A practical differentiable compiler performs AD and optimization as a single pipeline.

The Role of Types

Types help define which programs are differentiable.

A language may distinguish between:

Type Differentiation role
Float Differentiable scalar
Vector Float Differentiable array
Int Usually non-differentiable
Bool Usually non-differentiable
String Non-differentiable
Array Int Float Differentiable values with discrete indices
Function Differentiable only under constraints

A typed system can reject invalid uses early:

@differentiable
func f(_ x: Float) -> Float {
    return x * x
}

The annotation marks a function as participating in differentiation. The compiler can then check whether all operations inside the function support derivatives.

Shape types and effect systems extend this further. Shape types ensure tensor dimensions match. Effect systems track mutation, randomness, I/O, and other behavior that affects derivative semantics.

Mutation and State

Mutation is one of the hardest issues in differentiable programming.

Consider:

def f(x):
    y = x
    y[0] = y[0] * 2
    return sum(y)

The assignment changes the value of y. If y aliases x, the mutation also changes x. Reverse mode must reconstruct the correct sequence of states and propagate adjoints through each update.

There are several implementation choices.

Approach Behavior
Disallow mutation Simplest semantics
Functionalize mutation Rewrite updates into immutable values
Tape mutations Record old values for reverse pass
Use linear types Ensure values have unique ownership
Define array update adjoints Treat mutation as scatter and gather

Functionalization is common. The system rewrites mutation into pure operations, making the derivative transformation easier to define.

Control Flow and Dynamic Programs

Differentiable programming must support control flow because scientific and machine learning code is full of it.

A loop differentiates the actual iterations run:

def fixed_point(x):
    y = x
    while norm(g(y) - y) > 1e-6:
        y = g(y)
    return y

Differentiating this program by unrolling the loop gives the derivative of the algorithm, not necessarily the derivative of the mathematical fixed point. These can differ.

This distinction matters.

Target Meaning
Differentiate the algorithm Derivative of the finite executed computation
Differentiate the solution Derivative of the mathematical object computed
Differentiate the implementation Derivative of the exact source-level operations

For solvers, implicit differentiation may be preferable to differentiating every iteration.

Differentiable Programming in Practice

Modern systems implement differentiable programming with different tradeoffs.

System Main style
PyTorch Dynamic tracing with eager execution
TensorFlow Graph tracing and compiler paths
JAX Pure functional transformations over traced programs
Julia Zygote Source-to-source AD
Enzyme Compiler IR-level AD
Swift AD Language-integrated typed AD
Taichi Differentiable simulation DSL
Dr.Jit Differentiable rendering and simulation kernels

The field has no single dominant architecture. The right design depends on the host language, execution target, and expected workloads.

Differentiable Programming Beyond Machine Learning

Differentiable programming is useful wherever a program contains parameters that should be optimized.

Examples include:

Domain Differentiable program
Rendering Image formation pipeline
Physics Simulator with tunable parameters
Robotics Controller and dynamics model
Finance Pricing model and risk objective
Biology Kinetic model or molecular simulation
Databases Learned cost model or differentiable query component
Compilers Autotuning objective
Control Trajectory optimizer

The common pattern is objective-driven computation. A program computes a loss, error, likelihood, reward, or constraint violation. AD provides the derivative needed to improve parameters.

Design Requirements

A serious differentiable programming system should provide:

Requirement Reason
Correct derivative semantics Users must know what is being differentiated
Efficient reverse mode Scalar losses over many parameters are common
Forward mode support Needed for JVPs, Jacobians, and implicit methods
Higher-order derivatives Needed for curvature and meta-optimization
Custom derivative rules Needed for stability and performance
Control-flow support Real programs branch and loop
Mutation model State must have defined adjoint behavior
Compiler optimization Naive AD produces inefficient code
Debugging tools Gradients need inspection and testing

Differentiable programming becomes a systems problem, not only a calculus problem.

Failure Modes

Differentiable programming systems fail in characteristic ways.

Failure mode Example
Wrong derivative target Differentiating solver iterations instead of solved equation
Memory blowup Reverse mode stores every intermediate
Silent zero gradients Discrete operations cut gradient flow
Numerically unstable gradients Naive rules amplify floating point error
Excessive recompilation Dynamic shapes or branches trigger new traces
Perturbation confusion Nested AD mixes derivative levels
Invalid custom rules User-supplied adjoints violate the true derivative

These failures are often subtle. Good systems expose diagnostics, gradient checks, and explicit boundaries between differentiable and non-differentiable code.

Summary

Differentiable programming generalizes automatic differentiation from numerical kernels to whole programs. It asks the language and compiler to treat derivatives as ordinary program transformations.

The core challenge is semantic clarity. A user should know whether the system differentiates the mathematical function, the algorithm, the implementation, or a custom abstraction exposed by a library. Once that boundary is clear, the remaining problems are compiler and runtime engineering: representation, memory, optimization, dispatch, and hardware lowering.