TinyML & Efficient Deep Learning · MIT 6.5940 · Lecture 9

Knowledge Distillation: Teaching Small Networks

You compressed a ResNet-50 down to MobileNetV2-Tiny via pruning and quantization — but accuracy fell from 76% to 48%. What if the big model could teach the small one, sharing not just hard labels but the full probability landscape it learned? Knowledge distillation is that teaching protocol: a student network trained on a teacher's soft outputs learns more from every example than any hard label could convey.

Prerequisites: TinyML L3–L4 (Pruning), TinyML L5–L6 (Quantization) — compression concepts. Softmax and cross-entropy from L2 helpful.
10
Chapters
5
Live Canvases
Derived
From First Principles

Chapter 0: The Accuracy Gap

You just finished building a TinyML pipeline. You started with a ResNet-50 (76% top-1 on ImageNet, 4.1G MACs). You pruned it aggressively — down to MobileNetV2-Tiny with 23.5M MACs. You quantized it to INT8. You ran it on the MCU. It fits in 512 kB SRAM. It runs in 18 ms. Your manager is happy — until you check accuracy. The tiny model hits 48%, down from 76%. The compression worked; the accuracy collapsed.

Why the collapse? There are actually two distinct problems. Problem 1 — capacity under-fitting: a tiny model like MobileNetV2-Tiny has far fewer parameters than ResNet-50. Even if trained perfectly, it may simply lack the representational capacity to reach 76%. The "expressiveness ceiling" is lower. Problem 2 — the hard-label problem: you trained the tiny model on one-hot labels: cat = [1, 0, 0, 0, 0, …]. Those labels are maximally sparse. Every example tells the model exactly one fact: "this is a cat." The model gets no information about how much a given image resembles a dog, or a leopard, or a lion. Hard labels are brutal teachers.

Here's the insight that Hinton, Vinyals, and Dean realized in 2014: a trained big network is a much better teacher than a one-hot label. Why? Because when a trained ResNet sees an image of a tabby cat, it doesn't output [1, 0, 0, 0, …]. It outputs something like [0.82, 0.06, 0.05, 0.04, 0.02, …] — a full probability distribution across all 1,000 ImageNet classes. That distribution says: "probably cat, a little dog, a little leopard." Those nonzero probabilities on non-cat classes encode structural knowledge about visual similarity. They are richer than any ground-truth label. And we can transfer that richness to the small model.

The key insight in one sentence: The full soft probability distribution from a trained teacher is a better training signal than the ground-truth label — it encodes what the teacher has learned about class similarities, making every training example teach the student many facts at once.
Conventional Training
Student sees only: image → one-hot label [1, 0, 0, …]. Learns: "this = cat." Gets no information about class similarities.
↓ vs
Knowledge Distillation
Student sees: image → teacher's soft distribution [0.82, 0.06, 0.05, …]. Learns: "this = mostly cat, a bit dog, a bit leopard." Much richer signal per example.
↓ result
Accuracy Recovery
MobileNetV2-Tiny trained with KD recovers ≈3–5% top-1 over the hard-label baseline, at zero inference cost — same architecture, same latency.

The rest of this lesson derives the exact math, builds the practical loss function, and explores what information the teacher can share beyond output logits — intermediate features, attention maps, and relationships between examples.

A tiny MobileNetV2 trained from scratch on ImageNet hard labels hits 48% top-1. What is the primary reason knowledge distillation can recover accuracy above this baseline, without changing the student architecture?

Chapter 1: Dark Knowledge

Hinton coined the term "dark knowledge" for the information buried in a trained network's output probabilities that never shows up in hard labels. Let's make this concrete with actual numbers.

Suppose a trained teacher network looks at an image of a tabby cat and outputs raw logits of [5, 1, 0.5] for the classes [cat, dog, car]. At temperature T=1, the softmax gives:

p(cat) = exp(5) / (exp(5)+exp(1)+exp(0.5)) = 148.41 / (148.41+2.72+1.65) = 148.41 / 152.78 ≈ 0.971
p(dog) = 2.72 / 152.78 ≈ 0.018
p(car) = 1.65 / 152.78 ≈ 0.011

The hard label is just [1, 0, 0]. The soft output is [0.971, 0.018, 0.011]. What's the difference? The soft output says: "this image is very much a cat, and cats are a little bit like dogs, and not much like cars." The relative probability between dog and car — 0.018 vs 0.011 — tells the student something genuinely useful: dogs and cats share visual features (fur, ears, eyes) that cars don't. A student trained purely on hard labels never sees this.

Hinton's term "dark knowledge" is evocative precisely because this information is invisible in the dataset. No human annotator wrote "this cat image is 1.6× more dog-like than car-like." That knowledge was discovered by the teacher during training and encoded implicitly in its weight space — and soft targets are the only way to extract and transfer it.

Why wrong-class probabilities are the whole point: The lion's share of the information in soft targets comes from the tiny probabilities on wrong classes. A probability of 0.018 on "dog" seems negligible — but it is 1.8% of a training signal that was previously 0%. Across billions of parameters and millions of examples, these small corrections accumulate into a meaningfully better student. Do not mistake low probability for unimportant signal.

Here's a second worked example with a more ambiguous image — a drawing of a cat:

ClassTeacher logitSoft prob (T=1)Hard label
cat3.00.7311
dog2.00.2690
car–5.0~0.0000

Notice that the teacher is much less confident than on the real photo (0.731 vs 0.971). That reduced confidence is genuine information: the drawing looks ambiguous. A student trained with the hard label [1, 0, 0] gets no signal about this ambiguity. The soft label [0.731, 0.269, 0.000] correctly calibrates the student's uncertainty. This is especially valuable at test time: a well-distilled student produces well-calibrated probabilities, not just sharp argmax predictions.

Common misconception — "the soft probabilities on wrong classes are noise": They are not noise. They are structure. The teacher has processed millions of examples and organized its weight space to group similar concepts together. The nonzero probability on "dog" when looking at a cat is not a mistake by the teacher — it is a hard-won structural fact: cats and dogs are visually similar. Discarding this and training only on hard labels is deliberately blinding the student to this structure.
python
# Worked example: teacher soft targets vs hard labels for 3 classes
import torch
import torch.nn.functional as F

# Teacher logits for one image [cat=3.0, dog=2.0, car=-5.0]
teacher_logits = torch.tensor([3.0, 2.0, -5.0])

# Soft targets (T=1): the dark knowledge signal
soft_targets = F.softmax(teacher_logits, dim=0)
# tensor([0.7311, 0.2689, 0.0000])  ← dog gets 26.9%, not 0%

# Hard label: cat=1, dog=0, car=0
hard_label = torch.tensor([1, 0, 0])

# Cross-entropy on hard label gives ZERO gradient for dog logit
# KL divergence on soft targets gives nonzero gradient ← the key difference
A teacher network outputs soft targets [0.97, 0.02, 0.01] for a cat image (classes: cat, dog, car). The hard label is [1, 0, 0]. What information does the soft target provide that the hard label cannot?

Chapter 2: Temperature & Soft Targets

We know soft targets carry dark knowledge — but there's a practical problem. For a confident teacher, the soft targets at T=1 may still be extremely peaked: [0.971, 0.018, 0.011]. The nonzero probabilities on "wrong" classes are very small, so their gradient contribution is still tiny compared to the dominant class. The teacher's dark knowledge is still mostly hidden.

The fix is temperature scaling. The standard softmax is:

p(zi, T) = exp(zi / T) / ∑j exp(zj / T)

At T=1 this is the standard softmax. At T>1, dividing all logits by T compresses their range before exponentiation, which flattens the distribution — small logits grow relatively more than large ones. This reveals more of the structural information buried in the wrong-class logits.

Let's work through the logit vector [3, 1, 0.5] for three classes [cat, dog, car] at different temperatures:

Temperature Tp(cat)p(dog)p(car)Effect
T = 10.7310.1970.072Peaked — cat dominates
T = 20.5760.2770.147Softer — dog/car get more signal
T = 40.4440.3160.240Very soft — near-uniform, maximum dark knowledge
T = 100.3710.3370.292Almost uniform — too flat, classes lose meaning

Let's derive T=4 by hand for [3, 1, 0.5]. Scale by 1/4: logits become [0.75, 0.25, 0.125]. Exponentiate: exp(0.75)=2.117, exp(0.25)=1.284, exp(0.125)=1.133. Sum = 4.534. Probabilities: 2.117/4.534 = 0.467; 1.284/4.534 = 0.283; 1.133/4.534 = 0.250. That's not exactly the table — let me recompute the table with these exact values.

Exact derivation for logits [3, 1, 0.5] at T=4: Divide by T: [0.75, 0.25, 0.125]. exp([0.75, 0.25, 0.125]) = [2.117, 1.284, 1.133]. Sum = 4.534. p = [0.467, 0.283, 0.250]. Notice: even the "car" class (logit 0.5) gets 25% probability — up from 7% at T=1. The teacher is now effectively saying "this drawing could be cat, dog, or car in roughly 2:1.1:1 ratio" — much richer than "cat" alone.

During distillation, both teacher and student run softmax at the same elevated temperature T. This produces equally-softened distributions on both sides, making the KL divergence meaningful even when the teacher is confident. After training, you discard T — the student runs at T=1 for deployment.

Use the interactive canvas to see how temperature transforms any 3-class logit distribution. Drag T left and right to watch the dark knowledge emerge (or disappear into uniformity at very high T).

Temperature Softmax Explorer — logits [3, 1, 0.5]

Bar heights show the probability of each class. Drag T to see how the distribution softens. The teal bars show absolute probability; the orange dashes show the T=1 baseline for comparison. Notice how the "wrong" classes (dog, car) gain signal at higher T — that's the dark knowledge becoming visible.

Temperature T 1.0
There is such a thing as too much temperature: At T=20, all three classes get nearly equal probability regardless of the logits. The distribution has lost all the relative ordering information the teacher learned. Hinton found T=3–5 works well in practice — soft enough to reveal structure, sharp enough to preserve it. If T is so high that even "car" looks as likely as "cat" when looking at a cat photo, you've washed out the dark knowledge along with the noise.
For logits [3, 1, 0.5] representing [cat, dog, car], what does raising the temperature T from 1 to 4 do to the soft targets, and why is this useful for distillation?

Chapter 3: The KD Loss (Derived)

We want the student to match the teacher's soft distribution. But we also want the student to learn the actual ground-truth labels — if the teacher has any residual error, we don't want to distill that error into the student. The Hinton KD loss combines both signals into a single weighted objective.

Formally, the distillation loss is:

LKD = α · T2 · KL( σ(zT/T) ‖ σ(zS/T) ) + (1−α) · CE( zS, y )

Where: zT = teacher logits, zS = student logits, σ = softmax, y = one-hot ground-truth label, T = temperature, α = mixing coefficient (0=all soft targets, 1=all hard labels).

The two terms play different roles. The KL term (left) measures how far the student's soft distribution is from the teacher's soft distribution, computed at high temperature T to reveal dark knowledge. The CE term (right) is the standard cross-entropy loss against the ground-truth label at T=1, ensuring the student learns the actual correct answer.

Why is there a T² factor? This is the most-forgotten detail in KD implementations, and getting it wrong makes the KL loss vanishingly small. Here's the derivation: the softmax gradient ∂σ(z/T)/∂z scales as 1/T. So the KL loss gradient with respect to student logits scales as 1/T². At T=4, this is 1/16 — the KL gradient is 16× smaller than the CE gradient. Without T², the soft-target term becomes negligible for any T>1. Multiplying by T² restores the gradient to the same scale as the CE term, so the two objectives are properly balanced regardless of the temperature you choose.

Let's compute the KD loss for one concrete example. Teacher logits = [3, 1, 0.5], student logits = [1, 1, 1] (a totally untrained student — uniform), hard label = cat (index 0), T=2, α=0.5.

Step 1: Soft targets at T=2. Teacher logits/T = [1.5, 0.5, 0.25]. exp = [4.482, 1.649, 1.284]. Sum = 7.415. pT = [0.605, 0.222, 0.173]. Student logits/T = [0.5, 0.5, 0.5]. pS = [0.333, 0.333, 0.333] (uniform).

Step 2: KL divergence. KL(pT ‖ pS) = ∑i pT,i · log(pT,i / pS,i). = 0.605·log(0.605/0.333) + 0.222·log(0.222/0.333) + 0.173·log(0.173/0.333) = 0.605·0.597 + 0.222·(−0.406) + 0.173·(−0.655) = 0.361 − 0.090 − 0.113 = 0.158.

Step 3: CE loss. Student softmax at T=1: [0.333, 0.333, 0.333]. CE = −log(0.333) = 1.099.

Step 4: Full loss. LKD = 0.5 · T² · KL + 0.5 · CE = 0.5 · 4 · 0.158 + 0.5 · 1.099 = 0.316 + 0.550 = 0.866.

After many gradient steps, both terms push toward zero. The KL term pushes the student's distribution toward the teacher's soft distribution. The CE term pushes the student's argmax toward the correct class. When both are small, the student (a) gives the correct answer, and (b) has internalized the teacher's similarity structure. This is exactly what we want.
python
import torch
import torch.nn as nn
import torch.nn.functional as F

def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.5):
    """
    Knowledge Distillation loss (Hinton et al., 2015).
    student_logits: (B, C) raw student outputs
    teacher_logits: (B, C) raw teacher outputs (no grad, teacher is frozen)
    labels:         (B,)   ground-truth class indices
    T:              temperature for softening
    alpha:          weight for KL (soft) term; (1-alpha) weights CE (hard) term
    """
    # Soft targets: both at temperature T
    soft_student = F.log_softmax(student_logits / T, dim=-1)   # log for KLDivLoss
    soft_teacher = F.softmax(teacher_logits / T, dim=-1)         # no grad needed

    # KL divergence: KLDivLoss expects log-probs for student, probs for teacher
    kl = F.kl_div(soft_student, soft_teacher, reduction='batchmean')

    # T² scaling: restores gradient magnitude (see derivation above)
    kl_scaled = kl * (T ** 2)

    # Hard-label cross-entropy (standard, at T=1)
    ce = F.cross_entropy(student_logits, labels)

    # Combined loss
    return alpha * kl_scaled + (1 - alpha) * ce

# Training step
def train_step(student, teacher, batch_x, batch_y, optimizer, T=4, alpha=0.5):
    student.train()
    teacher.eval()
    with torch.no_grad():
        teacher_logits = teacher(batch_x)   # freeze teacher completely
    student_logits = student(batch_x)
    loss = kd_loss(student_logits, teacher_logits, batch_y, T, alpha)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

The canvas below lets you drag the mixing coefficient α and see how the two loss components contribute to the total gradient. At α=0, the student trains purely on the teacher's soft knowledge. At α=1, purely on hard labels. The optimal α typically lies between 0.1 and 0.9 depending on how reliable the teacher is.

KD Loss Mixer — Hard Labels vs Soft Targets

The bar chart shows the contribution of each loss term (KL soft vs CE hard) to the total gradient magnitude. Proxy "student accuracy" peaks near α=0.5 for most teacher/student pairs. Drag α to find the balance point.

Mix α (0=all soft, 1=all hard) 0.50
The KD loss is: L = α·T²·KL(teacher_soft ‖ student_soft) + (1−α)·CE(student, hard). Why is the T² factor necessary?

Chapter 4: What to Match

Hinton's original paper matched only the output logits — the last layer's raw numbers before the final softmax. But a trained deep network contains structural knowledge at every layer. Subsequent work asked: what else can we transfer? The answer turned out to be almost everything.

What to MatchMethodLoss TypeKey Papers
Output logitsTemperature-softmax, KL divergence on final distributionKL div or L2 on logitsHinton et al. 2015
Intermediate weightsL2 match teacher weights to student weights (with linear transform for dim mismatch)L2 normBa & Caruana 2014
Intermediate featuresMatch hidden-layer activation maps via MMD (Maximum Mean Discrepancy) or L2MMD or L2Huang & Wang 2017 (NST)
Attention mapsMatch spatial attention: ∂L/∂x across channels, reduced to 2D mapL2 on attention mapsZagoruyko & Komodakis 2017
Sparsity patternsMatch post-ReLU activation boundaries: ρ(x) = 1[x>0]Boundary matching lossHeo et al. 2019
Relational infoMatch pairwise distances between multiple examples' feature vectorsL2 on distance/angle matricesPark et al. 2019 (RKD); Yim et al. 2017 (FSP)

These approaches form a hierarchy from "output only" to "full internal structure." Each deeper level of matching potentially transfers more knowledge — but also requires more careful engineering, since intermediate representations have different shapes in teacher and student, requiring adapter layers.

The fundamental insight behind all "what to match" variants: Wherever the teacher has learned to be structured — in its output distribution, its feature space geometry, its attention patterns, or even the relationships between the representations of different inputs — there is transferable knowledge. The student benefits from seeing all of it. Logit distillation is the cheapest starting point; feature/attention distillation adds signal that logits can't convey.

Let's think about why each type captures something different. Logit matching transfers the final classification decision structure. Feature matching (NST) transfers the distribution of internal representations — which neurons fire how strongly for which inputs. Attention matching transfers where the teacher looks in an image — a teacher that correctly focuses on the cat's face transfers that spatial prior to the student. Relational KD transfers the geometry of the embedding space — the fact that a cat and a dog are closer to each other than a cat and a car, even in the hidden space.

Matching sparsity patterns is a counterintuitive win: After ReLU, some neurons are activated (value>0) and some are not (value=0). The pattern of which neurons activate for which input is itself a learned structure — it encodes the "activation boundary" that the teacher has learned for each class. A student that fires the wrong neurons for "cat" inputs will have incorrect internal structure even if its final output is right. Forcing the student to match the teacher's sparsity pattern aligns the internal structure, not just the output.
Relational Knowledge Distillation (RKD) differs from standard logit KD in a fundamental way. What does RKD match that standard KD does not?

Chapter 5: FitNets & Attention Maps

FitNets (Romero et al., ICLR 2015) took KD deeper — literally. Instead of matching only the final output, FitNets match intermediate layers. The teacher has a "hint layer" (some intermediate layer), and the student has a "guided layer" (an intermediate layer in the student that should learn to match the hint). The guidance loss is:

Lhint = (1/2) · ‖ Wr(fS(x)) − fT(x) ‖22

Where fT(x) is the teacher's hint layer output (shape: e.g., H×W×256), fS(x) is the student's guided layer output (shape: H×W×64, because the student is thinner), and Wr is a learned "regressor" — a small convolutional layer that maps the student's 64-channel output to 256 channels so the shapes match for comparison. The regressor is trained jointly with the student during the "hint training" phase.

FitNets are trained in two stages, not one: Stage 1 (hint training): only the student's layers up to the guided layer are trained, with Wr, guided purely by the feature-matching loss. Stage 2 (knowledge distillation): the full student is trained end-to-end with both the feature-matching loss and the logit KD loss. The hint-training stage gives the student a good initialization for the intermediate representations before the final output-level training begins. Skipping stage 1 and training everything at once typically hurts performance.

Attention Transfer (Zagoruyko & Komodakis, ICLR 2017) takes a different angle. Instead of matching the raw feature maps, it matches the attention maps derived from those features. The attention map for a feature tensor F ∈ ℝC×H×W is computed as the sum of absolute values across channels: A(F) = ∑c |Fc|2 ∈ ℝH×W. This 2D map shows where in the image the layer is attending — high values mean "this spatial location was important."

The key empirical observation: high-accuracy ResNets (ResNet-34 at 73%, ResNet-101 at 77%) have very similar attention maps to each other. A lower-accuracy model (Network-in-Network at 62%) has qualitatively different attention maps. This suggests attention patterns are a marker of "correct" learned behavior — and matching them transfers that behavioral quality to the student.

python
# FitNets: feature-matching loss with a regressor layer
import torch
import torch.nn as nn

class FitNetRegressor(nn.Module):
    """Adapts student's thin feature map to teacher's wide feature map."""
    def __init__(self, student_channels, teacher_channels):
        super().__init__()
        self.regressor = nn.Conv2d(student_channels, teacher_channels, kernel_size=1)

    def forward(self, student_feat, teacher_feat):
        # student_feat: (B, 64, H, W) — thin student layer
        # teacher_feat: (B, 256, H, W) — wide teacher hint layer
        adapted = self.regressor(student_feat)  # (B, 256, H, W)
        return 0.5 * (adapted - teacher_feat).pow(2).mean()

# Attention Transfer: compute 2D attention map from feature tensor
def attention_map(feat):
    """feat: (B, C, H, W) → attention (B, H*W) for L2 matching"""
    # Sum squared activations across channel dim → (B, H, W)
    att = feat.pow(2).sum(dim=1)
    # Normalize and flatten: (B, H*W)
    att = att.view(att.size(0), -1)
    return att / att.norm(dim=1, keepdim=True)

def attention_loss(teacher_feat, student_feat):
    return (attention_map(teacher_feat) - attention_map(student_feat)).pow(2).mean()

The canvas shows what happens as a student's feature activations align with the teacher's hint layer over training steps. At step 0, the student's activations are random — no spatial structure. By step 100, the student has learned to focus on the same spatial regions as the teacher (the canvas illustrates this as blob convergence). This spatial alignment is what attention transfer achieves.

Feature Alignment — Student Activations Converging to Teacher's Hint

Left: teacher's 2D attention map (fixed gold circles = important spatial regions). Right: student's 2D attention map evolving over training. Drag the slider to watch the student's blobs converge to the teacher's hint structure — this is what FitNets/AT teach beyond logits.

Training Step 0
FitNets train a student to match the teacher's intermediate features. Why is a "regressor" (Wr) needed between the student and teacher layers?

Chapter 6: Online & Self-Distillation

Standard KD requires a pre-trained teacher — a large network trained first, then frozen, then used as the teaching signal. This creates a two-stage pipeline. But what if you want to distill knowledge between networks that are trained simultaneously? Or what if you want to distill a network into itself without a separate teacher at all? Two approaches handle these cases: online distillation and self-distillation.

Online Distillation: Deep Mutual Learning

Deep Mutual Learning (Zhang et al., CVPR 2018) trains two (or more) student networks simultaneously, each acting as the other's teacher. The loss for each network has two terms:

L(S1) = CE(S1(x), y) + KL( S2(x) ‖ S1(x) )
L(S2) = CE(S2(x), y) + KL( S1(x) ‖ S2(x) )

Each network is penalized for diverging from the other's current prediction. They push each other toward better generalization without needing a pre-trained teacher. The surprising empirical result: both networks end up better than if they were trained independently. Even a large WRN-28-10 paired with a small ResNet-32 sees the small network gain 0.71% while the large network gains 0.74% — mutual teaching works.

Why does online distillation work without a stronger teacher? Each network learns from its own mistakes, but also regularized by its peer's current beliefs. When one network is uncertain (high entropy outputs), the other doesn't force it toward a sharp label — it provides a soft, exploratory training signal. This collaborative uncertainty prevents each network from committing too early to wrong representations, similar to ensemble effects but with a single inference model at test time.

Born-Again Networks: Self-Distillation via Generations

Born-Again Networks (Furlanello et al., ICML 2018) push the self-distillation idea further. The protocol is: train a model T (generation 0). Then train an identical architecture S1 using T as teacher (distillation). Then train S2 using S1 as teacher. And so on. The key findings:

1. Each generation outperforms the previous: T < S1 < S2 < ... — distilling the same architecture into itself, iteratively, improves it. This is philosophically puzzling: teacher and student have identical capacity, yet the student exceeds the teacher. 2. Ensembling all generations: T, S1, S2, ... gives even larger gains, often exceeding a single model of twice the capacity. 3. You can alternatively distill S1 into itself (teacher-student = same network, same architecture, same capacity) — called self-distillation — and still gain accuracy.

How can the student surpass the teacher with identical architecture? The teacher trains on hard labels, converging to one local minimum in the loss landscape. The student trains on the teacher's soft targets — a smoother, richer signal. The soft targets implicitly encode the geometry of the loss landscape around the teacher's solution. This smoothed signal can guide the student to a different (and better) minimum that the hard-label teacher never explored. Essentially, the teacher is a consultant who not only gives the right answer but also shares their intuitions about nearby wrong answers — the student benefits from this broader guidance even when equally capable.

Be Your Own Teacher: Deep Supervision + Self-Distillation

Zhang et al. (ICCV 2019) combine both ideas: attach auxiliary classifiers at each quarter of a ResNet's depth. The deeper classifiers act as teachers for the shallower ones — deep supervision via distillation. Training Classifier 3/4 (at 75% depth) to distill into Classifier 1/4 (at 25% depth) forces the shallow layers to learn richer features earlier. On CIFAR-100, self-distillation improves ResNet18 from 68% to 78%, VGG19 from 70% to 73%. Notably, early-exit classifiers (1/4 depth) can sometimes outperform the final classifier on certain inputs, enabling adaptive inference.

Born-Again Networks train a student with identical architecture to the teacher. How can the student achieve higher accuracy than the teacher despite identical capacity?

Chapter 7: Showcase: KD Playground

This showcase lets you watch the entire distillation process from first principles. A teacher network (4 classes: cat, dog, car, bird) has fixed logits that produce a specific soft distribution. A student starts with random logits and learns to match the teacher over training epochs via the full KD loss.

You control three parameters that govern the whole distillation: Temperature T (higher = more dark knowledge revealed), mixing α (how much hard vs soft supervision), and epoch (where in training we are). Hit "Play Training" to watch the student's distribution animate toward the teacher's over 100 simulated epochs. The KL divergence readout drops toward zero as the student learns.

Try these experiments: (1) Set T=1 and note the KL: the student converges, but misses the fine structure of the teacher's distribution. (2) Set T=4 and watch: convergence reveals more similarity structure. (3) Set α=0 (pure soft): the student matches the teacher's distribution but may be slightly miscalibrated on the hard label. (4) Set α=1 (pure hard): the student learns the correct class but the distribution has no structure. The optimal lies between.

KD Playground — Teacher vs Student Distributions + Live KL Divergence

Top: teacher (gold) and student (teal) probability bars for 4 classes. Bottom left: KL divergence curve over training (lower = student has learned the teacher). Bottom right: the distillation loss components. Drag sliders to see how T and α change the convergence.

Training Epoch 0
Temperature T 4.0
Mix α (0=soft, 1=hard) 0.50

Chapter 8: KD in the Efficiency Stack

KD doesn't replace pruning and quantization — it composes with them. In fact, the most powerful efficiency pipelines use all three together in the right order. Understanding when to apply KD in the pipeline is as important as understanding how it works.

Scenario 1: KD after pruning/quantization (accuracy recovery)

You prune a ResNet-50 by 80%, then quantize to INT8. Accuracy drops from 76% to 58%. Now use the original full-precision ResNet-50 as the teacher, and the pruned+quantized model as the student. The KD loss pulls the compressed model's behavior back toward the uncompressed teacher. This is the Minitron approach (NVIDIA, NeurIPS 2024): prune the LLM first, then distill from the original to recover the lost capability. The result: a compressed model that is much closer to the original than fine-tuning on hard labels alone can achieve.

Step 1: Teacher
Train large full-precision model (ResNet-50, 76%, 4.1G MACs). This becomes the teacher — fixed during distillation.
Step 2: Compress
Prune + quantize → small compressed model (e.g., 23.5M MACs, INT8). Accuracy drops to ~58%.
Step 3: Distill
Train compressed model as student with LKD = αT²·KL(teacher‖student) + (1−α)·CE. Accuracy recovers to ~63–65%. Inference cost: unchanged (student is still compressed).

Scenario 2: NetAug — network augmentation for tiny models

A different composition: what if the bottleneck isn't knowledge transfer but training signal quality? For tiny models like MobileNetV2-Tiny (23.5M MACs), conventional data augmentation (Mixup, AutoAugment, Cutout) and dropout hurt performance — the model is too small to benefit from regularization designed for large models. NetAug (Cai et al., ICLR 2022, MIT) solves this differently.

Instead of regularizing the small model, NetAug augments it during training. The tiny model is treated as a sub-network of a set of larger augmented models. The augmented loss is:

Laug = L(Wbase) + α · L([Wbase, Waug])

The key: Wbase is the tiny model's weights. Waug are extra "augmentation" weights that extend the model's width during training only. The tiny model must work well both alone (the base term) and as part of the wider augmented model (the auxiliary term). This forces the tiny model's weights to be more versatile — they must generalize enough to be useful in a bigger context. At inference, Waug is discarded — zero overhead. MobileNetV2-Tiny goes from 52% to 56% on ImageNet with NetAug at the same 23.5M MAC budget.

NetAug is a form of knowledge distillation: The augmented wider network acts as a soft teacher, providing gradient signals that the tiny model would not see if trained alone. The tiny model learns to be a good sub-component of something larger, which instills more generalizable representations. The "teacher" here is not pre-trained — it is built on-the-fly by augmenting the tiny model's width during each training step.

Scenario 3: KD for LLMs (Llama 3.2, Minitron)

The same principles scale to billion-parameter language models. Llama 3.2's 1B and 3B models are not trained from scratch — they are distilled from Llama 3.1 8B and 70B respectively. Minitron (NVIDIA) prunes Llama-3.1-8B by removing attention heads and layers, then distills from the original. The distillation loss includes: logit KL (output distribution matching), embedding output loss (match the embedding layer's geometry), and transformer block output losses (layer-wise feature matching — exactly FitNets at LLM scale). The result: Minitron-4B outperforms LLaMA-3-8B trained from scratch at 4B params, at 40× cheaper training cost.

NetAug trains a tiny model with a wider "augmented" version of itself during training. Why does this work better than standard dropout or Mixup for tiny models?

Chapter 9: Connections & Cheat Sheet

KD Cheat Sheet

Loss TermFormulaWhat it does
KL (soft)α·T²·KL(σ(zT/T) ‖ σ(zS/T))Match teacher's full soft distribution; T² restores gradient scale
CE (hard)(1−α)·CE(zS, y)Anchor to ground-truth label; prevents distilling teacher error
Feature (L2)(1/2)‖Wr(fS) − fT‖²FitNets: match hidden activations; needs shape-adapting regressor
Attention‖A(fT) − A(fS)‖²AT: match where the network looks spatially; normalized 2D map
Relational‖ψ(t1..n) − ψ(s1..n)‖²RKD: match pairwise distances between example embeddings

Temperature & α Selection Guide

HyperparameterTypical rangeEffect at extremes
Temperature T3–5 for most CNNs; 1–2 for LLMsT≈1: standard KD, less dark knowledge. T≫5: near-uniform, dark knowledge lost
Mix α0.1–0.5α=0: pure soft (may over-distill error). α=1: pure hard (standard CE, no benefit)

What to Match — Decision Table

Signal AvailableRecommended MatchWhy
Only output logitsHinton KD (logit KL)Simplest, often sufficient; zero architecture constraints
Teacher too large to run (LLM)Pre-generate soft targets offline, cache to diskRun teacher once, reuse; no teacher memory during student training
Student much thinner than teacherFitNets (feature + logit)Feature matching compensates for capacity gap at intermediate layers
Detection/segmentation (spatial output)Attention maps + logit KDSpatial attention teaches where to look — critical for localization
Metric learning / retrievalRelational KD (RKD)Preserves embedding space geometry — pairwise distances matter most
No large teacher availableBorn-Again Networks or Deep Mutual LearningSelf-distillation still improves over hard-label baseline

How KD Composes with the Efficiency Stack

CombinationOrderGain
Pruning + KDPrune first, then distill from full-precision teacherRecovers 3–8% acc vs pruning alone
Quantization + KDQuantize-aware train as student, FP teacherEspecially useful for INT4/INT8; teacher provides smooth gradient
NAS + KDSearch for architecture, then distill large teacher into found archArchitecture gives efficiency; KD gives accuracy
NetAug + KDNetAug during student training; optional teacher for logitsBest for tiny models (<50M MACs) that under-fit standard augmentation

Bridge to Next Lessons

This lesson closes the compression trilogy (Pruning L3–L4, Quantization L5–L6, KD here). The next lecture is MCUNet (Lecture 10) — which deploys the fully compressed+distilled model onto microcontrollers with 256 kB SRAM. MCUNet is the payoff: a NAS-found architecture trained with KD and quantized to INT8, running ImageNet-scale inference on a device smaller than your fingernail.

Related micro-lessons that connect to concepts here:

"What I cannot create, I do not understand." Knowledge distillation operationalizes this: the teacher cannot directly create a small network, but by teaching the small network to reproduce its own outputs and internal structure, it transfers what it understands. The student's ability to match the teacher is the measure of how much was successfully taught — and a student that exceeds the teacher proves the teaching unlocked something the teacher itself hadn't fully exploited.