AD in Python

Python became the dominant language for modern machine learning and differentiable computing because it combines a simple programming model with access to high-performance...

AD in Python

Python became the dominant language for modern machine learning and differentiable computing because it combines a simple programming model with access to high-performance native libraries. Most Python automatic differentiation systems therefore follow a hybrid architecture:

Layer Role
Python frontend User-facing model and control logic
Tensor runtime Dense array execution
AD engine Gradient propagation
Native backend CPU/GPU kernels in C/C++/CUDA
Compiler subsystem Graph optimization and lowering

Python itself is slow for numerical kernels. The important observation is that tensor operations are executed outside the Python interpreter. The AD system therefore differentiates tensor programs driven by Python control.

Tensor-Centric Computation

Modern Python AD systems are built around tensors.

A tensor object typically contains:

Component Meaning
Shape Tensor dimensions
Dtype Numeric type
Storage Underlying memory
Device CPU, GPU, TPU
Gradient metadata Information for reverse mode
Graph reference Dependency structure

A simple example:

x = tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * x
z = y.sum()
z.backward()

After execution:

x.grad

contains:

$$ [2, 4, 6] $$

The user writes imperative Python code, but the AD system records tensor operations and constructs derivative propagation rules internally.

Dynamic Computational Graphs

Many Python systems use dynamic graphs.

A graph is built during execution. Each tensor operation creates graph nodes connecting inputs to outputs.

For example:

a = x + 1
b = sin(x)
c = a * b

produces a graph:

x
├── add ── a
├── sin ── b
└── mul(a,b) ── c

Each node stores:

Field Purpose
Operation type Determines derivative rule
Inputs Parent references
Outputs Result tensor
Saved tensors Needed for backward pass
Backward function Propagates adjoints

Dynamic graphs are flexible because they naturally support:

  • Loops
  • Branches
  • Recursion
  • Variable shapes
  • Interactive execution

This matched the needs of machine learning research, where models evolve rapidly.

Reverse Mode in Python Systems

Most Python AD frameworks optimize for scalar-loss reverse mode because neural network training requires gradients with respect to many parameters.

Suppose:

y = f(x)

where:

$$ f : \mathbb{R}^n \rightarrow \mathbb{R} $$

Reverse mode computes:

$$ \nabla f(x) $$

with cost proportional to a small multiple of the forward evaluation.

Internally, the reverse pass traverses the graph backward:

  1. Initialize output adjoint to 1.
  2. Visit graph nodes in reverse topological order.
  3. Apply local derivative rules.
  4. Accumulate gradients into parent tensors.

For multiplication:

$$ z = xy $$

the reverse rule is:

$$ \bar{x} += \bar{z}y $$

$$ \bar{y} += \bar{z}x $$

$$ z = xy

x = tensor([1, 2, 3])
y = x + 1
def f(x):
    return sin(x) + x * x
x
├── sin
├── mul(x,x)
└── add
x = placeholder()
y = x * x
session.run(y, feed_dict={x: ...})
x += y
class MyOp(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return ...
    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        return ...
g = grad(f)
h = grad(g)
y.backward()
grad(f)
vmap(f)
jit(f)
pmap(f)
jit(grad(f))
Python
→ traced graph
→ normalized IR
→ optimized IR
→ backend lowering
→ machine code or accelerator kernels
sin
exp
matmul
reshape
broadcast_to
sum
transpose
logsumexp
softplus
cross_entropy
layer_norm

$$

The Python frontend hides this machinery, but the runtime manages graph traversal, tensor lifetimes, and adjoint accumulation.

Eager Execution

Many Python systems use eager execution.

Operations execute immediately:

y is computed immediately rather than deferred into a static graph.

Advantages include:

Advantage Explanation
Debuggability Intermediate values are visible
Natural control flow Python semantics preserved
Interactive workflows REPL and notebooks work naturally
Simpler mental model Execution follows source order

The downside is optimization difficulty. Since the runtime sees operations incrementally, it has limited global visibility.

Tracing and Graph Capture

To recover optimization opportunities, many systems trace Python functions into graph representations.

Example:

Tracing executes the function with special tensor objects that record operations instead of performing ordinary computation.

The result is an intermediate graph:

The graph can then be optimized, fused, compiled, or lowered to accelerators.

Tracing enables:

Optimization Purpose
Kernel fusion Reduce launch overhead
Constant folding Eliminate redundant computation
Memory planning Reuse buffers
Vectorization Improve throughput
Device lowering Generate accelerator code

This produces a hybrid execution model:

Mode Behavior
Eager mode Flexible interactive execution
Traced mode Optimized graph execution

Static vs Dynamic Graph Systems

Early Python AD systems often used static graphs.

The user first defined a graph:

and later executed it:

Static graphs enabled aggressive optimization but created awkward programming models.

Dynamic systems later became dominant because they matched ordinary Python execution.

The distinction today is less strict. Modern systems often combine both:

System Main style
PyTorch Dynamic eager execution
TensorFlow 1.x Static graph
TensorFlow 2.x Eager + tracing
JAX Functional tracing
Tinygrad Minimal dynamic graph
MindSpore Graph-oriented hybrid execution

Mutation and In-Place Operations

Python tensor systems frequently support mutation:

Mutation complicates reverse mode because the old value of x may be needed during the backward pass.

Systems handle this differently.

Strategy Explanation
Disallow unsafe mutation Simplifies correctness
Version counters Detect illegal overwrites
Functionalization Rewrite mutation into pure operations
Copy-on-write Preserve old values automatically
Tape snapshots Save overwritten tensors

PyTorch, for example, tracks tensor versions to detect modifications that invalidate gradient computation.

Custom Gradient Functions

Many operations need manually defined derivatives.

A Python framework usually exposes an API:

This separates:

Phase Role
Forward Compute primal output
Context storage Save required intermediates
Backward Compute adjoint propagation

Custom gradients are critical for:

  • Numerical stability
  • Efficient kernels
  • External libraries
  • Implicit differentiation
  • Physics simulators
  • Specialized GPU operations

Higher-Order Differentiation

Python systems increasingly support higher-order derivatives.

Example:

This requires differentiating the backward pass itself.

The system must ensure:

  • Reverse-mode operations are differentiable
  • Graphs remain valid through nesting
  • Saved tensors survive nested passes
  • Perturbation confusion is avoided

Higher-order AD is important for:

Application Need
Meta-learning Differentiate optimization
Physics Curvature information
Scientific computing Hessian-vector products
Implicit methods Jacobian structure
Probabilistic inference Laplace approximations

Python and Functional Transformations

Some Python AD systems adopt a more functional style.

Instead of mutating tensor objects:

they expose explicit transformations:

These transformations compose.

Example:

means:

  1. Differentiate f
  2. Compile the resulting derivative program

This model treats differentiation as a pure program transformation rather than as a side effect attached to tensors.

Compilation Pipelines

Modern Python AD frameworks often lower programs into compiler IRs.

The pipeline may look like:

Common IR forms include:

IR Purpose
FX graphs Python graph capture
HLO Tensor compiler IR
MLIR Multi-level compiler infrastructure
XLA graphs Accelerator optimization
TorchScript IR PyTorch compilation

AD may operate at multiple levels:

Level Differentiation target
Python AST Source transformation
Runtime graph Dynamic tracing
Tensor IR Graph-level AD
LLVM IR Low-level compiler differentiation

Interaction with NumPy

NumPy heavily influenced Python AD systems.

Many frameworks mimic NumPy APIs:

This allows numerical code to become differentiable with minimal changes.

However, ordinary NumPy arrays do not carry gradient metadata. Frameworks therefore provide tensor types that emulate NumPy behavior while tracking derivatives.

Compatibility layers are essential for ecosystem adoption.

GPU and Accelerator Execution

Python frameworks usually execute tensor kernels outside Python.

The Python interpreter orchestrates computation, but dense operations are dispatched to:

Backend Typical implementation
CPU BLAS, vectorized kernels
GPU CUDA or ROCm kernels
TPU XLA-compiled programs
Specialized accelerators Vendor-specific runtimes

AD systems must therefore manage:

  • Device placement
  • Gradient synchronization
  • Memory transfers
  • Kernel scheduling
  • Mixed precision

The derivative computation becomes part of a distributed runtime system.

Memory Management

Reverse mode requires storing intermediates from the forward pass.

Memory costs can dominate execution.

Strategies include:

Technique Purpose
Gradient checkpointing Recompute instead of storing
Activation rematerialization Trade compute for memory
Buffer reuse Reduce allocations
Lazy gradient allocation Allocate only when needed
Static memory planning Optimize graph execution

Large neural networks are often constrained more by activation memory than by arithmetic throughput.

Numerical Stability

Naive derivatives can be numerically unstable.

Examples include:

Operation Problem
Softmax Overflow
Logarithm Singularities near zero
Division Unstable denominators
Exponentials Exploding gradients
Normalization Small variance instability

Python AD systems therefore rely heavily on custom stable primitives.

For example:

often have carefully engineered backward implementations.

Major Python AD Systems

System Main characteristics
PyTorch Dynamic eager reverse mode
TensorFlow Hybrid graph/eager system
JAX Functional transformations and tracing
Autograd Pure NumPy-based tracing
Tinygrad Minimal educational framework
MindSpore Graph-oriented execution
Chainer Early dynamic graph system

These systems differ mainly in:

  • Graph construction strategy
  • Compilation model
  • Mutation semantics
  • Transformation interface
  • Hardware integration

Python as an AD Host Language

Python succeeded because it provided:

Feature Importance
Simple syntax Rapid experimentation
Scientific ecosystem NumPy, SciPy, plotting
Dynamic execution Flexible model definition
Native extension support Access to optimized kernels
Interactive workflow Notebook-based research

The AD engine is usually not written primarily in Python. Python acts as the orchestration layer above highly optimized native runtimes and compiler systems.

Modern Python AD frameworks therefore resemble compiler toolchains hidden behind an imperative scripting interface.