Memory Explosion

Reverse-mode automatic differentiation trades computation for memory. To compute gradients efficiently, the backward pass requires access to intermediate values produced...

Memory Explosion

Reverse-mode automatic differentiation trades computation for memory. To compute gradients efficiently, the backward pass requires access to intermediate values produced during the forward pass. For large computational graphs, storing these intermediates can dominate total system cost.

This phenomenon is known as memory explosion.

It is one of the central engineering constraints in modern automatic differentiation systems.

Memory pressure limits:

  • model size,
  • batch size,
  • sequence length,
  • simulation resolution,
  • graph depth,
  • and hardware utilization.

In many large systems, computation is abundant while memory bandwidth and capacity are scarce. Modern AD design therefore revolves around managing intermediate state efficiently.

Why Reverse Mode Requires Memory

Consider a computation:

$$ x \to v_1 \to v_2 \to \cdots \to y. $$

Reverse mode computes gradients backward:

$$ \bar{v}_i = \frac{\partial y}{\partial v_i}. $$

Each local reverse rule requires primal values.

Example:

$$ z = xy. $$

Backward propagation uses:

$$ \bar{x} \mathrel{+}= \bar{z}y, $$

$$ \bar{y} \mathrel{+}= \bar{z}x. $$

The backward pass therefore needs access to both:

$$ x, y. $$

If the forward computation has millions or billions of operations, storing every intermediate becomes extremely expensive.

Linear Growth of Activation Memory

Suppose a computation graph contains:

$$ n $$

operations, each producing an activation of size:

$$ s. $$

Naively, reverse mode stores:

$$ O(ns) $$

memory.

For deep neural networks:

Component Typical scale
Activations GBs to TBs
Parameters MBs to hundreds of GBs
Gradients Similar to parameter size
Optimizer state 2× to 8× parameter size

Activation memory often dominates total memory usage during training.

Memory in Deep Networks

Consider a transformer with:

  • depth $L$,
  • sequence length $T$,
  • hidden dimension $d$,
  • batch size $B$.

Activation tensors scale roughly as:

$$ O(BTdL). $$

Attention layers introduce additional terms:

$$ O(BT^2). $$

Long-context transformers therefore experience rapid memory growth.

Example:

Sequence length Attention memory
1K manageable
8K large
32K severe
128K often impractical

Memory becomes the bottleneck before computation does.

Reverse Mode as a Tape

Many AD systems conceptualize reverse mode as a tape.

During forward execution:

  1. execute operation,
  2. record metadata,
  3. store required intermediates.

During backward execution:

  1. traverse tape backward,
  2. retrieve stored values,
  3. apply reverse rules.

The tape may contain:

Stored item Purpose
Primal values Local derivatives
Tensor shapes Broadcasting and reduction
Data types Correct gradient kernels
Operation identifiers Backward dispatch
Control flow metadata Dynamic graph reconstruction

Large graphs therefore create large tapes.

Wengert Lists

A Wengert list stores intermediate variables explicitly:

$$ v_1, v_2, \dots, v_n. $$

Each variable depends on previous variables.

Example:

$$ v_1 = x_1 x_2, $$

$$ v_2 = \sin(v_1), $$

$$ v_3 = v_2 + x_3. $$

Reverse mode traverses:

$$ v_3 \to v_2 \to v_1. $$

The larger the dependency chain, the larger the stored state.

Dynamic Graphs

Dynamic graph systems allocate graph structures at runtime.

Examples include:

  • eager execution frameworks,
  • dynamic control flow,
  • recursive differentiable programs.

Dynamic graphs increase memory pressure because:

  • graph structure itself consumes memory,
  • metadata cannot always be statically optimized,
  • allocations become fragmented,
  • runtime bookkeeping increases overhead.

Static graph compilers can optimize memory reuse more aggressively.

Memory Fragmentation

Memory explosion is not only about total size.

Fragmentation matters.

Suppose free memory exists in many small blocks rather than one contiguous region. Large tensor allocations may fail even though nominal free memory appears sufficient.

GPU allocators therefore use:

  • pooling,
  • caching allocators,
  • arena allocation,
  • and tensor reuse strategies.

Fragmentation becomes severe in dynamic workloads with varying tensor shapes.

Gradient Storage

Backward propagation requires gradient buffers.

For parameters:

$$ \theta_i, $$

systems store:

$$ \nabla_{\theta_i} L. $$

Optimizer state often multiplies memory further.

Example: Adam optimizer.

For parameter tensor:

$$ \theta, $$

Adam stores:

Quantity Memory multiplier
Parameters
Gradients
First moment
Second moment

Total:

$$ 4\times $$

parameter memory before activations are considered.

Higher-Order AD

Higher-order differentiation dramatically increases memory cost.

Suppose reverse mode is nested inside reverse mode.

The outer reverse pass must preserve:

  • primal values,
  • inner reverse state,
  • gradient intermediates,
  • higher-order adjoints.

Naive higher-order reverse mode can produce exponential memory growth.

This is one reason why Hessian computation is substantially harder than gradient computation.

Recomputation Tradeoffs

Memory can be reduced by recomputing values instead of storing them.

This creates a tradeoff:

Strategy Memory Computation
Store everything High Low
Recompute everything Low High
Checkpointing Moderate Moderate

The central idea:

Instead of saving all activations, save only selected checkpoints.

Missing values are recomputed during backward execution.

Checkpointing

Checkpointing partitions the graph into segments.

Suppose:

$$ x_0 \to x_1 \to \cdots \to x_n. $$

Rather than storing every $x_i$, store only selected states:

$$ x_0, x_k, x_{2k}, \dots $$

During backward execution:

  1. reload nearest checkpoint,
  2. recompute intermediate states,
  3. continue backward pass.

This reduces memory from:

$$ O(n) $$

toward:

$$ O(\sqrt{n}) $$

or even logarithmic scaling depending on strategy.

Revolve Algorithm

Optimal checkpoint scheduling is a classical problem.

The Revolve algorithm computes recomputation schedules minimizing memory under bounded storage.

It treats reverse-mode differentiation as a reversible execution problem.

This becomes important in:

  • PDE solvers,
  • climate simulation,
  • differentiable physics,
  • and long time-horizon systems.

Activation Checkpointing in Deep Learning

Modern deep learning systems commonly use activation checkpointing.

Typical policy:

  • save activations at layer boundaries,
  • recompute inside segments.

This enables training larger models on limited hardware.

Tradeoff:

Effect Result
Lower memory Larger models
Higher recomputation Slower training

Large language models rely heavily on this technique.

Reversible Networks

Some architectures reconstruct activations exactly during backward execution.

Example:

$$ y_1 = x_1 + f(x_2), $$

$$ y_2 = x_2 + g(y_1). $$

The original inputs can be recovered:

$$ x_2 = y_2 - g(y_1), $$

$$ x_1 = y_1 - f(x_2). $$

Thus activations need not be stored explicitly.

Reversible networks trade:

  • extra recomputation,
  • stricter architectural constraints,

for dramatically lower memory use.

Gradient Checkpoint Granularity

Checkpoint placement matters.

Fine-grained checkpoints:

Property Result
Low recomputation High memory
High scheduling complexity More metadata

Coarse-grained checkpoints:

Property Result
Lower memory More recomputation
Simpler scheduling Reduced flexibility

Optimal checkpoint placement depends on:

  • graph structure,
  • tensor sizes,
  • recomputation cost,
  • hardware bandwidth.

Memory in Attention

Self-attention is especially memory intensive.

Attention scores require:

$$ QK^T. $$

For sequence length $T$:

$$ O(T^2) $$

memory.

Backward propagation also requires:

  • attention probabilities,
  • normalization statistics,
  • softmax intermediates.

Long-context transformers therefore become memory-bound quickly.

Flash Attention

Flash Attention reduces memory usage by avoiding explicit materialization of large attention matrices.

Instead:

  • compute attention in blocks,
  • fuse operations,
  • recompute partial quantities as needed.

This changes memory scaling dramatically.

The key idea:

Trade extra arithmetic for reduced memory traffic.

Modern accelerators are often compute-rich but bandwidth-limited, making this trade favorable.

Offloading

Large systems sometimes move tensors between devices.

Examples:

Offload target Purpose
CPU RAM Extend GPU capacity
NVMe SSD Very large models
Remote memory Distributed systems

Offloading reduces peak GPU memory but introduces transfer latency.

Efficient scheduling becomes essential.

Tensor Compression

Activation storage can be compressed.

Methods include:

Method Idea
Reduced precision float16 or bfloat16
Quantization Integer representation
Sparsification Store only important entries
Delta encoding Store differences
Low-rank compression Factorized activations

Compression reduces memory but may degrade gradient accuracy.

In-Place Operations

In-place updates reuse memory:

$$ x \leftarrow x + y. $$

This saves allocations.

However, reverse mode may still require the original value of $x$.

Unsafe in-place mutation can therefore destroy information needed for gradients.

AD systems carefully track aliasing and mutation dependencies.

Static Memory Planning

Static graph compilers can analyze tensor lifetimes.

Suppose two tensors are never simultaneously live.

Then they may share memory.

This resembles register allocation in classical compilers.

Static planning enables:

  • tensor reuse,
  • buffer recycling,
  • preallocation,
  • memory pooling.

Framework compilers aggressively optimize these schedules.

Distributed Memory

Large models exceed single-device memory.

Distributed strategies include:

Strategy Partition
Data parallelism Batch dimension
Tensor parallelism Tensor dimensions
Pipeline parallelism Layers
ZeRO optimization Optimizer states

Memory becomes a distributed systems problem rather than a single-device problem.

Sparse Activation Systems

Sparse models activate only subsets of parameters.

Mixture-of-experts architectures are a major example.

Sparse activation reduces:

  • activation memory,
  • optimizer memory,
  • communication volume.

However, routing metadata and irregular execution introduce new complexity.

Memory Bandwidth vs Capacity

Capacity is only part of the problem.

Bandwidth matters equally.

Backward propagation repeatedly loads:

  • activations,
  • weights,
  • gradients,
  • optimizer states.

Memory traffic often dominates runtime.

Modern AD systems therefore optimize:

  • tensor locality,
  • kernel fusion,
  • cache reuse,
  • recomputation balance.

Computational Graph Lifetime

Some graphs are short-lived.

Others persist across iterations.

Persistent graphs consume memory through:

  • retained references,
  • closure capture,
  • caching,
  • graph history accumulation.

Improper graph cleanup is a common source of memory leaks.

Gradient Accumulation

Large effective batch sizes may exceed device memory.

Gradient accumulation simulates larger batches:

  1. compute gradients on microbatches,
  2. accumulate gradients,
  3. update parameters later.

This trades:

  • longer training steps,
  • more gradient storage,

for lower activation memory per step.

Memory Complexity of Reverse Mode

Forward mode complexity:

Quantity Complexity
Memory Small
Compute Scales with inputs

Reverse mode:

Quantity Complexity
Memory Potentially very large
Compute Scales with outputs

The low computational cost of reverse mode comes partly from high memory consumption.

This is a fundamental tradeoff.

Core Idea

Reverse-mode automatic differentiation requires access to intermediate program state during backward propagation. As computational graphs grow larger, storing these intermediates becomes a dominant systems constraint. Memory explosion is therefore not an implementation accident but a structural consequence of reverse accumulation. Modern AD systems manage this through checkpointing, recomputation, reversible computation, compression, static planning, and distributed execution strategies.