TinyML & Efficient Deep Learning · MIT 6.5940 · Lecture 21

On-Device Training & Transfer Learning

Inference on a microcontroller is solved — but to personalise a keyword detector to YOUR voice, or a face-unlock model to YOUR face, the device must train. That requires storing every layer's activations for the backward pass. For MobileNetV2 at batch size 8, training memory is 452 MB — 226× larger than the 2 MB an MCU allows. This lesson derives the activation memory bottleneck from first principles, then shows every technique that breaks it: bias-only updates, TinyTL's lite-residual modules, sparse backpropagation, gradient checkpointing, quantized training with QAS, the Tiny Training Engine, and federated learning.

Prerequisites: TinyML L1 (Pruning) — forward/backward basics. TinyML L4 (Quantization) — int8 ops. TinyML L10 (MCUNet) — MCU memory layout.
10
Chapters
5
Live Canvases
Derived
From First Principles

Chapter 0: The Training Wall

You have just deployed a keyword-spotting model to a hearing aid. It was trained on thousands of voices, and it works — most of the time. But the user has an unusual accent. They want the model to learn their specific pronunciation of "Hey Siri." The obvious solution: collect a few dozen examples on the device and fine-tune. Simple, right?

The obstacle is not computation — even a Cortex-M7 can run a few backward passes per second. The obstacle is memory. Specifically, the memory demanded by the backward pass. During inference, a forward pass through a network with L layers only needs to hold one layer's activations in memory at a time: you compute layer i, pass the output to layer i+1, and discard the intermediate. Peak inference memory is proportional to the largest single-layer activation — perhaps a few hundred KB for a well-designed MCU model.

Training is a completely different story. To compute the weight gradient at layer i during the backward pass, you need the input activation to layer i — which was computed during the forward pass. This means you must store all layer activations throughout the entire forward pass before a single backward step begins. For a network with L layers each producing an activation of size A bytes, training requires L × A bytes of activation memory — while inference requires only A bytes. This is the training wall.

The backward-pass activation requirement — derived. For a linear layer, the forward pass computes ai+1 = ai Wi + bi. During the backward pass, the weight gradient is ∂L/∂Wi = aiT · ∂L/∂ai+1. You need ai — the input activation — to form this outer product. The gradient that propagates backward to the previous layer is ∂L/∂ai = (∂L/∂ai+1) · WiT — this does NOT require ai. But the weight gradient does. So: to train weights, you must store activations. To just propagate gradients without updating weights, you don't. This asymmetry is the key to bias-only and sparse update methods.

Let's put numbers on this. MobileNetV2 at input resolution 224×224: inference memory (batch size 1) ≈ 20 MB. Training memory (batch size 8) ≈ 452 MB — a 22× jump. An STM32H7 MCU has 1 MB SRAM. A Raspberry Pi has 256 MB DRAM. Neither can hold 452 MB. Even an iPhone's 6 GB of RAM struggles with full-precision training of a large model. The MCU budget is 320 KB of activation memory and 1 MB of weight storage. The training activation budget for any reasonable neural network is 13,000× larger than the MCU allows.

What about cloud training plus device deployment — the approach we used in TinyML L10? That works for fixed tasks. But it requires data to leave the device. If the task is personalisation (your face, your voice, your writing patterns) or the data is sensitive (medical signals, private code, enterprise documents), uploading to the cloud is unacceptable. On-device training is not a nice-to-have: it is the only privacy-preserving path to personalised AI.

Key distinction: inference memory vs training memory. Inference: store one activation tensor at a time — output of layer i becomes input to layer i+1, then layer i's output can be freed. Peak memory = max single-layer activation ≈ A bytes. Training: must store ALL activations from the forward pass to use during the backward pass. Peak memory = sum of all activation tensors = L × A bytes. For L=20 layers and A=20 MB each, training needs 400 MB. Inference needs 20 MB. Factor: L = 20×.
Inference vs Training Memory — The Explosion

Toggle between inference and training. Drag the slider to change the number of layers. Each bar segment represents one layer's activation memory. Inference only needs to hold the current layer; training must hold ALL simultaneously.

Layers (L) 10
You add a second convolutional layer to your MCU model. How does this affect training memory?

Chapter 1: Deep Leakage — Gradients Are Not Safe

The obvious workaround for training memory is to keep the data on-device but share gradients with a powerful cloud server for aggregation. The user trains locally for a few steps, uploads only the gradient tensor — not their raw images or audio — and the server merges updates from many devices. This is the promise of federated learning, introduced by McMahan et al. in 2016. The key claim: user data never leaves the device, so privacy is preserved.

In 2019, Zhu et al. (NeurIPS) showed this claim is dangerously wrong for naive implementations. Their attack — Deep Leakage from Gradients (DLG) — reconstructs the original training data (images, text, labels) from gradient tensors alone, with no access to the model's training data. The attack is startling: if a server sees your gradient, it can reconstruct the photo or text you used to compute it.

The DLG attack works by gradient matching. The attacker initialises random "dummy" inputs x′ and labels y′. They run a forward+backward pass with the same model to get dummy gradients ∇′. They then minimise the L2 distance between real gradients ∇ and dummy gradients ∇′ by gradient descent on x′ and y′. After a few hundred iterations, x′ converges to the original training data. The attacker never touched the data — only the gradients.

The DLG mechanism — why it works. Gradients ∂L/∂W = ainT · ∂L/∂aout encode information about the input activation ain. For batch size 1, the first FC layer's gradient is an outer product of the input image and the output gradient — which essentially exposes the input. For convolutional layers, the gradient similarly encodes local patch statistics. The matching loss D = ||∇ - ∇′||2 is differentiable with respect to x′, so standard optimisers can minimise it. For batch sizes > 1, reconstruction is harder but still possible for small batches.

Defences against DLG have been proposed: adding Gaussian or Laplacian noise to gradients, gradient quantisation, and gradient compression (e.g., sending only the top 0.1% of gradient values by magnitude). Song Han's group showed that gradient compression at 99% sparsity effectively prevents reconstruction while preserving model accuracy — because highly sparse gradients leak far less information about the input structure. Without gradient compression, adding noise sufficient to prevent leakage destroys accuracy.

The deep implication for on-device learning: keeping data on the device is the only truly safe option. Sharing even compressed gradients has risks. This motivates fully local training approaches — TinyTL, sparse backpropagation, and the Tiny Training Engine — where no gradient ever leaves the device.

Privacy budget intuition. Think of your gradient as a signed fingerprint of your training data. For batch size 1, it is nearly a perfect fingerprint — the outer product ainT ⊗ δ is essentially the data. As batch size grows, individual samples are blended — harder to separate but not impossible. Gradient compression helps: a 1% sparse gradient masks 99% of the fingerprint. But the aggregated server sees many users' compressed gradients — and can sometimes reconstruct individuals by studying the overlap of their sparse indices across rounds.
python
# DLG attack — reconstruct training data from gradients alone
# ~20 lines, 2400+ citations. From Zhu et al. NeurIPS 2019

import torch
import torch.nn.functional as F

def dlg_attack(model, true_grads, n_iters=300):
    """Recover training data from gradient tensor alone."""
    # Initialise random dummy data + label
    dummy_x = torch.randn_like(true_input, requires_grad=True)
    dummy_y = torch.randn(1, n_classes, requires_grad=True)
    optimizer = torch.optim.LBFGS([dummy_x, dummy_y])

    for i in range(n_iters):
        def closure():
            optimizer.zero_grad()
            dummy_pred = model(dummy_x)
            dummy_loss = F.cross_entropy(dummy_pred, F.softmax(dummy_y, dim=-1))
            dummy_grads = torch.autograd.grad(
                dummy_loss, model.parameters(), create_graph=True)
            # Minimise distance between real grads and dummy grads
            grad_diff = sum(((dg - tg) ** 2).sum()
                            for dg, tg in zip(dummy_grads, true_grads))
            grad_diff.backward()
            return grad_diff
        optimizer.step(closure)
    return dummy_x.detach()  # ≈ original training image
The DLG attack works by minimising a loss over what variable?

Chapter 2: Transfer Learning Modes

On-device training almost always means transfer learning: start from a pretrained model (trained in the cloud on a large dataset) and adapt it to a new, small on-device dataset. This works because deep networks learn hierarchical features — early layers detect edges and textures, middle layers detect object parts, late layers detect semantics — and these general features transfer well across tasks. You only need to shift the final semantics, not relearn vision from scratch.

There are three classic modes. Feature extraction (Last-only): freeze all layers, train only the final classification head. Memory cost: tiny (only head gradients). Accuracy: limited — the frozen backbone cannot adjust its representation to the new domain. Partial fine-tuning (BN+Last): freeze conv weights, train BatchNorm scale/shift parameters plus the head. Cheap in parameters (13× fewer than full) but not cheap in activation memory (only 1.8× savings — BN params don't cause activations to be discarded). Full fine-tuning: train all layers. Best accuracy but maximum memory.

The parameter-efficiency trap. Reducing trainable parameters does NOT proportionally reduce activation memory. Why? Activation memory is determined by which layers require a backward pass through their WEIGHTS, not by which weights are trainable. To compute ∂L/∂Wi = aiT · δi, you need to store ai (the input to layer i) during the forward pass. If layer i's weights are frozen (not updated), you still need ai to propagate the gradient BACKWARDS to layer i-1 via ∂L/∂ai = δi · WiT. So unless you stop the backward pass entirely at a layer, all activations above that layer must still be stored. BN+Last: you stop storing activations for BN gradients, but the signal still passes backward — saving only BN's activation terms, which are small.

The result for ResNet-50 on the Stanford Cars dataset: Full fine-tuning — 92% accuracy, 680 MB activation memory. BN+Last — 78% accuracy, 370 MB (1.8× savings). Last-only — 70% accuracy, 25 MB (27× savings). There is a stark tradeoff: big memory savings destroy accuracy; accuracy-preserving partial tuning barely saves memory. Neither is acceptable on a 2 MB MCU. We need a new axis: activation-memory-efficient fine-tuning without accuracy loss.

Transfer Learning Mode Comparison

Click a training mode to highlight it. The bars show activation memory (left axis) and accuracy proxy (right axis). Notice that BN+Last saves parameters (13×) but barely saves memory (1.8×) — the parameter-efficiency trap. TinyTL breaks this pattern.

BN+Last reduces trainable parameters by 13× versus full fine-tuning. By how much does it reduce activation memory?

Chapter 3: TinyTL — Reduce Activations, Not Parameters

TinyTL (Cai et al., NeurIPS 2020) reframes the problem. Previous parameter-efficient transfer learning methods asked: "how do we reduce trainable parameters?" TinyTL asks: "how do we reduce activation memory?" The key insight is to train only biases while freezing all weight tensors. This sounds deceptively simple — let's derive why it actually works for memory.

For a linear layer ai+1 = aiWi + bi, two gradient equations govern the backward pass. The weight gradient: ∂L/∂Wi = aiT · (∂L/∂ai+1). This requires storing ai — the layer's input activation. The bias gradient: ∂L/∂bi = ∂L/∂ai+1. The bias gradient depends only on the downstream gradient, not on the stored activation. If we freeze Wi and only train bi, we never need to compute ∂L/∂Wi, so we never need to store ai. Activation memory for layer i drops to zero — we only need the propagated gradient, not the activation that generated it.

Bias-only update — the activation-free backward pass. Compare the two backward equations side by side. Weight gradient: ∂L/∂W = ainT · δout — requires ain (must be stored during forward). Bias gradient: ∂L/∂b = ∑δout (summed over spatial dims for conv) — requires only δout (the gradient flowing back from the next layer). Conclusion: bias-only training needs ZERO stored activations per layer. The backward pass still propagates through the frozen weight layers (using WTδ to pass the gradient signal backwards), but since we don't need to compute ∂L/∂W, we don't need to store the activation. This is the core mechanism of TinyTL.

Empirically, bias-only training on ResNet-50 cuts activation memory by 12× compared to full fine-tuning. But there is a cost: accuracy on Stanford Cars drops from 92% (full) to 76% (bias-only) — a 16 percentage point gap. The model lacks the capacity to adapt its feature representation using only 1D bias shifts. The features are stuck as they were during pretraining.

TinyTL's solution to the accuracy gap: add a lightweight lite residual module to each inverted bottleneck block. The lite residual is a small, newly initialised branch that runs in parallel with the frozen backbone. It has three design decisions that keep its activation cost minimal. First, reduce the spatial resolution: the lite residual downsamples the input by 2× before processing, halving the spatial activation size. Second, avoid the inverted bottleneck: standard MobileNet blocks expand channels 6× (hence "inverted bottleneck"), which is the largest activation in the block. The lite residual uses only C channels (not 6C). Third, reduce depth: use 2/3 the block depth of the main branch. Combined: the lite residual's activation size is roughly 1/6 channels × 1/4 spatial (from 2× downsample) × 2/3 depth = approximately 4% of the main branch's activation cost.

python
# TinyTL: freeze backbone weights, train biases + lite residual
import torch.nn as nn

class LiteResidual(nn.Module):
    """Low-activation branch: downsample → group conv → upsample."""
    def __init__(self, C, stride=1):
        super().__init__()
        # Downsample to 0.5× spatial resolution — key activation saving
        self.down = nn.AvgPool2d(kernel_size=2, stride=2)
        # Group conv at reduced channels (no inverted bottleneck)
        self.conv = nn.Conv2d(C, C, kernel_size=3, padding=1,
                              groups=C // 4)  # group conv, cheap
        self.bn   = nn.BatchNorm2d(C)
        self.up   = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, x):
        return self.up(self.bn(self.conv(self.down(x))))

def prepare_tinytl(model):
    """Freeze weights; leave biases + lite residuals trainable."""
    for name, param in model.named_parameters():
        if 'weight' in name and 'lite_residual' not in name:
            param.requires_grad = False  # freeze backbone weights
        else:
            param.requires_grad = True   # train biases + lite residual
    return model

The result: TinyTL reaches 90.1% on Stanford Cars (versus 92.2% for full fine-tuning, 76.0% for bias-only), while using only 12% of the activation memory of full fine-tuning — a 6× reduction beyond even BN+Last (1.8×). Furthermore, at batch size 1 with group normalisation replacing batch normalisation (needed since batch size 1 makes BN statistics noisy), TinyTL fits within 16 MB — the size of a typical L3 CPU cache. It enables training inside the cache, which is dramatically more energy-efficient than training in DRAM.

Why does bias-only training save activation memory compared to weight training?

Chapter 4: Sparse Backpropagation

TinyTL freezes all backbone weights and trains only biases. This is a form of sparse update — we update a strict subset of parameters. Sparse backpropagation (SparseBP) generalises this idea: rather than a fixed bias-only policy, use contribution analysis to find which layers, and which channel fractions, are most worth updating for a given downstream task — then skip the backward pass (and thus the activation storage) for all other layers.

The insight: not all layers contribute equally when adapting to a new task. Contribution analysis measures this directly. Fine-tune only one layer at a time on the target task, measure the accuracy improvement Δacc, and rank layers by their contribution. For MobileNetV2, the first depth-wise conv contributes most to visual adaptation tasks. For BERT, the QKV projection layers and the first FFN layer contribute most. Different models, different optimal sparse update patterns.

SparseBP memory calculation. For a conv layer with gradient shape (H, N) × (N, M): the weight gradient is dYT × X, which has FLOPs = M × H × N and requires storing activation X of size (N, M). If you update only 1/4 of the output channels (sparse channel update), the activation to store shrinks to (N, M/4) — a 4× reduction in activation memory and FLOPs simultaneously. Layer sparsity (skipping entire layers) saves the entire layer's activation. The two modes — sparse layers (which layers to update) and sparse tensors (which channels within a layer) — can be combined via evolutionary search over the space of valid sparse configs to find the Pareto frontier of memory versus accuracy.

SparseBP uses evolutionary search to find the optimal sparse configuration. The search space is combinatorial — for each of L layers, choose: update all, update 1/2 channels, update 1/4 channels, or skip. Random search works poorly in this space; evolutionary search converges much faster. The search evaluates each candidate configuration by running a few training steps on a validation split and measuring accuracy. After finding the optimal config, it is fixed for all on-device fine-tuning steps.

python
# Sparse backpropagation via custom autograd hooks

class SparseBackwardLayer(nn.Module):
    """Conv layer with optional sparse channel backward."""
    def __init__(self, in_c, out_c, channel_ratio=1.0, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c, **kwargs)
        self.ratio = channel_ratio

    def forward(self, x):
        if not self.training or self.ratio == 0.0:
            with torch.no_grad():
                return self.conv(x)  # no gradient computed or stored
        if self.ratio < 1.0:
            # Only keep gradients for a fraction of output channels
            n_active = int(self.conv.out_channels * self.ratio)
            # Subset the weight for active channels only
            active_w = self.conv.weight[:n_active]
            out_active = F.conv2d(x, active_w, self.conv.bias[:n_active],
                                    self.conv.stride, self.conv.padding)
            with torch.no_grad():
                out_frozen = F.conv2d(x, self.conv.weight[n_active:],
                                        self.conv.bias[n_active:],
                                        self.conv.stride, self.conv.padding)
            return torch.cat([out_active, out_frozen], dim=1)
        return self.conv(x)  # full backward

On benchmarks — MCUNet, MobileNetV2, ResNet-50, DistilBERT, BERT — SparseBP matches full backpropagation accuracy within ~0.5% while using 4.5–7.5× less activation memory. For LLaMA-2-7B fine-tuning on the Stanford Alpaca dataset using the PockEngine framework (Chapter 7), SparseBP achieves comparable Alpaca-eval win rates to LoRA at rank=8, with 2× faster iteration latency and 14 GB less GPU memory.

Sparse Layer Update — Activation Memory Savings

A 10-layer network. Toggle each layer's update mode: full update (stores activation, high cost), bias-only (no activation stored), or skip (no backward). Watch total activation memory change. The contribution bars show which layers matter most for a sample visual task.

You are fine-tuning a 20-layer CNN. You discover that layers 5, 10, and 15 have the highest contribution analysis scores. You set those three layers to full update and skip all others. Compared to full backpropagation, how does activation memory change?

Chapter 5: Gradient Checkpointing

TinyTL and SparseBP reduce activation memory by restricting which weights are trained. But what if you genuinely need to train all weights of a deep network? Gradient checkpointing (also called activation rematerialisation) offers a different trade: instead of reducing which activations are stored, it reduces how many are stored simultaneously — at the cost of extra computation.

The standard forward pass stores all L activations. The standard backward pass uses each activation exactly once to compute the weight gradient for that layer. Gradient checkpointing instead stores only a subset of C "checkpoint" activations during the forward pass, discarding the rest. When the backward pass reaches a layer whose activation was discarded, it recomputes that activation by running the forward pass again from the nearest checkpoint. You pay for the discarded activations with extra compute, but you save memory.

The square-root memory result — derived. Place checkpoints every k = L/C layers apart. Between checkpoints, the backward pass must rematerialise up to k activations at a time (the segment between two checkpoints). Memory cost: C checkpoints stored permanently + up to k = L/C activations in the recomputed segment = C + L/C. By AM-GM inequality, C + L/C is minimised when C = L/C, i.e. C = √L, giving total memory = 2√L. This is the classic √L memory result: gradient checkpointing with optimally spaced checkpoints reduces activation memory from O(L) to O(√L), at the cost of one extra forward pass per segment (roughly doubling compute). For L=100 layers: standard stores 100 activations; checkpointing stores √100 = 10 checkpoints, recomputing at most 10 per segment — a 10× memory saving for 2× compute cost.
python
# Gradient checkpointing with PyTorch's built-in API
import torch.utils.checkpoint as cp

class CheckpointedResBlock(nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block

    def forward(self, x):
        # checkpoint() re-runs block in backward pass instead of storing activations
        return cp.checkpoint(self.block, x, use_reentrant=False)

# Manual segment checkpointing — place a checkpoint every k layers
def forward_with_checkpoints(layers, x, k=4):
    """Run forward pass, storing only 1 in k activations."""
    checkpoints = []
    for i, layer in enumerate(layers):
        if i % k == 0:
            # Store this activation as a checkpoint
            checkpoints.append(x.detach().requires_grad_(True))
            x = layer(checkpoints[-1])
        else:
            with torch.no_grad():  # no grad graph for non-checkpoint layers
                x = layer(x)       # activation discarded after use
    return x, checkpoints

# Memory: C + L/C stored activations (optimal C = sqrt(L))
# Compute: ~2× forward passes (one forward + one per segment in backward)

Gradient checkpointing does not help with the weight memory problem — you still need all model parameters. But it directly attacks the activation memory problem. Combining checkpointing with bias-only training or SparseBP can further reduce memory: skip backpropagation through early layers (no activation storage needed) and checkpoint the remaining layers that are trained.

Gradient Checkpointing — Memory vs Compute Tradeoff

Drag the slider to set the number of checkpoints C (from 1 to L). The left bar shows activation memory (stored checkpoints + largest recomputed segment). The right bar shows extra compute (recomputation cost). The optimal is at C = √L. An MCU memory limit line shows which configurations fit.

Checkpoints (C) 5
For a 36-layer network, what is the optimal number of checkpoints and the resulting memory compared to storing all activations?

Chapter 6: Quantized Training & QAS

All the techniques so far operate at FP32. But MCUs don't have FP32 ALUs — they operate in INT8 or INT16. To run training on an MCU, we need quantised training: keep the entire training computation in low-precision integers. This would also solve the memory problem more radically — INT8 activations are 4× smaller than FP32, and INT8 weights are 4× smaller.

The naive approach is Quantisation-Aware Training (QAT), also called fake quantisation: maintain FP32 weights and activations throughout training, but insert "fake quantise" operations that round values to the nearest integer grid and back. This preserves the FP32 computation graph (backward passes work normally) but does not actually save memory — the tensors are still FP32 internally. Fake quantisation is used for preparing a model for inference deployment, not for reducing training memory.

The real vs fake quantisation distinction. Fake QAT: weights stored as FP32, passed through a differentiable rounding simulation q(x) = Round(x/s) × s where s is the scale factor. The gradient ∂q/∂x is approximated as 1 (straight-through estimator). Memory savings: zero — everything is FP32. Real quantisation: weights stored as INT8, all multiply-accumulate operations done in INT32, activations stored as INT8. Memory: 4× smaller than FP32. Problem: gradients are ill-defined for integer operations — how do you differentiate INT32 → INT8 quantisation? And batch normalisation (standard in FP32 training) has no integer analogue. Real quantised training is 4× more memory-efficient but dramatically harder to optimise.

The core problem of real quantised training is a scale mismatch between weights and gradients. In FP32 training, Adam adaptively scales gradients to match the weight update magnitude — the ratio ||W||/||G|| stays roughly consistent across layers. In INT8 training, quantising weights to INT8 changes their effective scale by a factor SW (the weight quantisation scale), but the gradient scale doesn't automatically adjust. Empirically, log10(||W||/||G||) spreads from +5 to -5 across layers in INT8 training versus a tight band around 0 in FP32. This scale mismatch causes catastrophic oscillations and non-convergence — INT8 SGD on ResNet-50 drops accuracy from 75.4% (FP32) to 64.8%.

Quantisation-Aware Scaling (QAS), proposed by Han Lab, fixes this by explicitly re-scaling the gradients to compensate for the quantisation scale shift. For a weight tensor with per-channel quantisation scale SW (per channel, not scalar), the weight gradient ∂L/∂W is multiplied by SW-2 before the weight update. This compensates for the fact that quantising W by SW changes the effective loss curvature by SW2. After QAS correction, the W/G ratio plot for INT8+QAS closely matches FP32, and convergence is recovered — ResNet-50 INT8+QAS achieves 74.1% (versus 75.4% FP32), a gap of only 1.3% versus the catastrophic 10.6% drop without QAS.

python
# Quantisation-Aware Scaling (QAS) — correct gradient scale mismatch
class QASLayer(nn.Module):
    """INT8 conv with per-channel gradient rescaling for QAS."""
    def __init__(self, in_c, out_c, kernel_size=3):
        super().__init__()
        # Weights stored as FP32 internally; quantised to INT8 during forward
        self.weight = nn.Parameter(torch.randn(out_c, in_c, kernel_size, kernel_size))
        self.scale_w = nn.Parameter(torch.ones(out_c))  # per-channel scale

    def quantise_weight(self):
        # Per-channel scale: sw = max(|W|) / 127 for each output channel
        sw = self.weight.abs().amax(dim=(1,2,3), keepdim=True) / 127
        return (self.weight / sw).round().clamp(-128, 127) * sw, sw

    def forward(self, x):
        w_q, sw = self.quantise_weight()
        return F.conv2d(x, w_q, padding=1)

    def qas_correct_grads(self):
        """Apply sw^-2 correction to gradient — fixes scale mismatch."""
        if self.weight.grad is not None:
            sw = self.weight.abs().amax(dim=(1,2,3), keepdim=True) / 127
            self.weight.grad /= (sw ** 2)  # QAS correction
Why does fake quantisation (QAT) fail to reduce training memory on an MCU?

Chapter 7: The Tiny Training Engine & PockEngine

Algorithms are only half the story. Even after choosing sparse updates, bias-only gradients, and INT8 activations, a standard deep learning framework like PyTorch generates a dense computation graph for the backward pass. It doesn't know that most layers are frozen — it preallocates gradient buffers for all parameters, allocates activations for all layers, and runs backward through every node in the autograd graph. To deploy on-device training to a microcontroller, we need a system co-designed with the sparse training algorithm.

The Tiny Training Engine (TTE) (Lin et al., NeurIPS 2022 — "On-Device Training Under 256 KB Memory") is exactly this co-design. TTE operates at compile time rather than runtime. Given a neural network and a sparse update specification (which layers, which parameter types are trainable), TTE performs compile-time automatic differentiation: it traces the computation graph, identifies which activations are actually needed for the specified gradients, prunes the backward graph to eliminate unnecessary operations, and generates a minimal C++ kernel that runs the entire backward pass without any framework overhead.

Graph pruning: why compile-time matters. PyTorch's autograd builds the backward graph dynamically at runtime — it can't know ahead of time which gradients will be needed. So it stores activations "just in case." A layer with frozen weights still has its activation stored, because autograd doesn't know whether you'll call .backward() on it. TTE, by contrast, performs autograd at compile time: given a fixed sparse update spec, it symbolically computes which nodes in the computation graph are ancestors of the trainable parameters' gradient nodes. Anything that's not an ancestor can be pruned — its activation never needs to be stored. This is static graph pruning, and it's only possible because on-device training uses a fixed (not dynamic) update configuration.

TTE achieves on-device training of a VWW (Visual Wake Words) model on an ARM Cortex-M7 MCU with 256 KB SRAM and 1 MB Flash. The resulting configuration: 158 KB activation memory (below 256 KB SRAM), 1.36 MB weights (fits in Flash), 5.6% accuracy improvement on a new task over inference-only deployment. Training throughput: approximately 5 forward-backward passes per second.

PockEngine extends TTE to mobile and edge GPU platforms (Jetson Orin, Apple M1). For LLaMA-2-7B fine-tuning on Jetson Orin, PockEngine achieves 1.8 seconds per iteration versus PyTorch's 7.7 seconds (4.3× speedup) for full fine-tuning, and 0.9 seconds with sparse backpropagation (8.6× speedup). The key system techniques: (1) prune the backward autograd graph at compile time based on the sparse update mask; (2) reorder operators to maximise data locality and minimise cache misses; (3) fuse adjacent sparse conv backward operations; (4) generate device-specific Metal (Apple) or CUDA (NVIDIA) kernels from the pruned graph.

python
# TTE-style compile-time graph pruning (simplified pseudocode)

def prune_backward_graph(model, trainable_params):
    """Remove backward nodes not needed for trainable_params gradients."""
    # Step 1: Build full autograd computation graph
    full_graph = trace_computation_graph(model)

    # Step 2: Mark leaf nodes = trainable parameter nodes
    targets = {id(p) for p in trainable_params}

    # Step 3: Find all ancestors of target nodes (these are NEEDED)
    needed = set()
    for node in full_graph:
        if id(node.param) in targets:
            needed |= ancestors(node, full_graph)

    # Step 4: Generate pruned C++ backward kernel
    pruned_graph = [n for n in full_graph if n in needed]
    return codegen_backward_kernel(pruned_graph)
    # Result: only activations needed by trainable-param ancestors are stored
    # All other activations: never allocated
Why does TTE perform autograd at compile time rather than using PyTorch's standard dynamic autograd?

Chapter 8: Showcase — Full On-Device Training Lab

This is the interactive payoff. You're the system designer deploying a keyword-spotting network to an MCU. Configure every axis of the training problem: how many layers to train (full vs sparse), which parameter types (weights vs biases), how many gradient checkpoints, and whether to use federated aggregation across multiple devices. Watch activation memory, model capacity, and training quality update in real time.

On-Device Training Configuration Explorer

A 12-layer CNN on an MCU with 320 KB activation budget (red line). Configure your training strategy. The canvas shows: activation memory breakdown (by layer), total memory bar versus MCU budget, trainable parameter count, and an accuracy-proxy estimate. Find configurations that fit the MCU while preserving accuracy.

Sparse Update Fraction
30%
Gradient Checkpoints
4
Federated rounds:
Federated Devices 1
Federated Learning — Local Training Stays Local

N devices each hold private data. Each trains locally for E epochs. Only model weights (or compressed gradients) are sent to the server for FedAvg aggregation. Watch the global model accuracy improve over rounds. The data dots never leave their device columns.

Chapter 9: Connections & Series Recap

Lesson 17 closes the TinyML series. Let's consolidate the complete on-device training toolkit, then step back to see how this lesson connects to every other concept in the series.

On-Device Training Cheat Sheet

TechniqueMemory SavingAccuracy CostCompute CostWhen to Use
Last-only27× activationHigh (−16%)MinimalExtreme memory budget; linear head only
BN+Last1.8× activationModerate (−8%)MinimalWhen parameter count matters more than memory
Bias-only (TinyTL)12× activationModerate (−5%)Same fwd; lighter bwdPrivacy-critical edge devices; no lite residual capacity
TinyTL full6–8× activationMinimal (−1%)+5% for lite residual fwdMCU with 2–16 MB SRAM; best accuracy-per-byte
Sparse BP4.5–7.5× activationMinimal (−0.5%)+search cost onceHeterogeneous model families; LLM fine-tuning on edge GPU
Grad. Checkpoint√L× activationZero~2× FLOPsFull fine-tuning; sufficient compute; insufficient memory
INT8 + QAS4× (all tensors)~1% vs FP32Hardware-dependentMCU with INT8 ALU; complement to all above methods
TTE/PockEngineGraph pruningZero (system only)3–4× fasterProduction deployment; compile-time fixed update spec

Transfer Learning Decision Tree

Memory < 50 KB?
Bias-only or Last-only — no stored activations at all
↓ else
Memory 50 KB–2 MB?
TinyTL (bias + lite residual) — 12× activation saving with minimal accuracy loss
↓ else
Memory 2 MB–32 MB?
SparseBP via contribution analysis + gradient checkpointing for remaining layers
↓ else (edge GPU / phone)
Memory 32 MB+?
PockEngine with sparse BP — fine-tune LLM on Jetson/Apple Silicon at 4–8× speedup
↓ data privacy required?
Federated Learning
FedAvg — local training + gradient compression — data never leaves device

The Complete TinyML Efficiency Stack

This series traced a single arc: how do we bring state-of-the-art deep learning to devices with kilobyte budgets? Each lesson attacked one constraint. Here is the full stack, from model design down to on-device training:

LessonTopicKey Insight
L1–L2Pruning70–90% of weights are redundant; magnitude and structured pruning
L3QuantisationINT8 is 4× smaller and 4× faster; scale calibration is the key problem
L4Low-Rank DecompositionSVD + Tucker decomposition; LoRA as rank-constrained delta
L5Knowledge DistillationTrain small from large soft labels; intermediate feature matching
L6NAS: EfficiencySearch for architectures on the Pareto frontier of accuracy vs latency
L7NAS: AlgorithmsDARTS, one-shot, hardware-aware search spaces
L8TransformersEfficient attention: linear, sparse, MLA, FlashAttention IO complexity
L9LLM EfficiencyMQA, GQA, speculative decoding, MoE routing, quantised inference
L10MCU Inference (MCUNet)Joint NAS over architecture + deployment schedule in 256 KB SRAM
L11Deployment EnginesTVM, MLIR, operator fusion, memory planning, runtime scheduling
L12Video UnderstandingTemporal redundancy; 3D conv decomposition; token filtering
L13Point CloudsIrregular 3D data; PointNet invariance; sparse voxel convolutions
L14GAN CompressionDiscriminator pruning; progressive distillation; once-for-all GAN
L15Diffusion CompressionDDIM fewer steps; progressive distillation; latent diffusion
L16Distributed TrainingData/tensor/pipeline parallelism; ZeRO; 3D parallelism for 175B models
L17On-Device TrainingActivation memory is the wall; bias-only/TinyTL/SparseBP/TTE break it

Related Gleams

Series complete. The TinyML series (Lessons 1–17) covers the complete spectrum of efficient deep learning: from removing weights and bits at the model level, to hardware-aware deployment on MCUs and edge GPUs, to the newest frontier — training on-device while keeping data private. The core principle that runs through every lesson: efficiency is a first-class constraint, not an afterthought. The best model is the one that fits your hardware while solving the task.
You need to personalise a face-recognition model on an MCU (256 KB SRAM). Which combination of techniques enables this?