AlexNet has 61 million parameters — but 90% of them can be removed with almost no accuracy loss. Which 10% actually matter, and how do you find them? This lesson derives the pruning problem from first principles, walks every granularity from individual weights to whole channels, builds five criteria for deciding what to cut (including the Optimal Brain Damage second-order derivation), and shows the iterative prune–finetune loop that recovers accuracy at extreme sparsity. MIT 6.5940 by Song Han.
Take AlexNet. 61 million parameters. A weight tensor for the first fully-connected layer alone holds 4096 × 4096 = 16.8 million values. If you print them out and plot a histogram, you'll see something striking: the vast majority of those values hover near zero. Not exactly zero — but close enough that removing them barely changes what the network computes.
That observation, made empirically by Song Han and colleagues in 2015, led to a result that still surprises people: you can prune 90% of AlexNet's connections and, after a brief finetuning pass, restore accuracy to within 0.5% of the original. You can do the same to VGG-16 at 12× compression. The 10% of weights you keep are doing almost all the work. The other 90% are passengers.
This is the core insight behind neural network pruning: trained networks are massively over-parameterized, and most of the excess can be removed without hurting the task they were trained for. The brain does the same thing — a human infant has ~15,000 synapses per neuron; by adolescence that drops to ~7,000 as unused connections are eliminated. The brain prunes based on activity; we prune based on importance criteria.
Here is the weight magnitude histogram for a toy 4×4 fully-connected layer — representative of what you see in practice. Notice how most weights cluster near zero while a few have large magnitude. A naive threshold wipes out the near-zero mass while preserving the large ones. The canvas below lets you visualize this matrix and sweep a sparsity threshold.
A simulated 8×8 weight matrix. Darker cells = larger |w|. The threshold line (drag slider) zeroes out weights below it. Watch how many cells go dark at 50% vs 90% sparsity — yet the few large weights survive.
The tour this lesson takes: (1) formalize what pruning is as an optimization problem; (2) survey the spectrum of granularities — from individual weights up to whole channels — and what each buys you on real hardware; (3) study five criteria for scoring which weights to remove, including the elegant second-order method from Optimal Brain Damage (LeCun, 1989); (4) understand the prune–finetune–repeat loop that recovers accuracy at extreme sparsity; and (5) see how different layers tolerate wildly different sparsity levels.
Let's be precise about what we're actually trying to do. You have a trained network with weights W and an objective (loss) function L(x; W). You want to find a pruned set of weights WP that minimizes the same objective while having few non-zero entries:
Here ‖WP‖0 counts the number of non-zero elements in WP — the L0 "norm" (it's not actually a norm, but it counts non-zeros). N is your target sparsity budget: how many connections the pruned network is allowed to have.
This formulation is clean but the solution is not. Exactly minimizing this is an NP-hard combinatorial search — you'd have to evaluate every possible subset of weights of size N and pick the best one. For a 61M-parameter network, that's more subsets than atoms in the observable universe.
The practical approach is a three-step decomposition:
Two design choices dominate the rest of this lesson: what you prune (granularity — individual weights? entire channels?) and why you prune it (criterion — magnitude? curvature? activation statistics?). These are independent decisions: you can use magnitude as a criterion for either fine-grained or channel pruning.
Worked numbers — the compression table from Han et al. 2015:
| Network | Before | After | Compression | MAC reduction |
|---|---|---|---|---|
| AlexNet | 61M | 6.7M | 9× | 3× |
| VGG-16 | 138M | 10.3M | 12× | 5× |
| GoogLeNet | 7M | 2.0M | 3.5× | 5× |
| ResNet-50 | 26M | 7.47M | 3.4× | 6.3× |
| SqueezeNet | 1M | 0.38M | 3.2× | 3.5× |
Note: compression ratio > MAC reduction for unstructured pruning because zero weights still occupy compute cycles on dense GPU kernels. Real speedup requires sparse hardware support or structured pruning.
Before you decide which weights to prune, you must decide at what scale to prune. This is granularity: how big is the unit you remove? The spectrum runs from individual scalar weights at one end, to entire filters (all weights in an output channel) at the other.
Fine-grained pruning (also called unstructured pruning) removes individual weights. For a 2D weight matrix, you get a pattern that looks random — a few surviving non-zeros scattered irregularly across the tensor. This is the most flexible option: you can remove exactly the weights with the smallest importance score, regardless of where they live. The result is a sparse matrix.
The problem: regular hardware runs on dense matrix operations (GEMM). A sparse weight tensor still occupies the same memory slots as its dense counterpart — you'd need to explicitly zero out entries and skip them in the multiply-accumulate pipeline. On a standard GPU, zeroing out 90% of a matrix doesn't make inference 10× faster; it might be the same speed or slower, because the hardware still executes dense vector operations and just multiplies by 0. To actually accelerate fine-grained sparse inference, you need:
N:M sparsity is a structured-irregular middle ground. For every contiguous M elements in the weight tensor, exactly N are kept and M−N are zeroed. The classic case is 2:4 sparsity: in every group of 4 weights, exactly 2 are non-zero (50% sparsity). NVIDIA's Ampere architecture (A100) supports 2:4 sparsity natively in its Sparse Tensor Cores: weights are stored in a compressed format (non-zero values + 2-bit indices), and the hardware skips zero multiplications automatically, delivering up to 2× throughput. Accuracy tests across BERT, ResNet, and ViT show that 2:4 sparsity typically recovers to within 0.5% of the dense baseline after sparse-aware training.
A 6×8 weight matrix. Switch modes to see which cells are zeroed (gray) and what pattern emerges. Note how fine-grained produces random scatter, N:M produces a regular local pattern, and channel pruning removes entire rows.
Why 2:4 was chosen: The 2-bit index per non-zero (4 possible positions in a group of 4) is the minimum overhead that delivers near-unstructured flexibility. At 2:4 each compressed block stores 2 values + 2 × 2 bits = 2 values + 4 bits overhead, versus the 4 values in the dense form. Memory footprint is exactly halved. The hardware can then schedule sparse-dense matmuls with a dedicated "Sparse GEMM" kernel that reads the compressed matrix and uses the indices to select matching dense columns — achieving 2× throughput with no accuracy-to-sparsity tradeoff beyond what fine-grained 50% pruning would give.
Structured pruning removes entire groups of weights that correspond to a unit the hardware naturally processes: a row, a column, a kernel (one 3×3 slice), or a channel (all kernels feeding one output). The result is a smaller dense network — no sparse formats needed, no special hardware, just the same GPU GEMM running on a tensor with fewer dimensions.
For a convolutional layer, the weight tensor has shape (Cout, Cin, kh, kw). The granularity hierarchy:
| Granularity | Unit removed | Result | HW-friendly? | Compression ratio |
|---|---|---|---|---|
| Fine-grained | Individual w | Sparse tensor | No (needs sparse HW) | Highest |
| Pattern (N:M) | Contiguous groups | Structured sparse | Yes (A100+) | High (fixed 50%) |
| Vector | A row of kernel | Irregular sparse | Partial | Medium |
| Kernel | One k×k filter slice | Irregular sparse | Partial | Medium |
| Channel | Entire output channel | Smaller dense tensor | Yes (always) | Lower |
| Filter | All filters feeding a layer | Smaller dense tensor | Yes (always) | Lower |
Channel pruning is the most popular structured method. If a layer has Cout = 512 output channels and you prune 50% of them, you get Cout = 256. The next layer's input channels must also shrink from 512 to 256 — you remove the corresponding input-channel slices in that layer too. The network literally becomes smaller: every tensor in the pruned region is narrower. No zero-skipping needed; plain dense convolution runs faster because the tensor is smaller.
Worked numbers — channel pruning a conv layer:
Layer spec: Cin=256, Cout=512, k=3×3, output 14×14.
After 50% channel pruning (Cout → 256):
For channel pruning at sparsity s: params and MACs both scale as (1−s), because the output tensor shrinks by (1−s) in the channel dimension and the next layer also shrinks its input side. This is the key advantage: real latency reduction without sparse-execution hardware.
You've decided on a granularity. Now you need to score every weight (or group of weights) with an importance value. The simplest and most widely-used criterion is magnitude: larger absolute value = more important. The intuition is direct — a weight of 0.001 contributes almost nothing to the output, while a weight of 5.0 strongly amplifies or suppresses its input. Remove the small ones.
For element-wise (fine-grained) pruning, the importance of a single weight wi is:
You sort all weights by |w|, and remove the bottom-k%. This is the approach in Han et al. 2015 — what they called "learning connections" (find a threshold, zero out everything below it, retrain with a frozen mask).
For row-wise (structured) pruning, you need a single scalar importance for a whole group of weights. Two natural choices:
Worked example with numbers: Take a 2×4 weight matrix:
# Weight matrix (2 rows, 4 cols) W = [[3, -2, 0, 1], [-5, 0, 1, -0.2]] # L1-norm row importance row0_L1 = |3| + |-2| + |0| + |1| = 6 row1_L1 = |-5| + |0| + |1| + |-0.2| = 6.2 # L2-norm row importance row0_L2 = sqrt(9 + 4 + 0 + 1) = sqrt(14) ≈ 3.74 row1_L2 = sqrt(25 + 0 + 1 + 0.04) = sqrt(26.04) ≈ 5.10 # Both norms agree: row1 is more important than row0. # At 50% row sparsity, row0 would be pruned.
A simulated weight distribution (Gaussian, matching real layer statistics). Drag the sparsity slider to sweep the threshold. Bars below threshold are shown in red (pruned). The accuracy proxy curve shows how accuracy typically degrades.
Per-layer magnitude pruning in PyTorch:
import torch def magnitude_prune(weight: torch.Tensor, sparsity: float) -> torch.Tensor: """Return a binary mask: 1=keep, 0=prune. Keeps top (1-sparsity) fraction.""" flat = weight.abs().view(-1) k = int(sparsity * flat.numel()) # number to prune threshold, _ = torch.kthvalue(flat, k) # k-th smallest mask = (weight.abs() > threshold).float() return mask # Example: 4x4 weight matrix, 50% sparsity W = torch.randn(4, 4) mask = magnitude_prune(W, sparsity=0.5) W_pruned = W * mask # zero out pruned weights print(f"Nonzero weights: {mask.sum().int()}/16") # → 8
weight.data *= mask. If you forget this, gradient descent will gradually re-inflate the pruned weights, undoing your sparsity. This is why production pruning frameworks (torch.nn.utils.prune) register a forward hook that applies the mask on every forward pass.Magnitude-of-weights is a natural criterion when weights are all at the same scale. But two other signals carry more direct information about channel importance: the BatchNorm scaling factor (γ) and the average percentage of zero activations (APoZ). Each exploits information that is simply not available in the weight magnitude alone.
In a convolutional network with BatchNorm, each output channel j has a learned scaling factor γj. The BatchNorm output is:
If γj is very small, channel j barely contributes to the next layer's input — regardless of what the convolutional weights themselves look like. A γ near zero says: "this channel's output is being globally suppressed by training." That's a much more reliable importance signal than raw weight magnitude, because it reflects the network's learned decision about which channels matter.
The Network Slimming technique (Liu et al. ICCV 2017) adds an L1 sparsity penalty on all γ values during training:
This regularization nudges unimportant channel scaling factors toward zero. After training, pruning is trivial: sort channels by |γj|, remove the bottom-k%, and retrain.
| Channel | γ value | Action |
|---|---|---|
| Filter 0 | 1.17 | Keep |
| Filter 1 | 0.10 | Prune |
| Filter 2 | 0.29 | Prune |
| Filter 3 | 0.82 | Keep |
| Filter N-1 | 0.56 | Keep |
A different approach looks not at weights but at activations. ReLU networks produce many zeros in activation maps — any negative pre-activation becomes exactly zero. If a neuron (or channel) produces zero for the vast majority of inputs, it is contributing almost nothing to downstream computation. APoZ — Average Percentage of Zero activations — quantifies this.
where ϕ is an indicator that equals 1 when the activation is exactly zero, N is the number of evaluation samples, and H×W is the spatial dimension of the feature map.
APoZ is computed on a calibration dataset (typically a few thousand training examples) by running a forward pass and recording which channels are zero. A channel with APoZ = 95% is dead for 95% of inputs — an obvious pruning candidate.
import torch def compute_apoz(model, dataloader, layer_name): """Compute Average Percentage of Zeros for each channel in a layer.""" zero_counts = None total = 0 def hook(module, inp, output): nonlocal zero_counts, total # output: (N, C, H, W) zeros = (output == 0).float().sum(dim=[0, 2, 3]) # sum over N,H,W → (C,) count = output.shape[0] * output.shape[2] * output.shape[3] if zero_counts is None: zero_counts = zeros else: zero_counts += zeros total += count h = model.get_submodule(layer_name).register_forward_hook(hook) model.eval() with torch.no_grad(): for x, _ in dataloader: model(x) h.remove() return zero_counts / total # APoZ per channel
Magnitude tells you how big a weight is. But what you really want to know is: how much does the loss change if I remove this weight? For a well-trained network, removing weight wi means setting δwi = wi (deleting it is equivalent to perturbing it by its current value). The change in loss is δL.
Step 1: Taylor expansion of δL. Expand L(x; W − δW) around the current weights W:
where gi = ∂L/∂wi (gradient) and hij = ∂2L/(∂wi∂wj) (Hessian entry).
Step 2: Three OBD assumptions.
Step 3: The OBD saliency. Now set δwi = wi (pruning weight i means removing it, i.e., the perturbation equals the weight value) and δwj = 0 for all j ≠ i:
This is the OBD saliency — the estimated loss increase from removing weight i. Weights with small saliency are safe to prune. The key difference from magnitude: saliency weighs the weight by its curvature hii. A small weight in a high-curvature region (hii large) can be crucial; a large weight in a flat region (hii small) can be safely removed.
Two-weight model. Contours show the loss surface. Current weights: w1=0.1, w2=3.0. Note that w1 sits in a steep valley (high curvature), while w2 is in a flat ridge. Magnitude prunes w1; OBD correctly prunes w2. Drag w2 to explore.
Worked numerical example: Two weights, w1 = 0.1 and w2 = 3.0.
OBD decision: Prune w2 (saliency 0.045 < 1.0 — costs only 0.045 loss units). Magnitude decision: Prune w1 (|w1| = 0.1 < |w2| = 3.0). OBD avoids the expensive mistake: w1 sits on a steep ridge of the loss surface (h11=200 means moving w1 by 0.1 costs 1.0 loss units), while w2=3.0 sits in a flat valley (h22=0.01 means removing it barely changes the loss).
Taylor expansion pruning (1st order): A simpler compromise. Keep the gradient term and approximate saliency as |gi × wi|. This is equivalent to the first-order Taylor series and is much cheaper than computing second derivatives. It works well for small pruning steps (when the network isn't far from the original optimum). Used in Molchanov et al. CVPR 2019.
You've picked a granularity and a criterion. Now comes the most practical question: do you prune the entire network at once (one-shot), or do you alternate between pruning and finetuning in small increments (iterative)?
The intuition for why iterative wins: imagine you're hiking and you want to remove 90% of your gear. If you dump 90% immediately, you've made a huge random change to how you carry load — you might drop critical items and the remaining 10% might not form a coherent system. But if you remove 10% at a time, evaluating and re-packing after each step, you can carefully select what to discard and the remaining gear stays coherent.
In network terms: one-shot pruning at 90% sparsity simultaneously removes many interdependent weights. The loss surface shifts sharply; the surviving weights were optimized for a different parameter configuration and finetuning may not recover accuracy. Iterative pruning at 10% per step makes a small change, finetunes to re-adapt the surviving weights, then repeats. Each finetuning step keeps the network near a good minimum before the next prune.
The standard loop:
# Iterative magnitude pruning with PyTorch import torch import torch.nn as nn import torch.nn.utils.prune as prune model = YourModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.001) target_sparsity = 0.90 steps = 9 # prune 10% → 20% → ... → 90% step_size = target_sparsity / steps # 0.10 per step for step in range(steps): current_sparsity = (step + 1) * step_size # --- Prune each linear/conv layer --- for name, module in model.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): prune.l1_unstructured(module, name='weight', amount=step_size) # --- Finetune for a few epochs --- for epoch in range(finetune_epochs): train_one_epoch(model, optimizer, dataloader) print(f"Step {step+1}: sparsity={current_sparsity:.0%}") # Remove the pruning hooks and make sparsity permanent for name, module in model.named_modules(): if hasattr(module, 'weight_mask'): prune.remove(module, 'weight')
Not all layers are created equal. The first convolutional layer in a CNN processes raw pixels — its small number of learned edge detectors is hard to replace. If you prune 80% of them, the whole network sees degraded features. The last fully-connected layer, by contrast, often has massive redundancy and can tolerate 80–90% sparsity with negligible accuracy loss.
Sensitivity analysis: for each layer independently, measure the accuracy drop as a function of sparsity. Plot it as a curve (accuracy vs sparsity per layer). This tells you the maximum safe sparsity for each layer and motivates non-uniform sparsity: sensitive layers get low sparsity (10–30%); robust layers get high sparsity (70–90%). The overall compression is a weighted average.
Drag the sparsity slider. The orange curve shows one-shot pruning; the teal curve shows iterative pruning + finetuning. Notice how iterative recovers far more accuracy at the same final sparsity. Click “Run iterative step” to animate each prune–finetune cycle.
This showcase puts together everything from this lesson: granularity, criterion, and the per-layer sensitivity landscape. You're the ML engineer. You have a small CNN with five layers. Each layer has a different sensitivity to pruning — the first conv is fragile, the FC layers are robust. Your tools are: choose the granularity (fine or channel), choose the criterion (magnitude or OBD proxy), choose per-layer sparsity, and see how accuracy and MAC savings evolve.
Each bar shows accuracy drop vs sparsity for that layer (measured independently). Taller bars = more sensitive = set lower target sparsity. The tool recommends a safe uniform vs adaptive budget allocation.
Watch the conv weight tensor shrink as you increase channel sparsity. The gray cells disappear, and the next layer's input dimension (left side) shrinks too. Slide to see how params and MACs scale linearly with retained channels.
Worked MAC savings at different sparsity targets:
| Sparsity | Remaining params | Remaining MACs | Notes |
|---|---|---|---|
| 0% | 1,179,648 | 231.2M | Baseline (C_in=256, C_out=512, 3×3, 14×14) |
| 30% | 825,754 | 161.8M | Safe for most layers; 1.43× speedup |
| 50% | 589,824 | 115.6M | Aggressive but recoverable with finetuning; 2× speedup |
| 70% | 353,894 | 69.4M | High sparsity; only for robust later layers; 3.3× speedup |
| 90% | 117,965 | 23.1M | Extreme; only final FC layers can typically tolerate this |
Channel pruning gives exact proportional scaling: 50% channel sparsity ⇒ 50% param reduction ⇒ 50% MAC reduction ⇒ ~50% wall-clock speedup (varies with memory bandwidth and other overheads).
You now have the core vocabulary of pruning. Here's the full picture in one table, followed by what comes next.
| Granularity | Best criteria | Hardware target | When to use |
|---|---|---|---|
| Fine-grained | OBD saliency, magnitude | Custom ASIC (EIE), A100 Sparse Cores | Maximum compression ratio; edge accelerators with sparse support |
| N:M (2:4) | Magnitude within groups | NVIDIA Ampere (A100, RTX 30+) | GPU deployment needing real speedup at 50% sparsity |
| Vector/kernel | L1/L2 group norm | Partially regular HW | Rare — middle ground between fine and channel |
| Channel/filter | BN γ, L1-norm of filter, Taylor criterion | Any GPU/CPU/MCU | Latency-critical deployment; no sparse HW available |
| Criterion | Complexity | When it excels | Limitation |
|---|---|---|---|
| Magnitude (|W|) | O(n) | Large networks; fast iterations; works well on average | Ignores curvature; can misprioritize in steep regions |
| L1/L2 group norm | O(n) | Structured/channel pruning; natural for filter ranking | Same blind spot as magnitude |
| BN γ | O(C) free | Networks with BatchNorm; training-time regularization | Only for channel pruning; requires retraining with L1-γ reg |
| APoZ | O(n × data) | Detecting dead neurons; domain-specific pruning | Data-dependent; must run calibration set forward pass |
| OBD saliency | O(n) + Hessian diag | When you want principled pruning near a minimum | Diagonal Hessian approximation; off-diagonal errors |
This lesson covered the foundations. Pruning II dives into three open frontiers: