Nested AD

Nested automatic differentiation means applying automatic differentiation inside another automatic differentiation computation.

Nested AD

Nested automatic differentiation means applying automatic differentiation inside another automatic differentiation computation.

Conceptually, this means differentiating derivative computations.

Examples include:

Expression Meaning
$\nabla(\nabla f)$ Hessian
$D(Df)$ second directional derivative
$\nabla(Df[v])$ Hessian-vector product
$D(\nabla f)$ Jacobian of gradient
$\nabla(\text{optimization step})$ meta-gradient

Nested AD is essential for higher-order derivatives, implicit differentiation, meta-learning, differentiable optimization, and differentiable programming systems.

The mathematics is straightforward. The implementation is subtle.

First-Order AD as a Transformation

An AD transform maps a program into another program.

If

$$ f : X \to Y, $$

then forward mode produces a transformed program:

$$ Df : TX \to TY, $$

where $T$ represents tangent information.

Reverse mode produces a program that computes pullbacks:

$$ f^\ast : T^\ast Y \to T^\ast X. $$

Nested AD applies these transformations repeatedly.

For example:

$$ D(Df) $$

means applying forward-mode transformation twice.

Similarly,

$$ \nabla(\nabla f) $$

means applying reverse mode to a reverse-mode derivative computation.

The resulting program has multiple derivative levels active simultaneously.

Example: Forward-over-Forward

Consider

$$ f(x) = x^3. $$

First forward mode computes:

$$ f(x + \epsilon v) = x^3 + 3x^2v,\epsilon. $$

The tangent coefficient is:

$$ Df(x)[v] = 3x^2v. $$

Apply forward mode again with another perturbation:

$$ x + \epsilon_1 v + \epsilon_2 w + \epsilon_1\epsilon_2 u. $$

Now mixed infinitesimal terms appear:

$$ f(x + \cdots) = x^3 + 3x^2(\cdots) + 6xvw,\epsilon_1\epsilon_2 + \cdots $$

The coefficient of

$$ \epsilon_1\epsilon_2 $$

contains second-order information.

Nested forward mode therefore naturally produces higher derivatives.

Dual Number Nesting

Forward mode is commonly implemented using dual numbers:

$$ a + b\epsilon, \quad \epsilon^2 = 0. $$

Nested forward mode uses nested dual structures:

$$ (a + b\epsilon_1) + (c + d\epsilon_1)\epsilon_2. $$

Expanding gives:

$$ a + b\epsilon_1 + c\epsilon_2 + d\epsilon_1\epsilon_2. $$

Each infinitesimal direction corresponds to a derivative level.

The mixed term

$$ d\epsilon_1\epsilon_2 $$

encodes second-order interaction.

This construction generalizes to arbitrary derivative order, but the algebra grows rapidly.

Reverse-over-Reverse Nesting

Nested reverse mode is more complicated.

A reverse-mode computation creates adjoints and backward passes. Differentiating that computation introduces adjoints of adjoints.

Suppose:

g = grad(f)
H = grad(g)

The second gradient differentiates the reverse pass used by the first gradient.

This means:

  1. The original primal computation must remain differentiable.
  2. The backward pass must itself behave like differentiable code.
  3. Adjoint accumulation logic becomes part of the differentiated computation.

The system must distinguish derivative levels carefully.

Mixed-Mode Nesting

Many practical systems use mixed-mode nesting.

Examples include:

Nesting Purpose
forward-over-reverse Hessian-vector products
reverse-over-forward directional derivative gradients
reverse-over-reverse higher-order scalar derivatives
forward-over-forward Taylor coefficients
reverse-over-forward-over-reverse advanced implicit differentiation

Mixed mode is often preferable because different AD modes have complementary strengths.

For scalar-output functions:

Mode Efficient dimension
forward mode small input dimension
reverse mode small output dimension

Nested systems combine these strengths.

Perturbation Confusion

The most famous nested-AD failure mode is perturbation confusion.

Suppose two forward-mode computations accidentally share the same infinitesimal symbol:

$$ \epsilon. $$

Then derivative information from different levels mixes incorrectly.

For example:

$$ (a + b\epsilon) + (c + d\epsilon) = (a+c) + (b+d)\epsilon. $$

If these infinitesimals were meant to represent different derivative levels, the result is wrong.

The problem becomes severe when derivative computations are passed through higher-order functions or closures.

Tagging Derivative Levels

Correct nested forward mode assigns a unique tag to each perturbation level.

Instead of one universal infinitesimal:

$$ \epsilon, $$

the system uses:

$$ \epsilon_1, \epsilon_2, \epsilon_3, \ldots $$

with independent nilpotent behavior.

Then:

$$ \epsilon_i \epsilon_j \ne 0 \quad \text{for } i \ne j, $$

while

$$ \epsilon_i^2 = 0. $$

Each derivative transform introduces a fresh perturbation identity.

Implementation-wise, systems usually represent this with:

Technique Description
unique IDs each AD transform gets fresh perturbation tag
lexical scoping perturbations scoped to derivative level
type-level tagging derivative levels encoded in types
runtime tagging tags stored dynamically
staged transforms derivative levels separated during compilation

Without tagging, nested forward mode is unreliable.

Cotangent Level Separation

Reverse mode has a similar issue.

Adjoints from different derivative levels must remain separate.

Suppose:

outer gradient
    inner gradient

The inner reverse pass should not accidentally consume cotangents from the outer reverse pass.

A correct nested reverse system tracks cotangent levels explicitly.

Conceptually:

Level Meaning
primal level original computation
tangent/cotangent level 1 first derivative
tangent/cotangent level 2 second derivative
tangent/cotangent level k kth derivative

This separation is part of the semantics, not merely debugging metadata.

Closures and Higher-Order Functions

Nested AD becomes harder in languages with closures and higher-order functions.

Example:

def outer(x):
    def inner(y):
        return x * y
    return grad(inner)

The derivative transform captures free variables from the outer scope.

The AD system must decide:

Question Requirement
which values are primal original computation
which values are tangent derivative level
which values are cotangent reverse accumulation
which closures capture derivative state nested differentiation
which tapes belong to which level nesting correctness

Functional languages often model this more cleanly because closures are explicit semantic objects.

Dynamic Computation Graphs

Nested AD interacts strongly with dynamic graph systems.

If the computation graph depends on runtime control flow, the graph structure itself may differ between derivative levels.

For example:

if grad(f)(x) > 0:
    ...

The derivative computation influences the primal control flow.

A nested AD system must specify:

  1. Which computations are traced.
  2. Which branches are differentiated.
  3. Whether graph structure is static or dynamic.
  4. Whether nested derivative traces are compositional.

Different frameworks make different choices.

Differentiating Optimizers

Nested AD is central in differentiable optimization.

Suppose gradient descent performs:

$$ x_{t+1} = x_t - \eta \nabla f(x_t). $$

Now suppose we want derivatives with respect to hyperparameters:

$$ \frac{\partial x_T}{\partial \eta}. $$

The optimizer itself becomes part of the differentiated program.

This requires differentiating through:

Object Derivative target
gradients first-order derivatives
update rules optimization dynamics
momentum accumulators optimizer state
learning rate schedules hyperparameters
inner training loops meta-learning

Nested AD enables this.

Implicit Differentiation

Nested AD also appears in implicit differentiation.

Suppose:

$$ g(x, y(x)) = 0. $$

We may avoid differentiating every iteration of a solver by differentiating the fixed-point condition itself.

Still, the resulting derivative computations often involve nested linearizations and reverse passes.

This is especially common in:

Area Example
meta-learning differentiating equilibrium states
optimization layers differentiating argmin operators
physics simulation differentiating steady states
probabilistic inference differentiating fixed-point solvers

Nested AD provides the underlying machinery.

Tape Nesting

In tape-based reverse mode, nested AD introduces nested tapes.

Possible models include:

Model Description
independent tapes each derivative level has separate tape
hierarchical tapes tapes reference parent tapes
reentrant tapes backward passes may themselves record operations
staged tapes derivative levels separated during compilation

Incorrect tape interaction can cause:

Failure Meaning
tape corruption nested passes overwrite state
missing gradients tape lifetime ends too early
duplicated gradients nested replay occurs twice
memory explosion all nested levels retained simultaneously

Robust nested systems need explicit tape ownership semantics.

Compiler Perspective

A compiler-based AD system often handles nesting better than runtime operator overloading systems.

Instead of dynamically stacking derivative objects, the compiler transforms intermediate representations explicitly:

program
→ linearized program
→ transposed program
→ differentiated transposed program

This exposes derivative levels structurally.

Compiler IRs can annotate:

IR annotation Purpose
primal variable original value
tangent variable forward derivative
cotangent variable reverse derivative
residual saved intermediate
derivative level nesting separation

This structure reduces ambiguity.

Complexity Explosion

Higher-order derivatives grow rapidly.

For dimension $n$:

Derivative order Tensor size
gradient $n$
Hessian $n^2$
third derivative tensor $n^3$
fourth derivative tensor $n^4$

Nested AD can therefore create exponential growth in storage and computation.

Most practical systems avoid materializing full higher-order tensors.

Instead they compute:

Operation Scalable form
Hessian-vector product $Hv$
Jacobian-vector product $Jv$
vector-Hessian-vector $v^\top Hv$
directional kth derivative scalar directional expansion

Operator forms scale better than explicit tensors.

Practical Design Principles

A robust nested AD system should:

Principle Reason
separate derivative levels avoid perturbation confusion
represent derivatives explicitly improve correctness
isolate tapes per level prevent state corruption
distinguish primal/tangent/cotangent data preserve semantics
expose operator APIs avoid tensor explosion
support compositional transforms enable higher-order programming

Nested AD is fundamentally about composing derivative transformations safely and predictably.

Conceptual View

Automatic differentiation is often introduced as a technique for computing gradients.

Nested AD reveals a deeper interpretation.

Differentiation becomes a compositional program transform:

$$ \mathcal{D}(\mathcal{D}(f)), \quad \mathcal{D}(\mathcal{D}(\mathcal{D}(f))), \quad \ldots $$

The challenge is no longer merely computing derivatives. The challenge is preserving semantic structure across multiple interacting derivative levels while controlling memory, complexity, and numerical stability.