You compressed a ResNet-50 down to MobileNetV2-Tiny via pruning and quantization — but accuracy fell from 76% to 48%. What if the big model could teach the small one, sharing not just hard labels but the full probability landscape it learned? Knowledge distillation is that teaching protocol: a student network trained on a teacher's soft outputs learns more from every example than any hard label could convey.
You just finished building a TinyML pipeline. You started with a ResNet-50 (76% top-1 on ImageNet, 4.1G MACs). You pruned it aggressively — down to MobileNetV2-Tiny with 23.5M MACs. You quantized it to INT8. You ran it on the MCU. It fits in 512 kB SRAM. It runs in 18 ms. Your manager is happy — until you check accuracy. The tiny model hits 48%, down from 76%. The compression worked; the accuracy collapsed.
Why the collapse? There are actually two distinct problems. Problem 1 — capacity under-fitting: a tiny model like MobileNetV2-Tiny has far fewer parameters than ResNet-50. Even if trained perfectly, it may simply lack the representational capacity to reach 76%. The "expressiveness ceiling" is lower. Problem 2 — the hard-label problem: you trained the tiny model on one-hot labels: cat = [1, 0, 0, 0, 0, …]. Those labels are maximally sparse. Every example tells the model exactly one fact: "this is a cat." The model gets no information about how much a given image resembles a dog, or a leopard, or a lion. Hard labels are brutal teachers.
Here's the insight that Hinton, Vinyals, and Dean realized in 2014: a trained big network is a much better teacher than a one-hot label. Why? Because when a trained ResNet sees an image of a tabby cat, it doesn't output [1, 0, 0, 0, …]. It outputs something like [0.82, 0.06, 0.05, 0.04, 0.02, …] — a full probability distribution across all 1,000 ImageNet classes. That distribution says: "probably cat, a little dog, a little leopard." Those nonzero probabilities on non-cat classes encode structural knowledge about visual similarity. They are richer than any ground-truth label. And we can transfer that richness to the small model.
The rest of this lesson derives the exact math, builds the practical loss function, and explores what information the teacher can share beyond output logits — intermediate features, attention maps, and relationships between examples.
Hinton coined the term "dark knowledge" for the information buried in a trained network's output probabilities that never shows up in hard labels. Let's make this concrete with actual numbers.
Suppose a trained teacher network looks at an image of a tabby cat and outputs raw logits of [5, 1, 0.5] for the classes [cat, dog, car]. At temperature T=1, the softmax gives:
The hard label is just [1, 0, 0]. The soft output is [0.971, 0.018, 0.011]. What's the difference? The soft output says: "this image is very much a cat, and cats are a little bit like dogs, and not much like cars." The relative probability between dog and car — 0.018 vs 0.011 — tells the student something genuinely useful: dogs and cats share visual features (fur, ears, eyes) that cars don't. A student trained purely on hard labels never sees this.
Hinton's term "dark knowledge" is evocative precisely because this information is invisible in the dataset. No human annotator wrote "this cat image is 1.6× more dog-like than car-like." That knowledge was discovered by the teacher during training and encoded implicitly in its weight space — and soft targets are the only way to extract and transfer it.
Here's a second worked example with a more ambiguous image — a drawing of a cat:
| Class | Teacher logit | Soft prob (T=1) | Hard label |
|---|---|---|---|
| cat | 3.0 | 0.731 | 1 |
| dog | 2.0 | 0.269 | 0 |
| car | –5.0 | ~0.000 | 0 |
Notice that the teacher is much less confident than on the real photo (0.731 vs 0.971). That reduced confidence is genuine information: the drawing looks ambiguous. A student trained with the hard label [1, 0, 0] gets no signal about this ambiguity. The soft label [0.731, 0.269, 0.000] correctly calibrates the student's uncertainty. This is especially valuable at test time: a well-distilled student produces well-calibrated probabilities, not just sharp argmax predictions.
python # Worked example: teacher soft targets vs hard labels for 3 classes import torch import torch.nn.functional as F # Teacher logits for one image [cat=3.0, dog=2.0, car=-5.0] teacher_logits = torch.tensor([3.0, 2.0, -5.0]) # Soft targets (T=1): the dark knowledge signal soft_targets = F.softmax(teacher_logits, dim=0) # tensor([0.7311, 0.2689, 0.0000]) ← dog gets 26.9%, not 0% # Hard label: cat=1, dog=0, car=0 hard_label = torch.tensor([1, 0, 0]) # Cross-entropy on hard label gives ZERO gradient for dog logit # KL divergence on soft targets gives nonzero gradient ← the key difference
We know soft targets carry dark knowledge — but there's a practical problem. For a confident teacher, the soft targets at T=1 may still be extremely peaked: [0.971, 0.018, 0.011]. The nonzero probabilities on "wrong" classes are very small, so their gradient contribution is still tiny compared to the dominant class. The teacher's dark knowledge is still mostly hidden.
The fix is temperature scaling. The standard softmax is:
At T=1 this is the standard softmax. At T>1, dividing all logits by T compresses their range before exponentiation, which flattens the distribution — small logits grow relatively more than large ones. This reveals more of the structural information buried in the wrong-class logits.
Let's work through the logit vector [3, 1, 0.5] for three classes [cat, dog, car] at different temperatures:
| Temperature T | p(cat) | p(dog) | p(car) | Effect |
|---|---|---|---|---|
| T = 1 | 0.731 | 0.197 | 0.072 | Peaked — cat dominates |
| T = 2 | 0.576 | 0.277 | 0.147 | Softer — dog/car get more signal |
| T = 4 | 0.444 | 0.316 | 0.240 | Very soft — near-uniform, maximum dark knowledge |
| T = 10 | 0.371 | 0.337 | 0.292 | Almost uniform — too flat, classes lose meaning |
Let's derive T=4 by hand for [3, 1, 0.5]. Scale by 1/4: logits become [0.75, 0.25, 0.125]. Exponentiate: exp(0.75)=2.117, exp(0.25)=1.284, exp(0.125)=1.133. Sum = 4.534. Probabilities: 2.117/4.534 = 0.467; 1.284/4.534 = 0.283; 1.133/4.534 = 0.250. That's not exactly the table — let me recompute the table with these exact values.
During distillation, both teacher and student run softmax at the same elevated temperature T. This produces equally-softened distributions on both sides, making the KL divergence meaningful even when the teacher is confident. After training, you discard T — the student runs at T=1 for deployment.
Use the interactive canvas to see how temperature transforms any 3-class logit distribution. Drag T left and right to watch the dark knowledge emerge (or disappear into uniformity at very high T).
Bar heights show the probability of each class. Drag T to see how the distribution softens. The teal bars show absolute probability; the orange dashes show the T=1 baseline for comparison. Notice how the "wrong" classes (dog, car) gain signal at higher T — that's the dark knowledge becoming visible.
We want the student to match the teacher's soft distribution. But we also want the student to learn the actual ground-truth labels — if the teacher has any residual error, we don't want to distill that error into the student. The Hinton KD loss combines both signals into a single weighted objective.
Formally, the distillation loss is:
Where: zT = teacher logits, zS = student logits, σ = softmax, y = one-hot ground-truth label, T = temperature, α = mixing coefficient (0=all soft targets, 1=all hard labels).
The two terms play different roles. The KL term (left) measures how far the student's soft distribution is from the teacher's soft distribution, computed at high temperature T to reveal dark knowledge. The CE term (right) is the standard cross-entropy loss against the ground-truth label at T=1, ensuring the student learns the actual correct answer.
Let's compute the KD loss for one concrete example. Teacher logits = [3, 1, 0.5], student logits = [1, 1, 1] (a totally untrained student — uniform), hard label = cat (index 0), T=2, α=0.5.
Step 1: Soft targets at T=2. Teacher logits/T = [1.5, 0.5, 0.25]. exp = [4.482, 1.649, 1.284]. Sum = 7.415. pT = [0.605, 0.222, 0.173]. Student logits/T = [0.5, 0.5, 0.5]. pS = [0.333, 0.333, 0.333] (uniform).
Step 2: KL divergence. KL(pT ‖ pS) = ∑i pT,i · log(pT,i / pS,i). = 0.605·log(0.605/0.333) + 0.222·log(0.222/0.333) + 0.173·log(0.173/0.333) = 0.605·0.597 + 0.222·(−0.406) + 0.173·(−0.655) = 0.361 − 0.090 − 0.113 = 0.158.
Step 3: CE loss. Student softmax at T=1: [0.333, 0.333, 0.333]. CE = −log(0.333) = 1.099.
Step 4: Full loss. LKD = 0.5 · T² · KL + 0.5 · CE = 0.5 · 4 · 0.158 + 0.5 · 1.099 = 0.316 + 0.550 = 0.866.
python import torch import torch.nn as nn import torch.nn.functional as F def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.5): """ Knowledge Distillation loss (Hinton et al., 2015). student_logits: (B, C) raw student outputs teacher_logits: (B, C) raw teacher outputs (no grad, teacher is frozen) labels: (B,) ground-truth class indices T: temperature for softening alpha: weight for KL (soft) term; (1-alpha) weights CE (hard) term """ # Soft targets: both at temperature T soft_student = F.log_softmax(student_logits / T, dim=-1) # log for KLDivLoss soft_teacher = F.softmax(teacher_logits / T, dim=-1) # no grad needed # KL divergence: KLDivLoss expects log-probs for student, probs for teacher kl = F.kl_div(soft_student, soft_teacher, reduction='batchmean') # T² scaling: restores gradient magnitude (see derivation above) kl_scaled = kl * (T ** 2) # Hard-label cross-entropy (standard, at T=1) ce = F.cross_entropy(student_logits, labels) # Combined loss return alpha * kl_scaled + (1 - alpha) * ce # Training step def train_step(student, teacher, batch_x, batch_y, optimizer, T=4, alpha=0.5): student.train() teacher.eval() with torch.no_grad(): teacher_logits = teacher(batch_x) # freeze teacher completely student_logits = student(batch_x) loss = kd_loss(student_logits, teacher_logits, batch_y, T, alpha) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()
The canvas below lets you drag the mixing coefficient α and see how the two loss components contribute to the total gradient. At α=0, the student trains purely on the teacher's soft knowledge. At α=1, purely on hard labels. The optimal α typically lies between 0.1 and 0.9 depending on how reliable the teacher is.
The bar chart shows the contribution of each loss term (KL soft vs CE hard) to the total gradient magnitude. Proxy "student accuracy" peaks near α=0.5 for most teacher/student pairs. Drag α to find the balance point.
Hinton's original paper matched only the output logits — the last layer's raw numbers before the final softmax. But a trained deep network contains structural knowledge at every layer. Subsequent work asked: what else can we transfer? The answer turned out to be almost everything.
| What to Match | Method | Loss Type | Key Papers |
|---|---|---|---|
| Output logits | Temperature-softmax, KL divergence on final distribution | KL div or L2 on logits | Hinton et al. 2015 |
| Intermediate weights | L2 match teacher weights to student weights (with linear transform for dim mismatch) | L2 norm | Ba & Caruana 2014 |
| Intermediate features | Match hidden-layer activation maps via MMD (Maximum Mean Discrepancy) or L2 | MMD or L2 | Huang & Wang 2017 (NST) |
| Attention maps | Match spatial attention: ∂L/∂x across channels, reduced to 2D map | L2 on attention maps | Zagoruyko & Komodakis 2017 |
| Sparsity patterns | Match post-ReLU activation boundaries: ρ(x) = 1[x>0] | Boundary matching loss | Heo et al. 2019 |
| Relational info | Match pairwise distances between multiple examples' feature vectors | L2 on distance/angle matrices | Park et al. 2019 (RKD); Yim et al. 2017 (FSP) |
These approaches form a hierarchy from "output only" to "full internal structure." Each deeper level of matching potentially transfers more knowledge — but also requires more careful engineering, since intermediate representations have different shapes in teacher and student, requiring adapter layers.
Let's think about why each type captures something different. Logit matching transfers the final classification decision structure. Feature matching (NST) transfers the distribution of internal representations — which neurons fire how strongly for which inputs. Attention matching transfers where the teacher looks in an image — a teacher that correctly focuses on the cat's face transfers that spatial prior to the student. Relational KD transfers the geometry of the embedding space — the fact that a cat and a dog are closer to each other than a cat and a car, even in the hidden space.
FitNets (Romero et al., ICLR 2015) took KD deeper — literally. Instead of matching only the final output, FitNets match intermediate layers. The teacher has a "hint layer" (some intermediate layer), and the student has a "guided layer" (an intermediate layer in the student that should learn to match the hint). The guidance loss is:
Where fT(x) is the teacher's hint layer output (shape: e.g., H×W×256), fS(x) is the student's guided layer output (shape: H×W×64, because the student is thinner), and Wr is a learned "regressor" — a small convolutional layer that maps the student's 64-channel output to 256 channels so the shapes match for comparison. The regressor is trained jointly with the student during the "hint training" phase.
Attention Transfer (Zagoruyko & Komodakis, ICLR 2017) takes a different angle. Instead of matching the raw feature maps, it matches the attention maps derived from those features. The attention map for a feature tensor F ∈ ℝC×H×W is computed as the sum of absolute values across channels: A(F) = ∑c |Fc|2 ∈ ℝH×W. This 2D map shows where in the image the layer is attending — high values mean "this spatial location was important."
The key empirical observation: high-accuracy ResNets (ResNet-34 at 73%, ResNet-101 at 77%) have very similar attention maps to each other. A lower-accuracy model (Network-in-Network at 62%) has qualitatively different attention maps. This suggests attention patterns are a marker of "correct" learned behavior — and matching them transfers that behavioral quality to the student.
python # FitNets: feature-matching loss with a regressor layer import torch import torch.nn as nn class FitNetRegressor(nn.Module): """Adapts student's thin feature map to teacher's wide feature map.""" def __init__(self, student_channels, teacher_channels): super().__init__() self.regressor = nn.Conv2d(student_channels, teacher_channels, kernel_size=1) def forward(self, student_feat, teacher_feat): # student_feat: (B, 64, H, W) — thin student layer # teacher_feat: (B, 256, H, W) — wide teacher hint layer adapted = self.regressor(student_feat) # (B, 256, H, W) return 0.5 * (adapted - teacher_feat).pow(2).mean() # Attention Transfer: compute 2D attention map from feature tensor def attention_map(feat): """feat: (B, C, H, W) → attention (B, H*W) for L2 matching""" # Sum squared activations across channel dim → (B, H, W) att = feat.pow(2).sum(dim=1) # Normalize and flatten: (B, H*W) att = att.view(att.size(0), -1) return att / att.norm(dim=1, keepdim=True) def attention_loss(teacher_feat, student_feat): return (attention_map(teacher_feat) - attention_map(student_feat)).pow(2).mean()
The canvas shows what happens as a student's feature activations align with the teacher's hint layer over training steps. At step 0, the student's activations are random — no spatial structure. By step 100, the student has learned to focus on the same spatial regions as the teacher (the canvas illustrates this as blob convergence). This spatial alignment is what attention transfer achieves.
Left: teacher's 2D attention map (fixed gold circles = important spatial regions). Right: student's 2D attention map evolving over training. Drag the slider to watch the student's blobs converge to the teacher's hint structure — this is what FitNets/AT teach beyond logits.
Standard KD requires a pre-trained teacher — a large network trained first, then frozen, then used as the teaching signal. This creates a two-stage pipeline. But what if you want to distill knowledge between networks that are trained simultaneously? Or what if you want to distill a network into itself without a separate teacher at all? Two approaches handle these cases: online distillation and self-distillation.
Deep Mutual Learning (Zhang et al., CVPR 2018) trains two (or more) student networks simultaneously, each acting as the other's teacher. The loss for each network has two terms:
Each network is penalized for diverging from the other's current prediction. They push each other toward better generalization without needing a pre-trained teacher. The surprising empirical result: both networks end up better than if they were trained independently. Even a large WRN-28-10 paired with a small ResNet-32 sees the small network gain 0.71% while the large network gains 0.74% — mutual teaching works.
Born-Again Networks (Furlanello et al., ICML 2018) push the self-distillation idea further. The protocol is: train a model T (generation 0). Then train an identical architecture S1 using T as teacher (distillation). Then train S2 using S1 as teacher. And so on. The key findings:
1. Each generation outperforms the previous: T < S1 < S2 < ... — distilling the same architecture into itself, iteratively, improves it. This is philosophically puzzling: teacher and student have identical capacity, yet the student exceeds the teacher. 2. Ensembling all generations: T, S1, S2, ... gives even larger gains, often exceeding a single model of twice the capacity. 3. You can alternatively distill S1 into itself (teacher-student = same network, same architecture, same capacity) — called self-distillation — and still gain accuracy.
Zhang et al. (ICCV 2019) combine both ideas: attach auxiliary classifiers at each quarter of a ResNet's depth. The deeper classifiers act as teachers for the shallower ones — deep supervision via distillation. Training Classifier 3/4 (at 75% depth) to distill into Classifier 1/4 (at 25% depth) forces the shallow layers to learn richer features earlier. On CIFAR-100, self-distillation improves ResNet18 from 68% to 78%, VGG19 from 70% to 73%. Notably, early-exit classifiers (1/4 depth) can sometimes outperform the final classifier on certain inputs, enabling adaptive inference.
This showcase lets you watch the entire distillation process from first principles. A teacher network (4 classes: cat, dog, car, bird) has fixed logits that produce a specific soft distribution. A student starts with random logits and learns to match the teacher over training epochs via the full KD loss.
You control three parameters that govern the whole distillation: Temperature T (higher = more dark knowledge revealed), mixing α (how much hard vs soft supervision), and epoch (where in training we are). Hit "Play Training" to watch the student's distribution animate toward the teacher's over 100 simulated epochs. The KL divergence readout drops toward zero as the student learns.
Try these experiments: (1) Set T=1 and note the KL: the student converges, but misses the fine structure of the teacher's distribution. (2) Set T=4 and watch: convergence reveals more similarity structure. (3) Set α=0 (pure soft): the student matches the teacher's distribution but may be slightly miscalibrated on the hard label. (4) Set α=1 (pure hard): the student learns the correct class but the distribution has no structure. The optimal lies between.
Top: teacher (gold) and student (teal) probability bars for 4 classes. Bottom left: KL divergence curve over training (lower = student has learned the teacher). Bottom right: the distillation loss components. Drag sliders to see how T and α change the convergence.
KD doesn't replace pruning and quantization — it composes with them. In fact, the most powerful efficiency pipelines use all three together in the right order. Understanding when to apply KD in the pipeline is as important as understanding how it works.
You prune a ResNet-50 by 80%, then quantize to INT8. Accuracy drops from 76% to 58%. Now use the original full-precision ResNet-50 as the teacher, and the pruned+quantized model as the student. The KD loss pulls the compressed model's behavior back toward the uncompressed teacher. This is the Minitron approach (NVIDIA, NeurIPS 2024): prune the LLM first, then distill from the original to recover the lost capability. The result: a compressed model that is much closer to the original than fine-tuning on hard labels alone can achieve.
A different composition: what if the bottleneck isn't knowledge transfer but training signal quality? For tiny models like MobileNetV2-Tiny (23.5M MACs), conventional data augmentation (Mixup, AutoAugment, Cutout) and dropout hurt performance — the model is too small to benefit from regularization designed for large models. NetAug (Cai et al., ICLR 2022, MIT) solves this differently.
Instead of regularizing the small model, NetAug augments it during training. The tiny model is treated as a sub-network of a set of larger augmented models. The augmented loss is:
The key: Wbase is the tiny model's weights. Waug are extra "augmentation" weights that extend the model's width during training only. The tiny model must work well both alone (the base term) and as part of the wider augmented model (the auxiliary term). This forces the tiny model's weights to be more versatile — they must generalize enough to be useful in a bigger context. At inference, Waug is discarded — zero overhead. MobileNetV2-Tiny goes from 52% to 56% on ImageNet with NetAug at the same 23.5M MAC budget.
The same principles scale to billion-parameter language models. Llama 3.2's 1B and 3B models are not trained from scratch — they are distilled from Llama 3.1 8B and 70B respectively. Minitron (NVIDIA) prunes Llama-3.1-8B by removing attention heads and layers, then distills from the original. The distillation loss includes: logit KL (output distribution matching), embedding output loss (match the embedding layer's geometry), and transformer block output losses (layer-wise feature matching — exactly FitNets at LLM scale). The result: Minitron-4B outperforms LLaMA-3-8B trained from scratch at 4B params, at 40× cheaper training cost.
| Loss Term | Formula | What it does |
|---|---|---|
| KL (soft) | α·T²·KL(σ(zT/T) ‖ σ(zS/T)) | Match teacher's full soft distribution; T² restores gradient scale |
| CE (hard) | (1−α)·CE(zS, y) | Anchor to ground-truth label; prevents distilling teacher error |
| Feature (L2) | (1/2)‖Wr(fS) − fT‖² | FitNets: match hidden activations; needs shape-adapting regressor |
| Attention | ‖A(fT) − A(fS)‖² | AT: match where the network looks spatially; normalized 2D map |
| Relational | ‖ψ(t1..n) − ψ(s1..n)‖² | RKD: match pairwise distances between example embeddings |
| Hyperparameter | Typical range | Effect at extremes |
|---|---|---|
| Temperature T | 3–5 for most CNNs; 1–2 for LLMs | T≈1: standard KD, less dark knowledge. T≫5: near-uniform, dark knowledge lost |
| Mix α | 0.1–0.5 | α=0: pure soft (may over-distill error). α=1: pure hard (standard CE, no benefit) |
| Signal Available | Recommended Match | Why |
|---|---|---|
| Only output logits | Hinton KD (logit KL) | Simplest, often sufficient; zero architecture constraints |
| Teacher too large to run (LLM) | Pre-generate soft targets offline, cache to disk | Run teacher once, reuse; no teacher memory during student training |
| Student much thinner than teacher | FitNets (feature + logit) | Feature matching compensates for capacity gap at intermediate layers |
| Detection/segmentation (spatial output) | Attention maps + logit KD | Spatial attention teaches where to look — critical for localization |
| Metric learning / retrieval | Relational KD (RKD) | Preserves embedding space geometry — pairwise distances matter most |
| No large teacher available | Born-Again Networks or Deep Mutual Learning | Self-distillation still improves over hard-label baseline |
| Combination | Order | Gain |
|---|---|---|
| Pruning + KD | Prune first, then distill from full-precision teacher | Recovers 3–8% acc vs pruning alone |
| Quantization + KD | Quantize-aware train as student, FP teacher | Especially useful for INT4/INT8; teacher provides smooth gradient |
| NAS + KD | Search for architecture, then distill large teacher into found arch | Architecture gives efficiency; KD gives accuracy |
| NetAug + KD | NetAug during student training; optional teacher for logits | Best for tiny models (<50M MACs) that under-fit standard augmentation |
This lesson closes the compression trilogy (Pruning L3–L4, Quantization L5–L6, KD here). The next lecture is MCUNet (Lecture 10) — which deploys the fully compressed+distilled model onto microcontrollers with 256 kB SRAM. MCUNet is the payoff: a NAS-found architecture trained with KD and quantized to INT8, running ImageNet-scale inference on a device smaller than your fingernail.
Related micro-lessons that connect to concepts here: