Population-Based Training

Population-based training, or PBT, is a hyperparameter optimization method that trains many models at the same time.

Population-based training, or PBT, is a hyperparameter optimization method that trains many models at the same time. Each model has its own weights and hyperparameters. During training, weak models are replaced or modified using information from stronger models.

Grid search, random search, and Bayesian optimization usually treat each trial as a separate run. A configuration is selected before training begins, and it usually stays fixed until the run ends. PBT changes this assumption. Hyperparameters can change while training is already in progress.

This makes PBT useful for hyperparameters whose best values vary over time, such as learning rate, weight decay, dropout, augmentation strength, entropy bonus, or reinforcement learning exploration rate.

The Basic Idea

PBT keeps a population of workers. Each worker trains a model.

A worker has:

Component Meaning
Weights $\theta_i$ Current model parameters
Hyperparameters $\lambda_i$ Current training settings
Score $s_i$ Current validation or reward metric
Checkpoint Saved state for replacement

At regular intervals, the population is evaluated. Strong workers are allowed to continue. Weak workers copy weights from stronger workers and then perturb their hyperparameters.

The algorithm alternates between two phases:

Phase Purpose
Explore Modify hyperparameters
Exploit Copy from better-performing workers

This creates an evolutionary process over both model states and hyperparameters.

Why PBT Is Different

In ordinary hyperparameter search, each trial trains from scratch. If a trial starts with a poor learning rate, the entire run may be wasted.

PBT can recover from poor choices. A weak worker may copy a better checkpoint and continue training with modified hyperparameters.

This means PBT searches over schedules, not just fixed values. For example, instead of choosing one learning rate for all training, PBT may discover that a high learning rate works early and a low learning rate works later.

A fixed configuration looks like:

$$ \lambda = (\eta, \lambda_{\text{wd}}, p_{\text{drop}}). $$

A schedule is a function of training time:

$$ \lambda(t) = (\eta(t), \lambda_{\text{wd}}(t), p_{\text{drop}}(t)). $$

PBT searches for useful schedules by adapting hyperparameters during training.

Population State

Suppose the population has $N$ workers.

At time $t$, worker $i$ has state:

$$ (\theta_i^{(t)}, \lambda_i^{(t)}, s_i^{(t)}). $$

Here:

Symbol Meaning
$\theta_i^{(t)}$ Model weights for worker $i$ at time $t$
$\lambda_i^{(t)}$ Hyperparameters for worker $i$ at time $t$
$s_i^{(t)}$ Validation score for worker $i$ at time $t$

A worker trains locally for some number of steps, then reports its score. The population controller compares workers and decides whether any worker should be replaced.

Exploitation

Exploitation means copying from a stronger worker.

For example, after evaluation, workers are ranked by validation score. A worker in the bottom 20 percent may copy the checkpoint of a worker in the top 20 percent.

If worker $j$ is strong and worker $i$ is weak, exploitation performs:

$$ \theta_i \leftarrow \theta_j, $$

$$ \lambda_i \leftarrow \lambda_j. $$

This gives the weak worker a better starting point. It avoids spending more compute on a clearly poor trajectory.

In practice, exploitation copies:

Item Usually copied
Model weights Yes
Optimizer state Often yes
Learning rate scheduler state Depends
Hyperparameters Yes
Training step Usually yes
Data loader state Usually no

Copying optimizer state can matter. For AdamW, the momentum buffers influence future updates. If weights are copied but optimizer state is not, training may behave differently after replacement.

Exploration

Exploration means modifying copied hyperparameters.

After copying from a strong worker, the weak worker perturbs the inherited hyperparameters. This prevents all workers from becoming identical.

A simple perturbation rule is multiplicative:

$$ \eta \leftarrow \eta \times r, $$

where

$$ r \in {0.8, 1.2}. $$

For example, if the copied learning rate is $10^{-3}$, exploration may change it to $8\times10^{-4}$ or $1.2\times10^{-3}$.

For continuous hyperparameters, perturbation may use random noise:

$$ \eta \leftarrow \eta \cdot \exp(\epsilon), \qquad \epsilon \sim \mathcal{N}(0,\sigma^2). $$

For categorical hyperparameters, exploration may randomly resample from a set:

optimizer = random.choice(["SGD", "AdamW"])

Exploration should respect valid ranges:

learning_rate = min(max(learning_rate, 1e-5), 1e-1)
dropout = min(max(dropout, 0.0), 0.5)

A Minimal PBT Algorithm

A simple PBT loop looks like this:

population = initialize_workers(num_workers)

for step in range(total_steps):
    for worker in population:
        worker.train(num_steps=steps_per_interval)

    for worker in population:
        worker.score = worker.evaluate()

    ranked = sort_by_score(population)

    bottom = ranked[:num_replace]
    top = ranked[-num_replace:]

    for weak_worker in bottom:
        strong_worker = random.choice(top)

        weak_worker.load_checkpoint(strong_worker.checkpoint)
        weak_worker.config = perturb(strong_worker.config)

This code omits many engineering details, but it shows the core pattern.

The key difference from random search is that a worker can inherit both weights and hyperparameters from another worker.

PyTorch Worker Structure

A PBT worker needs to save and load complete training state.

import torch

class Worker:
    def __init__(self, model, optimizer, config, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.config = config
        self.device = device
        self.step = 0
        self.score = None

    def train(self, loader, num_steps):
        self.model.train()

        iterator = iter(loader)

        for _ in range(num_steps):
            try:
                x, y = next(iterator)
            except StopIteration:
                iterator = iter(loader)
                x, y = next(iterator)

            x = x.to(self.device)
            y = y.to(self.device)

            logits = self.model(x)
            loss = torch.nn.functional.cross_entropy(logits, y)

            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()

            self.step += 1

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()

        correct = 0
        total = 0

        for x, y in loader:
            x = x.to(self.device)
            y = y.to(self.device)

            logits = self.model(x)
            pred = logits.argmax(dim=1)

            correct += (pred == y).sum().item()
            total += y.numel()

        self.score = correct / total
        return self.score

The checkpoint should include model weights, optimizer state, configuration, and step:

def state_dict(self):
    return {
        "model": self.model.state_dict(),
        "optimizer": self.optimizer.state_dict(),
        "config": self.config,
        "step": self.step,
        "score": self.score,
    }

def load_state_dict(self, state):
    self.model.load_state_dict(state["model"])
    self.optimizer.load_state_dict(state["optimizer"])
    self.config = dict(state["config"])
    self.step = state["step"]
    self.score = state["score"]

Perturbing Hyperparameters

A perturbation function should know which hyperparameters are mutable.

import copy
import random

def perturb(config):
    config = copy.deepcopy(config)

    for name in ["learning_rate", "weight_decay"]:
        if name in config:
            factor = random.choice([0.8, 1.2])
            config[name] *= factor

    if "dropout" in config:
        config["dropout"] += random.choice([-0.05, 0.05])
        config["dropout"] = min(max(config["dropout"], 0.0), 0.5)

    config["learning_rate"] = min(max(config["learning_rate"], 1e-5), 1e-1)
    config["weight_decay"] = min(max(config["weight_decay"], 1e-6), 1e-1)

    return config

If the optimizer uses the learning rate stored in param_groups, the optimizer must be updated after perturbation:

def apply_config_to_optimizer(optimizer, config):
    for group in optimizer.param_groups:
        group["lr"] = config["learning_rate"]
        group["weight_decay"] = config["weight_decay"]

For architecture hyperparameters such as hidden dimension or number of layers, simple PBT cannot modify them after training starts because the parameter shapes would change. PBT is best suited for training hyperparameters and regularization settings.

Choosing the Evaluation Interval

PBT requires an interval between exploit-explore steps.

If the interval is too short, scores are noisy and workers may copy from models that are only temporarily ahead.

If the interval is too long, weak workers waste compute before being replaced.

A practical interval depends on the task:

Task Typical interval
Small image classification Every few epochs
Large supervised training Every few thousand steps
Reinforcement learning Every fixed number of environment steps
Language model fine-tuning Every validation checkpoint
Diffusion training Less frequent, due to noisy metrics

The interval should be long enough that validation scores contain useful signal.

Metrics for Selection

PBT needs a scalar score for ranking workers.

For classification, this might be validation accuracy. For language modeling, it may be negative validation loss or negative perplexity. For reinforcement learning, it may be average episodic return.

When the objective has multiple terms, a scalar score can combine them:

$$ s = \text{accuracy} - \alpha \cdot \text{latency} - \beta \cdot \text{memory}. $$

This allows PBT to optimize under deployment constraints.

The score should be stable enough to compare workers. If validation measurement is noisy, use moving averages or repeated evaluations.

Population Size

Population size controls diversity.

A small population is cheaper but explores fewer schedules. A large population explores more schedules but requires more hardware.

Population size Behavior
4 to 8 Minimal, useful for small experiments
16 to 32 Common practical range
64 or more Large-scale search

PBT works best when workers run in parallel. If only one GPU is available, PBT loses much of its advantage because the population must be simulated sequentially.

Strengths and Weaknesses

Strengths Weaknesses
Searches dynamic schedules Requires parallel compute
Reuses partial training progress More complex than random search
Can recover from poor early choices Harder to reproduce exactly
Works well for RL and large training runs Needs careful checkpoint management
Handles nonstationary hyperparameters Less useful for architecture choices

PBT is especially useful when the best hyperparameters change during training.

When to Use PBT

PBT is appropriate when:

Situation Reason
Many workers can run in parallel PBT is population-based
Hyperparameters should change over time PBT discovers schedules
Training is long Mid-training adaptation helps
Early bad choices are costly Workers can recover
Reinforcement learning is involved RL often has unstable dynamics

PBT is less appropriate when training is cheap, when only a few trials are possible, or when the main choices are fixed architecture decisions.

PBT and Learning Rate Schedules

Learning rate schedules are a natural fit for PBT. Instead of choosing a schedule manually, PBT can adapt the learning rate based on observed performance.

For example, one worker may keep a high learning rate longer and improve quickly. Another may reduce the learning rate earlier and generalize better. PBT can copy from the better trajectory and perturb it.

The resulting schedule may be irregular:

$$ 10^{-3} \rightarrow 1.2\times10^{-3} \rightarrow 9.6\times10^{-4} \rightarrow 7.7\times10^{-4} \rightarrow 9.2\times10^{-4}. $$

This kind of schedule may perform well, even if it looks less clean than cosine decay or step decay.

Summary

Population-based training trains a population of models while adapting their hyperparameters during training. Weak workers copy weights and hyperparameters from stronger workers, then perturb those hyperparameters to continue exploration.

PBT combines training, selection, checkpoint reuse, and hyperparameter search into one process. It is useful for long-running training jobs, reinforcement learning, and hyperparameters whose best values change over time.

Its cost is engineering complexity. A reliable PBT system needs parallel workers, checkpoints, reproducible logging, stable evaluation, and careful handling of optimizer state.