Inference on a microcontroller is solved — but to personalise a keyword detector to YOUR voice, or a face-unlock model to YOUR face, the device must train. That requires storing every layer's activations for the backward pass. For MobileNetV2 at batch size 8, training memory is 452 MB — 226× larger than the 2 MB an MCU allows. This lesson derives the activation memory bottleneck from first principles, then shows every technique that breaks it: bias-only updates, TinyTL's lite-residual modules, sparse backpropagation, gradient checkpointing, quantized training with QAS, the Tiny Training Engine, and federated learning.
You have just deployed a keyword-spotting model to a hearing aid. It was trained on thousands of voices, and it works — most of the time. But the user has an unusual accent. They want the model to learn their specific pronunciation of "Hey Siri." The obvious solution: collect a few dozen examples on the device and fine-tune. Simple, right?
The obstacle is not computation — even a Cortex-M7 can run a few backward passes per second. The obstacle is memory. Specifically, the memory demanded by the backward pass. During inference, a forward pass through a network with L layers only needs to hold one layer's activations in memory at a time: you compute layer i, pass the output to layer i+1, and discard the intermediate. Peak inference memory is proportional to the largest single-layer activation — perhaps a few hundred KB for a well-designed MCU model.
Training is a completely different story. To compute the weight gradient at layer i during the backward pass, you need the input activation to layer i — which was computed during the forward pass. This means you must store all layer activations throughout the entire forward pass before a single backward step begins. For a network with L layers each producing an activation of size A bytes, training requires L × A bytes of activation memory — while inference requires only A bytes. This is the training wall.
Let's put numbers on this. MobileNetV2 at input resolution 224×224: inference memory (batch size 1) ≈ 20 MB. Training memory (batch size 8) ≈ 452 MB — a 22× jump. An STM32H7 MCU has 1 MB SRAM. A Raspberry Pi has 256 MB DRAM. Neither can hold 452 MB. Even an iPhone's 6 GB of RAM struggles with full-precision training of a large model. The MCU budget is 320 KB of activation memory and 1 MB of weight storage. The training activation budget for any reasonable neural network is 13,000× larger than the MCU allows.
What about cloud training plus device deployment — the approach we used in TinyML L10? That works for fixed tasks. But it requires data to leave the device. If the task is personalisation (your face, your voice, your writing patterns) or the data is sensitive (medical signals, private code, enterprise documents), uploading to the cloud is unacceptable. On-device training is not a nice-to-have: it is the only privacy-preserving path to personalised AI.
Toggle between inference and training. Drag the slider to change the number of layers. Each bar segment represents one layer's activation memory. Inference only needs to hold the current layer; training must hold ALL simultaneously.
The obvious workaround for training memory is to keep the data on-device but share gradients with a powerful cloud server for aggregation. The user trains locally for a few steps, uploads only the gradient tensor — not their raw images or audio — and the server merges updates from many devices. This is the promise of federated learning, introduced by McMahan et al. in 2016. The key claim: user data never leaves the device, so privacy is preserved.
In 2019, Zhu et al. (NeurIPS) showed this claim is dangerously wrong for naive implementations. Their attack — Deep Leakage from Gradients (DLG) — reconstructs the original training data (images, text, labels) from gradient tensors alone, with no access to the model's training data. The attack is startling: if a server sees your gradient, it can reconstruct the photo or text you used to compute it.
The DLG attack works by gradient matching. The attacker initialises random "dummy" inputs x′ and labels y′. They run a forward+backward pass with the same model to get dummy gradients ∇′. They then minimise the L2 distance between real gradients ∇ and dummy gradients ∇′ by gradient descent on x′ and y′. After a few hundred iterations, x′ converges to the original training data. The attacker never touched the data — only the gradients.
Defences against DLG have been proposed: adding Gaussian or Laplacian noise to gradients, gradient quantisation, and gradient compression (e.g., sending only the top 0.1% of gradient values by magnitude). Song Han's group showed that gradient compression at 99% sparsity effectively prevents reconstruction while preserving model accuracy — because highly sparse gradients leak far less information about the input structure. Without gradient compression, adding noise sufficient to prevent leakage destroys accuracy.
The deep implication for on-device learning: keeping data on the device is the only truly safe option. Sharing even compressed gradients has risks. This motivates fully local training approaches — TinyTL, sparse backpropagation, and the Tiny Training Engine — where no gradient ever leaves the device.
python # DLG attack — reconstruct training data from gradients alone # ~20 lines, 2400+ citations. From Zhu et al. NeurIPS 2019 import torch import torch.nn.functional as F def dlg_attack(model, true_grads, n_iters=300): """Recover training data from gradient tensor alone.""" # Initialise random dummy data + label dummy_x = torch.randn_like(true_input, requires_grad=True) dummy_y = torch.randn(1, n_classes, requires_grad=True) optimizer = torch.optim.LBFGS([dummy_x, dummy_y]) for i in range(n_iters): def closure(): optimizer.zero_grad() dummy_pred = model(dummy_x) dummy_loss = F.cross_entropy(dummy_pred, F.softmax(dummy_y, dim=-1)) dummy_grads = torch.autograd.grad( dummy_loss, model.parameters(), create_graph=True) # Minimise distance between real grads and dummy grads grad_diff = sum(((dg - tg) ** 2).sum() for dg, tg in zip(dummy_grads, true_grads)) grad_diff.backward() return grad_diff optimizer.step(closure) return dummy_x.detach() # ≈ original training image
On-device training almost always means transfer learning: start from a pretrained model (trained in the cloud on a large dataset) and adapt it to a new, small on-device dataset. This works because deep networks learn hierarchical features — early layers detect edges and textures, middle layers detect object parts, late layers detect semantics — and these general features transfer well across tasks. You only need to shift the final semantics, not relearn vision from scratch.
There are three classic modes. Feature extraction (Last-only): freeze all layers, train only the final classification head. Memory cost: tiny (only head gradients). Accuracy: limited — the frozen backbone cannot adjust its representation to the new domain. Partial fine-tuning (BN+Last): freeze conv weights, train BatchNorm scale/shift parameters plus the head. Cheap in parameters (13× fewer than full) but not cheap in activation memory (only 1.8× savings — BN params don't cause activations to be discarded). Full fine-tuning: train all layers. Best accuracy but maximum memory.
The result for ResNet-50 on the Stanford Cars dataset: Full fine-tuning — 92% accuracy, 680 MB activation memory. BN+Last — 78% accuracy, 370 MB (1.8× savings). Last-only — 70% accuracy, 25 MB (27× savings). There is a stark tradeoff: big memory savings destroy accuracy; accuracy-preserving partial tuning barely saves memory. Neither is acceptable on a 2 MB MCU. We need a new axis: activation-memory-efficient fine-tuning without accuracy loss.
Click a training mode to highlight it. The bars show activation memory (left axis) and accuracy proxy (right axis). Notice that BN+Last saves parameters (13×) but barely saves memory (1.8×) — the parameter-efficiency trap. TinyTL breaks this pattern.
TinyTL (Cai et al., NeurIPS 2020) reframes the problem. Previous parameter-efficient transfer learning methods asked: "how do we reduce trainable parameters?" TinyTL asks: "how do we reduce activation memory?" The key insight is to train only biases while freezing all weight tensors. This sounds deceptively simple — let's derive why it actually works for memory.
For a linear layer ai+1 = aiWi + bi, two gradient equations govern the backward pass. The weight gradient: ∂L/∂Wi = aiT · (∂L/∂ai+1). This requires storing ai — the layer's input activation. The bias gradient: ∂L/∂bi = ∂L/∂ai+1. The bias gradient depends only on the downstream gradient, not on the stored activation. If we freeze Wi and only train bi, we never need to compute ∂L/∂Wi, so we never need to store ai. Activation memory for layer i drops to zero — we only need the propagated gradient, not the activation that generated it.
Empirically, bias-only training on ResNet-50 cuts activation memory by 12× compared to full fine-tuning. But there is a cost: accuracy on Stanford Cars drops from 92% (full) to 76% (bias-only) — a 16 percentage point gap. The model lacks the capacity to adapt its feature representation using only 1D bias shifts. The features are stuck as they were during pretraining.
TinyTL's solution to the accuracy gap: add a lightweight lite residual module to each inverted bottleneck block. The lite residual is a small, newly initialised branch that runs in parallel with the frozen backbone. It has three design decisions that keep its activation cost minimal. First, reduce the spatial resolution: the lite residual downsamples the input by 2× before processing, halving the spatial activation size. Second, avoid the inverted bottleneck: standard MobileNet blocks expand channels 6× (hence "inverted bottleneck"), which is the largest activation in the block. The lite residual uses only C channels (not 6C). Third, reduce depth: use 2/3 the block depth of the main branch. Combined: the lite residual's activation size is roughly 1/6 channels × 1/4 spatial (from 2× downsample) × 2/3 depth = approximately 4% of the main branch's activation cost.
python # TinyTL: freeze backbone weights, train biases + lite residual import torch.nn as nn class LiteResidual(nn.Module): """Low-activation branch: downsample → group conv → upsample.""" def __init__(self, C, stride=1): super().__init__() # Downsample to 0.5× spatial resolution — key activation saving self.down = nn.AvgPool2d(kernel_size=2, stride=2) # Group conv at reduced channels (no inverted bottleneck) self.conv = nn.Conv2d(C, C, kernel_size=3, padding=1, groups=C // 4) # group conv, cheap self.bn = nn.BatchNorm2d(C) self.up = nn.Upsample(scale_factor=2, mode='bilinear') def forward(self, x): return self.up(self.bn(self.conv(self.down(x)))) def prepare_tinytl(model): """Freeze weights; leave biases + lite residuals trainable.""" for name, param in model.named_parameters(): if 'weight' in name and 'lite_residual' not in name: param.requires_grad = False # freeze backbone weights else: param.requires_grad = True # train biases + lite residual return model
The result: TinyTL reaches 90.1% on Stanford Cars (versus 92.2% for full fine-tuning, 76.0% for bias-only), while using only 12% of the activation memory of full fine-tuning — a 6× reduction beyond even BN+Last (1.8×). Furthermore, at batch size 1 with group normalisation replacing batch normalisation (needed since batch size 1 makes BN statistics noisy), TinyTL fits within 16 MB — the size of a typical L3 CPU cache. It enables training inside the cache, which is dramatically more energy-efficient than training in DRAM.
TinyTL freezes all backbone weights and trains only biases. This is a form of sparse update — we update a strict subset of parameters. Sparse backpropagation (SparseBP) generalises this idea: rather than a fixed bias-only policy, use contribution analysis to find which layers, and which channel fractions, are most worth updating for a given downstream task — then skip the backward pass (and thus the activation storage) for all other layers.
The insight: not all layers contribute equally when adapting to a new task. Contribution analysis measures this directly. Fine-tune only one layer at a time on the target task, measure the accuracy improvement Δacc, and rank layers by their contribution. For MobileNetV2, the first depth-wise conv contributes most to visual adaptation tasks. For BERT, the QKV projection layers and the first FFN layer contribute most. Different models, different optimal sparse update patterns.
SparseBP uses evolutionary search to find the optimal sparse configuration. The search space is combinatorial — for each of L layers, choose: update all, update 1/2 channels, update 1/4 channels, or skip. Random search works poorly in this space; evolutionary search converges much faster. The search evaluates each candidate configuration by running a few training steps on a validation split and measuring accuracy. After finding the optimal config, it is fixed for all on-device fine-tuning steps.
python # Sparse backpropagation via custom autograd hooks class SparseBackwardLayer(nn.Module): """Conv layer with optional sparse channel backward.""" def __init__(self, in_c, out_c, channel_ratio=1.0, **kwargs): super().__init__() self.conv = nn.Conv2d(in_c, out_c, **kwargs) self.ratio = channel_ratio def forward(self, x): if not self.training or self.ratio == 0.0: with torch.no_grad(): return self.conv(x) # no gradient computed or stored if self.ratio < 1.0: # Only keep gradients for a fraction of output channels n_active = int(self.conv.out_channels * self.ratio) # Subset the weight for active channels only active_w = self.conv.weight[:n_active] out_active = F.conv2d(x, active_w, self.conv.bias[:n_active], self.conv.stride, self.conv.padding) with torch.no_grad(): out_frozen = F.conv2d(x, self.conv.weight[n_active:], self.conv.bias[n_active:], self.conv.stride, self.conv.padding) return torch.cat([out_active, out_frozen], dim=1) return self.conv(x) # full backward
On benchmarks — MCUNet, MobileNetV2, ResNet-50, DistilBERT, BERT — SparseBP matches full backpropagation accuracy within ~0.5% while using 4.5–7.5× less activation memory. For LLaMA-2-7B fine-tuning on the Stanford Alpaca dataset using the PockEngine framework (Chapter 7), SparseBP achieves comparable Alpaca-eval win rates to LoRA at rank=8, with 2× faster iteration latency and 14 GB less GPU memory.
A 10-layer network. Toggle each layer's update mode: full update (stores activation, high cost), bias-only (no activation stored), or skip (no backward). Watch total activation memory change. The contribution bars show which layers matter most for a sample visual task.
TinyTL and SparseBP reduce activation memory by restricting which weights are trained. But what if you genuinely need to train all weights of a deep network? Gradient checkpointing (also called activation rematerialisation) offers a different trade: instead of reducing which activations are stored, it reduces how many are stored simultaneously — at the cost of extra computation.
The standard forward pass stores all L activations. The standard backward pass uses each activation exactly once to compute the weight gradient for that layer. Gradient checkpointing instead stores only a subset of C "checkpoint" activations during the forward pass, discarding the rest. When the backward pass reaches a layer whose activation was discarded, it recomputes that activation by running the forward pass again from the nearest checkpoint. You pay for the discarded activations with extra compute, but you save memory.
python # Gradient checkpointing with PyTorch's built-in API import torch.utils.checkpoint as cp class CheckpointedResBlock(nn.Module): def __init__(self, block): super().__init__() self.block = block def forward(self, x): # checkpoint() re-runs block in backward pass instead of storing activations return cp.checkpoint(self.block, x, use_reentrant=False) # Manual segment checkpointing — place a checkpoint every k layers def forward_with_checkpoints(layers, x, k=4): """Run forward pass, storing only 1 in k activations.""" checkpoints = [] for i, layer in enumerate(layers): if i % k == 0: # Store this activation as a checkpoint checkpoints.append(x.detach().requires_grad_(True)) x = layer(checkpoints[-1]) else: with torch.no_grad(): # no grad graph for non-checkpoint layers x = layer(x) # activation discarded after use return x, checkpoints # Memory: C + L/C stored activations (optimal C = sqrt(L)) # Compute: ~2× forward passes (one forward + one per segment in backward)
Gradient checkpointing does not help with the weight memory problem — you still need all model parameters. But it directly attacks the activation memory problem. Combining checkpointing with bias-only training or SparseBP can further reduce memory: skip backpropagation through early layers (no activation storage needed) and checkpoint the remaining layers that are trained.
Drag the slider to set the number of checkpoints C (from 1 to L). The left bar shows activation memory (stored checkpoints + largest recomputed segment). The right bar shows extra compute (recomputation cost). The optimal is at C = √L. An MCU memory limit line shows which configurations fit.
All the techniques so far operate at FP32. But MCUs don't have FP32 ALUs — they operate in INT8 or INT16. To run training on an MCU, we need quantised training: keep the entire training computation in low-precision integers. This would also solve the memory problem more radically — INT8 activations are 4× smaller than FP32, and INT8 weights are 4× smaller.
The naive approach is Quantisation-Aware Training (QAT), also called fake quantisation: maintain FP32 weights and activations throughout training, but insert "fake quantise" operations that round values to the nearest integer grid and back. This preserves the FP32 computation graph (backward passes work normally) but does not actually save memory — the tensors are still FP32 internally. Fake quantisation is used for preparing a model for inference deployment, not for reducing training memory.
The core problem of real quantised training is a scale mismatch between weights and gradients. In FP32 training, Adam adaptively scales gradients to match the weight update magnitude — the ratio ||W||/||G|| stays roughly consistent across layers. In INT8 training, quantising weights to INT8 changes their effective scale by a factor SW (the weight quantisation scale), but the gradient scale doesn't automatically adjust. Empirically, log10(||W||/||G||) spreads from +5 to -5 across layers in INT8 training versus a tight band around 0 in FP32. This scale mismatch causes catastrophic oscillations and non-convergence — INT8 SGD on ResNet-50 drops accuracy from 75.4% (FP32) to 64.8%.
Quantisation-Aware Scaling (QAS), proposed by Han Lab, fixes this by explicitly re-scaling the gradients to compensate for the quantisation scale shift. For a weight tensor with per-channel quantisation scale SW (per channel, not scalar), the weight gradient ∂L/∂W is multiplied by SW-2 before the weight update. This compensates for the fact that quantising W by SW changes the effective loss curvature by SW2. After QAS correction, the W/G ratio plot for INT8+QAS closely matches FP32, and convergence is recovered — ResNet-50 INT8+QAS achieves 74.1% (versus 75.4% FP32), a gap of only 1.3% versus the catastrophic 10.6% drop without QAS.
python # Quantisation-Aware Scaling (QAS) — correct gradient scale mismatch class QASLayer(nn.Module): """INT8 conv with per-channel gradient rescaling for QAS.""" def __init__(self, in_c, out_c, kernel_size=3): super().__init__() # Weights stored as FP32 internally; quantised to INT8 during forward self.weight = nn.Parameter(torch.randn(out_c, in_c, kernel_size, kernel_size)) self.scale_w = nn.Parameter(torch.ones(out_c)) # per-channel scale def quantise_weight(self): # Per-channel scale: sw = max(|W|) / 127 for each output channel sw = self.weight.abs().amax(dim=(1,2,3), keepdim=True) / 127 return (self.weight / sw).round().clamp(-128, 127) * sw, sw def forward(self, x): w_q, sw = self.quantise_weight() return F.conv2d(x, w_q, padding=1) def qas_correct_grads(self): """Apply sw^-2 correction to gradient — fixes scale mismatch.""" if self.weight.grad is not None: sw = self.weight.abs().amax(dim=(1,2,3), keepdim=True) / 127 self.weight.grad /= (sw ** 2) # QAS correction
Algorithms are only half the story. Even after choosing sparse updates, bias-only gradients, and INT8 activations, a standard deep learning framework like PyTorch generates a dense computation graph for the backward pass. It doesn't know that most layers are frozen — it preallocates gradient buffers for all parameters, allocates activations for all layers, and runs backward through every node in the autograd graph. To deploy on-device training to a microcontroller, we need a system co-designed with the sparse training algorithm.
The Tiny Training Engine (TTE) (Lin et al., NeurIPS 2022 — "On-Device Training Under 256 KB Memory") is exactly this co-design. TTE operates at compile time rather than runtime. Given a neural network and a sparse update specification (which layers, which parameter types are trainable), TTE performs compile-time automatic differentiation: it traces the computation graph, identifies which activations are actually needed for the specified gradients, prunes the backward graph to eliminate unnecessary operations, and generates a minimal C++ kernel that runs the entire backward pass without any framework overhead.
TTE achieves on-device training of a VWW (Visual Wake Words) model on an ARM Cortex-M7 MCU with 256 KB SRAM and 1 MB Flash. The resulting configuration: 158 KB activation memory (below 256 KB SRAM), 1.36 MB weights (fits in Flash), 5.6% accuracy improvement on a new task over inference-only deployment. Training throughput: approximately 5 forward-backward passes per second.
PockEngine extends TTE to mobile and edge GPU platforms (Jetson Orin, Apple M1). For LLaMA-2-7B fine-tuning on Jetson Orin, PockEngine achieves 1.8 seconds per iteration versus PyTorch's 7.7 seconds (4.3× speedup) for full fine-tuning, and 0.9 seconds with sparse backpropagation (8.6× speedup). The key system techniques: (1) prune the backward autograd graph at compile time based on the sparse update mask; (2) reorder operators to maximise data locality and minimise cache misses; (3) fuse adjacent sparse conv backward operations; (4) generate device-specific Metal (Apple) or CUDA (NVIDIA) kernels from the pruned graph.
python # TTE-style compile-time graph pruning (simplified pseudocode) def prune_backward_graph(model, trainable_params): """Remove backward nodes not needed for trainable_params gradients.""" # Step 1: Build full autograd computation graph full_graph = trace_computation_graph(model) # Step 2: Mark leaf nodes = trainable parameter nodes targets = {id(p) for p in trainable_params} # Step 3: Find all ancestors of target nodes (these are NEEDED) needed = set() for node in full_graph: if id(node.param) in targets: needed |= ancestors(node, full_graph) # Step 4: Generate pruned C++ backward kernel pruned_graph = [n for n in full_graph if n in needed] return codegen_backward_kernel(pruned_graph) # Result: only activations needed by trainable-param ancestors are stored # All other activations: never allocated
This is the interactive payoff. You're the system designer deploying a keyword-spotting network to an MCU. Configure every axis of the training problem: how many layers to train (full vs sparse), which parameter types (weights vs biases), how many gradient checkpoints, and whether to use federated aggregation across multiple devices. Watch activation memory, model capacity, and training quality update in real time.
A 12-layer CNN on an MCU with 320 KB activation budget (red line). Configure your training strategy. The canvas shows: activation memory breakdown (by layer), total memory bar versus MCU budget, trainable parameter count, and an accuracy-proxy estimate. Find configurations that fit the MCU while preserving accuracy.
N devices each hold private data. Each trains locally for E epochs. Only model weights (or compressed gradients) are sent to the server for FedAvg aggregation. Watch the global model accuracy improve over rounds. The data dots never leave their device columns.
Lesson 17 closes the TinyML series. Let's consolidate the complete on-device training toolkit, then step back to see how this lesson connects to every other concept in the series.
| Technique | Memory Saving | Accuracy Cost | Compute Cost | When to Use |
|---|---|---|---|---|
| Last-only | 27× activation | High (−16%) | Minimal | Extreme memory budget; linear head only |
| BN+Last | 1.8× activation | Moderate (−8%) | Minimal | When parameter count matters more than memory |
| Bias-only (TinyTL) | 12× activation | Moderate (−5%) | Same fwd; lighter bwd | Privacy-critical edge devices; no lite residual capacity |
| TinyTL full | 6–8× activation | Minimal (−1%) | +5% for lite residual fwd | MCU with 2–16 MB SRAM; best accuracy-per-byte |
| Sparse BP | 4.5–7.5× activation | Minimal (−0.5%) | +search cost once | Heterogeneous model families; LLM fine-tuning on edge GPU |
| Grad. Checkpoint | √L× activation | Zero | ~2× FLOPs | Full fine-tuning; sufficient compute; insufficient memory |
| INT8 + QAS | 4× (all tensors) | ~1% vs FP32 | Hardware-dependent | MCU with INT8 ALU; complement to all above methods |
| TTE/PockEngine | Graph pruning | Zero (system only) | 3–4× faster | Production deployment; compile-time fixed update spec |
This series traced a single arc: how do we bring state-of-the-art deep learning to devices with kilobyte budgets? Each lesson attacked one constraint. Here is the full stack, from model design down to on-device training:
| Lesson | Topic | Key Insight |
|---|---|---|
| L1–L2 | Pruning | 70–90% of weights are redundant; magnitude and structured pruning |
| L3 | Quantisation | INT8 is 4× smaller and 4× faster; scale calibration is the key problem |
| L4 | Low-Rank Decomposition | SVD + Tucker decomposition; LoRA as rank-constrained delta |
| L5 | Knowledge Distillation | Train small from large soft labels; intermediate feature matching |
| L6 | NAS: Efficiency | Search for architectures on the Pareto frontier of accuracy vs latency |
| L7 | NAS: Algorithms | DARTS, one-shot, hardware-aware search spaces |
| L8 | Transformers | Efficient attention: linear, sparse, MLA, FlashAttention IO complexity |
| L9 | LLM Efficiency | MQA, GQA, speculative decoding, MoE routing, quantised inference |
| L10 | MCU Inference (MCUNet) | Joint NAS over architecture + deployment schedule in 256 KB SRAM |
| L11 | Deployment Engines | TVM, MLIR, operator fusion, memory planning, runtime scheduling |
| L12 | Video Understanding | Temporal redundancy; 3D conv decomposition; token filtering |
| L13 | Point Clouds | Irregular 3D data; PointNet invariance; sparse voxel convolutions |
| L14 | GAN Compression | Discriminator pruning; progressive distillation; once-for-all GAN |
| L15 | Diffusion Compression | DDIM fewer steps; progressive distillation; latent diffusion |
| L16 | Distributed Training | Data/tensor/pipeline parallelism; ZeRO; 3D parallelism for 175B models |
| L17 | On-Device Training | Activation memory is the wall; bias-only/TinyTL/SparseBP/TTE break it |