Transfer Learning

Transfer learning reuses a model trained on one task as the starting point for another task.

Transfer learning reuses a model trained on one task as the starting point for another task. In image classification, this usually means taking a convolutional network or vision transformer trained on a large image dataset, replacing its final classifier, and fine-tuning it on a smaller target dataset.

The central idea is simple: early and middle layers learn reusable visual features. They detect edges, colors, textures, shapes, object parts, and higher-level patterns. A new task often needs different class labels, but it can still benefit from these learned representations.

Why Transfer Learning Works

A randomly initialized model starts with no useful visual features. It must learn low-level patterns and high-level decision boundaries from the target dataset. This requires more data, more compute, and more tuning.

A pretrained model already contains useful representations. Fine-tuning adapts those representations to the new task.

Let a pretrained model be written as

$$ f_\theta(x) = g_\phi(h_\psi(x)). $$

Here $h_\psi$ is the feature extractor, and $g_\phi$ is the classifier head. In transfer learning, we usually keep $h_\psi$, replace $g_\phi$, and train the new model on the target classes.

For a target dataset with $K$ classes, the new classifier produces

$$ z \in \mathbb{R}^{K}. $$

The feature extractor may be frozen, partially unfrozen, or fully fine-tuned.

Feature Extraction Versus Fine-Tuning

There are two common transfer learning modes.

Mode What changes Best when
Feature extraction Freeze pretrained backbone, train only new head Dataset is small, classes are similar to pretraining data
Fine-tuning Train some or all pretrained layers Dataset is larger, task differs from pretraining data

Feature extraction is cheaper and less likely to overfit. Fine-tuning is more flexible and usually gives better final accuracy when enough data is available.

Loading a Pretrained Model

PyTorch provides pretrained models through torchvision.models.

import torch
import torch.nn as nn
from torchvision import models

model = models.resnet18(
    weights=models.ResNet18_Weights.DEFAULT
)

The model was trained with a specific preprocessing convention. The weights object provides the matching transform:

weights = models.ResNet18_Weights.DEFAULT
transform = weights.transforms()

Using the correct transform matters. If the model was pretrained on normalized images, inference and fine-tuning should use the same normalization.

Replacing the Classifier Head

A ResNet classifier ends with a fully connected layer called fc.

num_classes = 5

model.fc = nn.Linear(
    in_features=model.fc.in_features,
    out_features=num_classes,
)

Now the model outputs logits with shape:

[B, 5]

Only the final layer shape changed. The convolutional backbone remains the same.

For other architectures, the classifier location differs. For example, many torchvision models use classifier instead of fc.

model = models.efficientnet_b0(
    weights=models.EfficientNet_B0_Weights.DEFAULT
)

num_classes = 5
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)

Always inspect the model before replacing the head:

print(model)

Freezing the Backbone

To use the pretrained model as a fixed feature extractor, disable gradients for the backbone.

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)

The new head has requires_grad=True by default. The optimizer should receive only trainable parameters:

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
    weight_decay=1e-4,
)

This avoids updating frozen parameters and reduces optimizer state memory.

Fine-Tuning the Whole Model

For full fine-tuning, leave all parameters trainable:

for param in model.parameters():
    param.requires_grad = True

Then use a smaller learning rate than ordinary training from scratch:

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-5,
    weight_decay=1e-4,
)

Pretrained weights already encode useful structure. Large learning rates can destroy that structure quickly. This is sometimes called catastrophic forgetting.

Differential Learning Rates

Often, the classifier head should learn faster than the pretrained backbone. PyTorch allows parameter groups with different learning rates.

backbone_params = []
head_params = []

for name, param in model.named_parameters():
    if name.startswith("fc."):
        head_params.append(param)
    else:
        backbone_params.append(param)

optimizer = torch.optim.AdamW(
    [
        {"params": backbone_params, "lr": 3e-5},
        {"params": head_params, "lr": 3e-4},
    ],
    weight_decay=1e-4,
)

This trains the new classifier more aggressively while making smaller updates to the pretrained representation.

Progressive Unfreezing

Progressive unfreezing starts by training only the classifier head. Then it gradually unfreezes deeper parts of the backbone.

A common schedule is:

Phase Trainable parameters
Phase 1 Classifier head only
Phase 2 Last block plus classifier head
Phase 3 Full model

This approach is useful when the target dataset is small. It reduces the risk of damaging useful pretrained features early in training.

For ResNet, the final residual block is usually layer4.

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

# later
for param in model.layer4.parameters():
    param.requires_grad = True

After changing which parameters are trainable, recreate the optimizer so it tracks the correct parameter set.

Batch Normalization During Transfer Learning

Batch normalization needs special care. A batch normalization layer has two kinds of state:

State Example Updated by
Trainable parameters scale and bias gradients
Running statistics running mean and variance forward passes in training mode

Freezing parameters does not automatically freeze running statistics. If the model remains in training mode, batch normalization statistics may still change.

For small target datasets, this can hurt performance. One option is to keep batch normalization layers in evaluation mode while training the classifier.

def set_batchnorm_eval(module):
    if isinstance(module, nn.BatchNorm2d):
        module.eval()

model.apply(set_batchnorm_eval)

This keeps running statistics fixed. The right choice depends on dataset size and domain shift.

Transfer Learning Training Loop

The training loop is the same as ordinary classification. The main differences are the pretrained initialization, replaced classifier head, and optimizer parameter selection.

def train_one_epoch(model, loader, loss_fn, optimizer, device):
    model.train()

    total_loss = 0.0
    total_correct = 0
    total_count = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * labels.size(0)
        total_correct += (preds == labels).sum().item()
        total_count += labels.size(0)

    return {
        "loss": total_loss / total_count,
        "accuracy": total_correct / total_count,
    }

Validation remains unchanged:

@torch.no_grad()
def evaluate(model, loader, loss_fn, device):
    model.eval()

    total_loss = 0.0
    total_correct = 0
    total_count = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = loss_fn(logits, labels)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * labels.size(0)
        total_correct += (preds == labels).sum().item()
        total_count += labels.size(0)

    return {
        "loss": total_loss / total_count,
        "accuracy": total_correct / total_count,
    }

Complete Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models

device = "cuda" if torch.cuda.is_available() else "cpu"

weights = models.ResNet18_Weights.DEFAULT
transform = weights.transforms()

train_set = datasets.ImageFolder("dataset/train", transform=transform)
val_set = datasets.ImageFolder("dataset/val", transform=transform)

train_loader = DataLoader(
    train_set,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_set,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

num_classes = len(train_set.classes)

model = models.resnet18(weights=weights)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)

model = model.to(device)

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
    weight_decay=1e-4,
)

best_acc = 0.0

for epoch in range(10):
    train_metrics = train_one_epoch(
        model=model,
        loader=train_loader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        device=device,
    )

    val_metrics = evaluate(
        model=model,
        loader=val_loader,
        loss_fn=loss_fn,
        device=device,
    )

    print(
        f"epoch={epoch + 1} "
        f"train_loss={train_metrics['loss']:.4f} "
        f"train_acc={train_metrics['accuracy']:.4f} "
        f"val_loss={val_metrics['loss']:.4f} "
        f"val_acc={val_metrics['accuracy']:.4f}"
    )

    if val_metrics["accuracy"] > best_acc:
        best_acc = val_metrics["accuracy"]

        torch.save(
            {
                "model": model.state_dict(),
                "classes": train_set.classes,
                "class_to_idx": train_set.class_to_idx,
                "weights": "ResNet18_Weights.DEFAULT",
                "val_acc": best_acc,
            },
            "transfer_classifier.pt",
        )

This version trains only the final classifier. It is a strong baseline for small image datasets.

When to Fine-Tune More Layers

The decision depends on the target data.

Situation Recommended approach
Very small dataset Freeze backbone, train head
Small dataset similar to ImageNet Freeze backbone, then unfreeze final block
Medium dataset Fine-tune final blocks with small learning rate
Large dataset Fine-tune full model
Strong domain shift Fine-tune more layers
Medical, satellite, or scientific images Fine-tune more layers, possibly from self-supervised weights

A dataset of dog and cat photos is close to common pretraining data. A dataset of microscope images is farther away. Larger domain shift usually requires deeper adaptation.

Common Mistakes

The most common transfer learning errors are simple.

Mistake Consequence
Wrong normalization Poor accuracy
Forgetting to replace classifier head Wrong output shape
Training frozen parameters accidentally Wasted memory and compute
Freezing all parameters including the new head No learning
Learning rate too high Destroyed pretrained features
Changing class order at inference Wrong labels
Random validation transforms Noisy metrics
Ignoring batch normalization behavior Unstable fine-tuning

Before training, verify which parameters are trainable:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

This simple check prevents many silent failures.

Practical Defaults

A good first transfer learning setup is:

Component Default
Model ResNet18 or EfficientNet-B0
Weights Official pretrained weights
Image size 224
Transform Weight-specific transform
First phase Freeze backbone
Optimizer AdamW
Head learning rate $10^{-3}$
Fine-tune learning rate $10^{-5}$ to $10^{-4}$
Loss Cross-entropy
Metric Validation accuracy
Checkpoint Best validation metric

These defaults are not always optimal, but they are usually stable. Once this baseline works, tune augmentation, learning rate, batch size, unfreezing depth, and model size.

Summary

Transfer learning uses pretrained models as reusable representation learners. For image classification, the usual workflow is to load pretrained weights, replace the classifier head, train the head, and optionally fine-tune deeper layers.

Feature extraction is safer and cheaper. Full fine-tuning is more powerful but more sensitive to learning rate, dataset size, and normalization. The best practice is to begin with a frozen-backbone baseline, then unfreeze progressively when the validation set shows that more adaptation is needed.