Skip to content

Dynamic Computation Graphs

Deep learning models are built from sequences of mathematical operations.

Deep learning models are built from sequences of mathematical operations. During training, the system must compute not only the forward result of these operations, but also derivatives with respect to model parameters. PyTorch achieves this through dynamic computation graphs.

A computation graph is a directed graph where nodes represent operations or tensors, and edges represent data dependencies between them. When a tensor operation is executed, PyTorch records how the result was produced. This recorded structure allows gradients to be computed automatically during backpropagation.

The defining feature of PyTorch is that this graph is dynamic. The graph is created at runtime as Python code executes.

Static Versus Dynamic Graphs

Deep learning frameworks historically followed two major approaches.

ApproachGraph constructionExample frameworks
Static graphDefine graph before executionEarly TensorFlow
Dynamic graphBuild graph during executionPyTorch

In a static graph system, the computation graph is declared first and executed later. The graph behaves like a compiled program.

In a dynamic graph system, the graph is built operation by operation as Python executes. Each forward pass constructs a new graph.

For example:

import torch

x = torch.tensor(2.0, requires_grad=True)

y = x * x + 3 * x + 1

When this code runs, PyTorch dynamically records the operations:

x
├── multiply(x, x)
├── multiply(3, x)
└── add(...)

The graph exists only because the operations were executed.

This gives PyTorch several advantages:

AdvantageMeaning
Natural Python control flowUse loops, branches, recursion
Easier debuggingInspect intermediate tensors directly
Interactive experimentationWorks naturally in notebooks
Flexible architecturesGraph structure may vary per input

Computation as a Graph

Suppose we define

y=x2+3x+1. y = x^2 + 3x + 1.

The computation can be decomposed into simpler operations:

a=x2, a = x^2, b=3x, b = 3x, y=a+b+1. y = a + b + 1.

This forms a graph:

x
├── square ── a
├── multiply by 3 ── b
└──────────────┐
          add operations
               y

Each node stores enough information for differentiation. During the backward pass, gradients propagate in reverse order through this graph.

In PyTorch:

x = torch.tensor(2.0, requires_grad=True)

a = x * x
b = 3 * x
y = a + b + 1

print(y)

The tensor y remembers how it was constructed.

print(y.grad_fn)

The field grad_fn references the operation that produced the tensor.

The Forward Pass

The forward pass computes outputs from inputs.

A neural network is a composition of functions:

xh1h2y. x \to h_1 \to h_2 \to \cdots \to y.

For a simple multilayer network:

h1=σ(W1x+b1), h_1 = \sigma(W_1x + b_1), h2=σ(W2h1+b2), h_2 = \sigma(W_2h_1 + b_2), y=W3h2+b3. y = W_3h_2 + b_3.

Each operation creates new tensors and graph nodes.

In PyTorch:

from torch import nn

model = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 2),
)

x = torch.randn(3, 4)

y = model(x)

During execution, PyTorch dynamically constructs the graph associated with this forward computation.

The graph records:

  • Tensor operations
  • Parent tensors
  • Operation types
  • Information needed for differentiation

The Backward Pass

Training requires gradients of the loss with respect to parameters.

Suppose

L=(fθ(x),y). L = \ell(f_\theta(x), y).

We need

θL. \nabla_\theta L.

PyTorch computes these gradients using reverse-mode automatic differentiation, commonly called backpropagation.

The backward pass traverses the computation graph in reverse order.

Example:

x = torch.tensor(2.0, requires_grad=True)

y = x * x + 3 * x + 1

y.backward()

print(x.grad)

The derivative is

dydx=2x+3. \frac{dy}{dx} = 2x + 3.

At x=2x=2:

dydx=7. \frac{dy}{dx} = 7.

PyTorch computes this automatically.

The backward pass applies the chain rule repeatedly across the graph.

Reverse-Mode Automatic Differentiation

Deep learning models usually have many parameters and a scalar loss.

Suppose:

f:RnR. f : \mathbb{R}^n \to \mathbb{R}.

The input dimension nn may be millions or billions. The output is usually a scalar loss.

Reverse-mode differentiation is efficient in this setting because it computes gradients of the scalar output with respect to all parameters in one backward traversal.

The chain rule is the core principle.

If

z=f(y),y=g(x), z = f(y), \quad y = g(x),

then

dzdx=dzdydydx. \frac{dz}{dx} = \frac{dz}{dy} \frac{dy}{dx}.

PyTorch applies this rule automatically across the graph.

Leaf Tensors and Gradients

A tensor with requires_grad=True becomes part of gradient tracking.

x = torch.tensor([1.0, 2.0], requires_grad=True)

Such tensors are called leaf tensors if they are created directly by the user rather than by another operation.

Gradients accumulate in .grad only for leaf tensors.

x = torch.tensor(2.0, requires_grad=True)

y = x * x

y.backward()

print(x.grad)

Intermediate tensors usually do not store gradients unless explicitly requested.

Dynamic Graph Construction

The graph is rebuilt every forward pass.

This is important because the executed operations may depend on input data or control flow.

Example:

def f(x):
    if x.sum() > 0:
        return x * 2
    else:
        return x * x

Different inputs may produce different computation graphs.

This flexibility makes PyTorch natural for:

  • Recursive models
  • Variable-length sequences
  • Tree structures
  • Dynamic routing
  • Conditional computation
  • Reinforcement learning environments

Static graph systems historically struggled with these patterns.

Control Flow in PyTorch

Because graphs are built dynamically, ordinary Python control flow works naturally.

Loops:

x = torch.tensor(1.0, requires_grad=True)

y = x

for _ in range(5):
    y = y * 2

y.backward()

print(x.grad)

Conditionals:

x = torch.tensor(-2.0, requires_grad=True)

if x.item() > 0:
    y = x
else:
    y = x * x

y.backward()

Recursion also works naturally.

This is one reason PyTorch became popular in research environments.

Graph Lifetime

The computation graph exists only as long as needed for gradient computation.

After backward(), PyTorch frees most graph information to save memory.

Example:

x = torch.tensor(2.0, requires_grad=True)

y = x * x

y.backward()

Calling backward() a second time fails because the graph has been freed:

# y.backward()  # RuntimeError

To preserve the graph:

y.backward(retain_graph=True)

Retaining graphs increases memory usage and should only be used when necessary.

Detaching Tensors

Sometimes a tensor should stop participating in gradient computation.

The detach() operation creates a tensor without gradient history.

x = torch.tensor(2.0, requires_grad=True)

y = x * x

z = y.detach()

print(z.requires_grad)

This is useful when:

  • Storing intermediate results
  • Logging metrics
  • Performing inference
  • Preventing gradient flow
  • Implementing target networks

Detached tensors share underlying storage but do not participate in autograd.

Disabling Gradient Tracking

Inference does not require gradients. Tracking them wastes memory and compute.

PyTorch provides torch.no_grad():

model.eval()

with torch.no_grad():
    outputs = model(x)

This disables graph construction inside the block.

Benefits:

BenefitEffect
Lower memory useNo graph storage
Faster executionNo autograd overhead
Cleaner inferenceNo accidental gradient tracking

Modern PyTorch also provides torch.inference_mode() for further optimization.

In-Place Operations

PyTorch supports in-place tensor modification:

x.add_(1)

The underscore convention indicates in-place modification.

In-place operations can reduce memory usage, but they are dangerous when tensors participate in autograd graphs.

Example:

x = torch.tensor(2.0, requires_grad=True)

y = x * x

# x.add_(1)  # may break autograd

Changing values needed for gradient computation can invalidate the graph.

For this reason, in-place operations should be used carefully.

Computational Graphs in Neural Networks

Consider a linear layer:

y=Wx+b. y = Wx + b.

Suppose:

xRB×d, x \in \mathbb{R}^{B \times d}, WRh×d, W \in \mathbb{R}^{h \times d}, bRh. b \in \mathbb{R}^{h}.

The forward pass computes:

y=xW+b. y = xW^\top + b.

The graph includes:

  • Matrix multiplication
  • Broadcasting
  • Addition
  • Activation functions
  • Loss computation

For example:

model = nn.Linear(4, 2)

x = torch.randn(3, 4)
target = torch.randint(0, 2, (3,))

loss_fn = nn.CrossEntropyLoss()

logits = model(x)
loss = loss_fn(logits, target)

loss.backward()

The graph links every parameter in model to the scalar loss.

After backpropagation:

print(model.weight.grad.shape)
print(model.bias.grad.shape)

Gradients are tensors with the same shapes as the parameters.

Graph Depth and Memory

Deep models produce deep graphs.

Transformers with hundreds of layers may create enormous computation graphs during training. Activations from the forward pass must often be stored because backward computation needs them.

This leads to high memory usage.

Several techniques reduce memory cost:

TechniqueIdea
Mixed precisionSmaller tensor formats
Gradient checkpointingRecompute activations
Activation offloadingMove tensors between devices
Tensor parallelismSplit tensors across GPUs

Graph memory is one of the central constraints in large-scale deep learning.

Eager Execution

PyTorch originally emphasized eager execution. Operations execute immediately.

Example:

x = torch.tensor([1.0, 2.0])

y = x + 1

print(y)

This feels like ordinary Python programming.

Advantages:

  • Easy debugging
  • Direct inspection
  • Interactive development
  • Natural control flow

However, eager execution may introduce overhead because operations are dispatched one at a time from Python.

Modern PyTorch combines eager programming with graph compilation through torch.compile.

Graph Compilation

PyTorch can trace and optimize execution graphs.

compiled_model = torch.compile(model)

Compilation may:

  • Fuse operations
  • Remove Python overhead
  • Optimize memory access
  • Improve kernel scheduling

This bridges the gap between dynamic flexibility and static optimization.

The programmer still writes ordinary PyTorch code, but execution becomes more optimized internally.

Dynamic Graphs and Modern AI Systems

Dynamic computation graphs are especially important for modern AI systems because many tasks involve variable structure.

Examples include:

DomainDynamic behavior
NLPVariable sequence lengths
Reinforcement learningEnvironment-dependent trajectories
Graph learningDifferent graph sizes
AgentsConditional tool use
Mixture-of-experts modelsInput-dependent routing
Reasoning systemsBranching computation

A static graph fixed in advance may not naturally express these behaviors.

Dynamic graph systems allow the executed computation to depend on the data itself.

Summary

A computation graph represents tensor operations and their dependencies. PyTorch builds this graph dynamically during execution. The graph records how outputs were computed from inputs and parameters.

During training, PyTorch traverses the graph backward to compute gradients using reverse-mode automatic differentiation. This enables efficient optimization of neural networks with millions or billions of parameters.

Dynamic graphs provide flexibility, natural Python control flow, and easy debugging. They are one of the defining design choices of PyTorch and a major reason for its adoption in modern deep learning research and production systems.