Higher-Order Reverse Mode

Reverse mode is efficient for scalar-output functions because it propagates one adjoint backward through the computation and produces a full gradient. For

Higher-Order Reverse Mode

Reverse mode is efficient for scalar-output functions because it propagates one adjoint backward through the computation and produces a full gradient. For

$$ f : \mathbb{R}^n \to \mathbb{R}, $$

one reverse pass computes

$$ \nabla f(x) $$

at a cost comparable to a small constant multiple of evaluating $f$.

Higher-order reverse mode asks for derivatives of reverse-mode derivative computations. The simplest case is differentiating the gradient:

$$ D(\nabla f)(x) = \nabla^2 f(x). $$

The idea is mathematically clean. The implementation is much harder.

First-Order Reverse Mode Recap

A program computes intermediate values:

$$ v_1, v_2, \ldots, v_k. $$

The final output is

$$ y = v_k. $$

Reverse mode associates each intermediate value $v_i$ with an adjoint:

$$ \bar{v}_i = \frac{\partial y}{\partial v_i}. $$

The backward pass starts with

$$ \bar{y} = 1. $$

Then it propagates adjoints backward through local derivative rules.

For example, if

$$ z = xy, $$

then reverse mode applies

$$ \bar{x} \mathrel{+}= \bar{z}y, $$

$$ \bar{y} \mathrel{+}= \bar{z}x. $$

This produces first derivatives.

Differentiating the Backward Pass

Higher-order reverse mode differentiates these backward computations.

For

$$ z = xy, $$

the backward rules are:

$$ \bar{x} \mathrel{+}= \bar{z}y, $$

$$ \bar{y} \mathrel{+}= \bar{z}x. $$

These rules are themselves programs. They use multiplication, addition, saved primal values, and adjoints. If we differentiate this backward program, we obtain second-order information.

This means a higher-order reverse system must treat the backward pass as differentiable code, not as an opaque implementation detail.

Reverse-over-Reverse

Reverse-over-reverse means applying reverse mode to a computation that was itself produced by reverse mode.

Conceptually:

g = grad(f)
H = jacobian(g)

If the second Jacobian is computed by reverse mode, the system is doing reverse-over-reverse.

This can compute second derivatives, but it creates several problems.

The first reverse pass records the primal computation. The second reverse pass may need to record the backward computation. This may include adjoint accumulation, tape reads, saved intermediates, and control flow in derivative rules.

The resulting computation can be much larger than the original one.

Tapes Become Part of the Program

Many reverse-mode systems use a tape. The tape records operations from the forward pass so that the backward pass can replay them in reverse.

For first-order AD, the tape is an implementation device.

For higher-order reverse mode, the tape can become part of the differentiated computation.

This creates hard questions:

Issue Why it matters
tape allocation allocation behavior may affect differentiated execution
mutation adjoints are often accumulated destructively
aliasing multiple references may point to the same storage
saved values higher-order rules may need more saved intermediates
custom gradients first-order custom rules may not have valid higher derivatives
control flow backward control flow must remain differentiable

A clean higher-order system separates mathematical derivative semantics from tape mechanics.

Mutation and Adjoint Accumulation

Reverse mode commonly updates adjoints by mutation:

x_bar += z_bar * y
y_bar += z_bar * x

This is natural and efficient.

But higher-order differentiation must account for how these updates depend on primal values and incoming adjoints. Mutation introduces ordering and aliasing concerns.

For example, if two paths contribute to the same adjoint, reverse mode sums them. The sum is mathematically commutative, but floating point mutation has an order. Higher-order differentiation can expose this ordering when numerical reproducibility matters.

A robust implementation may use an intermediate representation where adjoint accumulation is explicit and analyzable.

Saved Primal Values

Reverse-mode rules often need primal values from the forward pass.

For

$$ z = \sin x, $$

the backward rule is

$$ \bar{x} \mathrel{+}= \bar{z}\cos x. $$

The backward pass needs $x$, or some equivalent saved value.

For second derivatives, differentiating the backward rule also needs the derivative of $\cos x$:

$$ d(\bar{x}) = d(\bar{z})\cos x - \bar{z}\sin x,dx. $$

So a second-order system must preserve enough information to differentiate the backward rule itself.

If a first-order system saves too little, higher-order differentiation may be impossible or incorrect.

Custom Gradients

Many AD systems allow custom first-order gradients.

For example, a user may define:

forward: y = op(x)
backward: x_bar = custom_rule(x, y, y_bar)

This is enough for first-order AD.

For higher-order AD, the custom backward rule must also be differentiable and mathematically correct.

A custom rule can be first-order correct but second-order wrong.

For example, a rule may stop gradients through an intermediate value for numerical reasons. That may preserve the first derivative while destroying the second derivative.

Therefore, production systems should distinguish:

Rule type Meaning
first-order custom gradient valid for gradients only
higher-order custom gradient valid under nested AD
non-differentiable custom gradient blocks higher-order differentiation
symbolic derivative rule supplies explicit higher-order behavior

This distinction prevents silent errors.

Reverse Mode for Vector Outputs

For

$$ F : \mathbb{R}^n \to \mathbb{R}^m, $$

reverse mode computes vector-Jacobian products:

$$ w^\top J_F(x). $$

Higher-order reverse mode can differentiate these products.

If

$$ \phi(x) = w^\top F(x), $$

then reverse mode computes

$$ \nabla \phi(x) = J_F(x)^\top w. $$

Differentiating this gradient gives

$$ \nabla^2 \phi(x). $$

So higher-order reverse mode naturally handles scalarizations of vector-output functions.

This is important in machine learning, where losses are scalar but models are vector-valued internally.

Reverse Mode and Hessian-Vector Products

Higher-order reverse mode can compute Hessian-vector products, but pure reverse-over-reverse is often not the best route.

For scalar $f$, a Hessian-vector product can be computed as:

$$ H_f(x)v = \nabla(\nabla f(x)^\top v). $$

This uses reverse mode on the scalar function

$$ \phi(x) = \nabla f(x)^\top v. $$

The inner gradient usually comes from reverse mode. The outer gradient also uses reverse mode. That is reverse-over-reverse.

It works, but may require differentiating the backward pass. Forward-over-reverse often avoids some of this complexity by pushing a tangent through the gradient computation instead.

Complexity and Memory

Higher-order reverse mode can be expensive because it nests derivative computations.

For first-order reverse mode, memory is dominated by saved intermediates.

For higher-order reverse mode, memory may include:

Memory source Description
primal tape operations from original forward pass
backward tape operations from derivative computation
saved primal values values needed by first backward pass
saved adjoint values values needed by differentiated backward pass
nested AD metadata tags, levels, tangent or adjoint structures
temporary derivative arrays intermediate higher-order values

The exact cost depends on the program and implementation. But the key point is stable: higher-order reverse mode can consume much more memory than first-order reverse mode.

Perturbation and Cotangent Levels

Nested AD needs distinct derivative levels.

In forward mode, this prevents perturbation confusion. In reverse mode, the analogous issue concerns adjoint levels.

A nested AD system must know which derivative level each tangent or cotangent belongs to.

Otherwise, an inner derivative computation may accidentally consume or modify derivative information intended for an outer computation.

Correct systems usually track derivative levels explicitly:

level 0: primal computation
level 1: first derivative
level 2: second derivative

This bookkeeping is not cosmetic. It is part of the semantics of nested differentiation.

Checkpointing for Higher-Order Reverse

Checkpointing trades recomputation for memory.

In first-order reverse mode, checkpointing avoids saving every intermediate value. During the backward pass, some values are recomputed as needed.

In higher-order reverse mode, checkpointing becomes more complicated because recomputation itself may occur inside a differentiated computation.

A checkpoint must specify:

Question Requirement
what is saved primal values, derivative values, or both
what is recomputed forward code, backward code, or nested code
at which AD level derivative level must remain consistent
with what side effects recomputation must preserve semantics

Checkpointing remains essential, but the implementation must be level-aware.

Implementation Strategy

A practical higher-order reverse system usually needs a disciplined internal representation.

Good designs tend to make these objects explicit:

primal value
tangent value
cotangent value
AD level
saved residuals
linearized function
transpose rule

The system should avoid hiding derivative semantics inside opaque mutation-heavy runtime code.

One useful design is to split differentiation into two phases:

linearize(f, x) -> y, pullback
transpose(linearized_program) -> reverse program

Then higher-order differentiation can operate on the linearized program representation rather than on ad hoc tape operations.

Practical Guidance

Higher-order reverse mode is powerful, but it should be used with care.

Use it when:

Use case Reason
differentiating optimization procedures outer gradients require gradients of gradients
meta-learning training rules are themselves differentiated
implicit layers derivatives of solver outputs are needed
curvature analysis second-order information is required
differentiable programming languages nested AD is part of the language model

Avoid treating it as a default replacement for simpler methods.

For Hessian-vector products, prefer forward-over-reverse when available. For full Hessians, prefer structured or sparse methods when possible. For higher-order derivatives beyond second order, consider Taylor mode or specialized higher-order representations.

Design Principle

First-order reverse mode can be implemented as a runtime technique.

Higher-order reverse mode needs semantic discipline.

The backward pass must be a differentiable program with clear rules for values, adjoints, mutation, saved residuals, and derivative levels. Without that structure, nested reverse mode becomes fragile, memory-heavy, and prone to silent mathematical errors.