Tensor parallelism maxes out at 8 GPUs per node — yet frontier runs use thousands. This lesson derives how to go further: split the model across nodes with pipeline parallelism (micro-batches, the bubble, 1F1B schedules), shard every byte with FSDP/ZeRO-3 (all-gather on-the-fly, memory ↔ comm tradeoff), add sequence parallelism for activations, and learn the proven recipe for combining all three into a real 3D run on 512+ GPUs.
You learned in Lec 7 that tensor parallelism (TP) splits weight matrices across GPUs, reducing memory by the TP degree. It works brilliantly — inside one node. An NVIDIA DGX A100 node has 8 GPUs connected by NVLink at 600 GB/s. Tensor parallel all-reduces take ~0.05 ms. You barely notice them.
Now scale to a real training run: 512 GPUs across 64 nodes. TP saturates at 8 — you can't extend it beyond the NVLink boundary without paying 24× the communication cost. But the model still doesn't fit on 8 GPUs if it's large enough, and 8 GPUs gives you only 8× compute speedup regardless. You need two more tools.
Let's make this concrete. Suppose you're training a 70B parameter model on 512 GPUs across 64 nodes (8 GPUs each). Here's the capacity problem: 70B × 2 bytes (bf16) = 140 GB for weights alone. One A100 has 80 GB. With TP=8 you get 140/8 = 17.5 GB for weights — that fits! But optimizer state in fp32 is 3 × 4 bytes × 70B = 840 GB total, divided by 8 is still 105 GB per GPU. It doesn't fit. Now also consider you want 64 × these "TP groups" for compute scaling. How do you coordinate the 64 replicas? Data parallelism says all-reduce the 70B gradients across 64 groups. That's a massive all-reduce. FSDP distributes that memory burden too.
And pipeline parallelism? Imagine a 70B model with 80 transformer layers. Instead of TP=8 struggling to fit them, assign layers 0–9 to stage 0 (one node), layers 10–19 to stage 1 (next node), etc. Each stage only stores 1/8 of the parameters. The data flows like an assembly line — a micro-batch enters stage 0, stage 0 sends activations to stage 1, and so on. This is the third axis of 3D parallelism. But assembly lines have idle time. Deriving and minimizing that idle time is this chapter's core problem.
Toggle total GPUs and TP degree. See how much memory per GPU remains and whether FSDP is needed on top.
The idea is elegant: a 96-layer transformer doesn't need all 96 layers on the same GPU. Partition the layers into P equal groups called pipeline stages. Stage 0 holds layers 0 through L/P-1, stage 1 holds layers L/P through 2L/P-1, and so on. Each stage lives on one GPU (or one group of GPUs if combined with TP). The forward pass is literally a pipeline: stage 0 computes its layers, emits the activation tensor, and hands it to stage 1 via a point-to-point send.
Point-to-point here is key. Unlike tensor parallelism's all-reduce (every GPU talks to every other), pipeline parallelism only ever sends activations between adjacent stages. The communication volume is the activation tensor shape: for a batch of B sequences of length s with hidden dimension h, the activation at each stage boundary is B × s × h × 2 bytes (bf16). For B=1, s=4096, h=8192: 4096 × 8192 × 2 = 64 MB per boundary. On InfiniBand at 25 GB/s, that's 64/25000 = 2.6 ms. Compare to TP's all-reduce per layer: per-layer all-reduce is 6× cheaper for TP=8 on NVLink, but TP needs it every single layer. Pipeline needs only one send/receive per stage boundary, not per layer.
The backward pass is more subtle. Backprop through a pipeline stage needs: (1) the gradients flowing backward from the next stage, and (2) the activations saved during the forward pass. The backward pass reverses the pipeline: stage P-1 computes first (it has the loss), sends gradient tensors back to stage P-2, and so on. Each stage needs its own activations from the forward pass — stored in memory during the wait. This storage cost is the other price of pipeline parallelism, and micro-batching (Chapter 3) addresses it.
python — minimal 2-stage pipeline (CS336 lecture_08.py) def pipeline_parallelism_main(rank, world_size, data, num_layers, num_micro_batches): local_num_layers = num_layers // world_size # e.g. 4 layers / 2 GPUs = 2 per GPU local_params = [get_init_params(num_dim, num_dim, rank) for _ in range(local_num_layers)] micro_batch_size = batch_size // num_micro_batches for x in micro_batches: if rank - 1 >= 0: dist.recv(tensor=x, src=rank - 1) # receive from previous stage for param in local_params: x = x @ param; x = F.gelu(x) # compute this stage's layers if rank + 1 < world_size: dist.send(tensor=x, dst=rank + 1) # send activations to next stage
Here's the problem with a naive pipeline schedule. You have P=4 stages and one mini-batch to process. Stage 0 starts immediately. But stage 1 can't start until stage 0 finishes its forward pass and sends the activation. Stage 2 must wait for stage 1. Stage 3 waits for stage 2. In the forward direction, stages 0, 1, 2, 3 start at times t=0, 1, 2, 3 (in units of one forward-pass time).
Then backward pass: stage 3 computes its backward pass first (it has the loss from the forward pass). But now stages 0, 1, 2 are idle again — waiting for the backward pass to propagate back to them. They finish the backward pass in reverse order: stage 3 at t=5, stage 2 at t=6, stage 1 at t=7, stage 0 at t=8. The total time is 8 units. Ideal time (if all stages ran in parallel) is 2 units (1 forward + 1 backward). The bubble costs us P-1 = 3 forward-pass-equivalents of startup and P-1 = 3 of draining.
Formally, for one micro-batch and P stages, the bubble fraction is:
For P=4 with one micro-batch: bubble = 3/7 ≈ 43%. Nearly half the time is wasted. The fix is micro-batching: instead of one large batch, split it into m smaller micro-batches. Stage 0 can immediately process micro-batch 1 after sending micro-batch 0 to stage 1. The pipeline stays busy. The bubble at startup and draining still costs P-1 units of idle, but now the pipeline runs for m steps total — so the bubble fraction shrinks:
Let's verify with a concrete example. P=8 pipeline stages, m=32 micro-batches:
With m=1 (no micro-batching) and P=8, the bubble would be 7/15 = 47%. Adding 32 micro-batches brings it down to 18%. Going to m=64: 7/71 = 9.9%. The tradeoff: more micro-batches reduce the bubble but increase the pipeline buffer memory (you're keeping more activations in flight simultaneously).
Adjust P and m to see bubble fraction update live. Each cell is one stage processing one micro-batch forward or backward.
The naive schedule (sometimes called GPipe schedule after the 2018 Google paper) flushes all m micro-batches forward before starting any backward pass. It keeps the bubble formula (P-1)/(m+P-1) but has a painful memory cost: all m micro-batch activations must be stored simultaneously in the pipeline buffer. Each activation tensor is B×s×h; for m=32 and a 70B model stage, that's 32 × 64 MB = 2 GB of extra buffer per stage. This may not be a problem for small P, but it compounds.
The 1F1B schedule (one-forward-one-backward, from the PipeDream paper) interleaves forward and backward passes at each stage. As soon as stage k finishes the forward pass for micro-batch i, it doesn't wait — it also starts the backward pass for micro-batch i-1 (which arrived back from stage k+1). The pipeline keeps the same bubble fraction as GPipe but reduces peak activation memory to O(P) instead of O(m × P).
There's a subtler variant called interleaved 1F1B (Megatron-LM v3): instead of assigning P contiguous layers to stage k, you assign multiple non-contiguous "virtual stages" per GPU. E.g. with 4 real stages and 2 virtual stages each, GPU 0 holds layers {0,1} and {16,17}, GPU 1 holds {2,3} and {18,19}, etc. The pipeline now has P×v virtual stages. The bubble fraction becomes (P-1)/(v×m + P-1) — significantly lower because the effective depth of the pipeline (for bubble purposes) increases. The cost: each virtual-stage boundary requires an activation send/receive, so communication doubles. Worthwhile when m is small.
python — 1F1B schedule conceptual sketch (FSDP/pipeline) # Forward queue: next micro-batch to push forward through this stage # Backward queue: micro-batch whose grads arrived from next stage for step in range(num_micro_batches + warmup_steps): if can_forward(step): x = forward_pass(micro_batches[fwd_idx]) activation_buffer[fwd_idx] = x # save for backward send_activation(x, dst=rank + 1) fwd_idx += 1 if can_backward(step): grad_out = recv_gradient(src=rank + 1) backward_pass(activation_buffer[bwd_idx], grad_out) send_gradient(grad, dst=rank - 1) del activation_buffer[bwd_idx] # free immediately bwd_idx += 1
Teal = forward pass, orange = backward pass, gray = bubble. Compare total idle time (bubble %) between schedules.
Data parallelism (DP) with P replicas costs P times the communication in gradient all-reduce. But the deeper problem is memory: each DP rank holds a full copy of parameters, gradients, and optimizer state. A 70B model with AdamW uses 16 bytes per parameter = 1.12 TB per GPU if naive. ZeRO stages 1–3 shard these three components progressively. FSDP (Fully Sharded Data Parallel, PyTorch's implementation of ZeRO-3) shards all three.
The key idea of ZeRO-3 / FSDP: don't store the full parameter tensor on any GPU. Instead, each of the D data-parallel ranks holds only 1/D of each parameter tensor. When you need a full parameter tensor (for a forward or backward computation), you reconstruct it on-the-fly by doing an all-gather across D ranks. After the computation, discard the gathered tensor and keep only your shard. This way you only ever pay for 1/D of the parameters in steady state.
Compare to naive DDP: DDP all-reduces 2× the parameter size (one pass = one ring all-reduce = reduce-scatter + all-gather = 2P bytes). ZeRO-3/FSDP costs 3P bytes but eliminates the memory: naive DDP stores 16 bytes/param × P params total per GPU, FSDP stores 16/D bytes/param per GPU. For a 70B model on D=64 data-parallel ranks: naive DDP = 1.12 TB per GPU (impossible), FSDP = 1.12 TB / 64 = 17.5 GB per GPU (fits with room to spare).
python — FSDP wrap pattern (PyTorch) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # Wrap at the transformer block level (shard per block, not per layer) auto_wrap = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock} # your block class ) model = FSDP( model, auto_wrap_policy=auto_wrap, sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 mixed_precision=MixedPrecision( param_dtype=torch.bfloat16, # store params as bf16 reduce_dtype=torch.float32, # accumulate grads in fp32 ), device_id=torch.cuda.current_device(), )
Let's do the full accounting. A 70B model with AdamW mixed precision uses 16 bytes per parameter in the naive case. What does FSDP/ZeRO-3 save, and what is the remaining memory per GPU?
Parameter shards. Each of D data-parallel ranks holds 1/D of every parameter. In FSDP with bf16 parameters: 2 bytes × P / D. For P=70B, D=64: 2 × 70B / 64 = 2.19 GB. When gathered for computation, the full tensor temporarily occupies 2 bytes × P on the calling GPU — but only for one layer at a time (block-level sharding), not the whole model at once.
Gradient shards. Same story — reduce-scatter means each rank accumulates only its shard. With bf16 gradients: 2 bytes × P / D. Same 2.19 GB.
Optimizer state shards. fp32 master weights + Adam m1 + Adam m2 = 12 bytes/param. After ZeRO-3 sharding: 12 × P / D = 12 × 70B / 64 = 13.1 GB. This is the dominant term.
Total steady-state per GPU: (2 + 2 + 12) × 70B / 64 = 16 × 70B / 64 = 17.5 GB. We've fit 70B on 64 GPUs with only 17.5 GB per GPU for model state! An A100 80 GB has 80 − 17.5 = 62.5 GB headroom for activations.
Communication volume comparison. Let M = total parameter bytes (140 GB for 70B bf16). Per training step:
| Strategy | Comm volume / step | Memory / GPU (70B, D=64) | Fits on A100? |
|---|---|---|---|
| Naive DDP | 2M (all-reduce) | 1120 GB | No (14× over) |
| ZeRO-1 | 2M (RS + AG) | ~230 GB (optimizer sharded) | No (2.9×) |
| ZeRO-2 | 2M (RS + AG) | ~160 GB (opt + grad sharded) | No (2×) |
| ZeRO-3 / FSDP | 3M (AG+AG+RS) | 17.5 GB | Yes (4.6× margin) |
Watch the memory bar as FSDP gathers each layer's params, computes, then frees them. Each layer spike = one all-gather.
You've sharded parameters (FSDP), split layers (PP), and split weight matrices (TP). But there's one more memory consumer that none of these address directly: activation memory. For a transformer layer with batch B, sequence length s, and hidden dimension h, the activations stored for backpropagation cost roughly 10 × B × s × h × 2 bytes. For s=8192, B=1, h=8192, one layer = 10 × 8192 × 8192 × 2 = 1.34 GB. With 96 layers that's 128 GB — already over an A100.
Tensor parallelism (from Lec 7) splits the FFN and attention projections across TP ranks. The activations inside those computations are sharded. But some operations are not split: LayerNorm, Dropout, and the residual add all operate on the full B×s×h tensor replicated on every TP rank. These terms can dominate activation memory at long context.
Sequence parallelism (Korthikanti et al. 2022, Megatron-LM) solves this by splitting the sequence dimension across TP ranks for the LayerNorm and Dropout regions. In the regions that are not inside TP shards, each TP rank holds only 1/TP of the sequence dimension. The handoff between "sequence-parallel" and "tensor-parallel" modes requires an all-gather or reduce-scatter at each boundary, but these can be fused with the existing TP communication ops.
A different kind of sequence parallelism, sometimes called context parallelism or ring attention, splits the sequence dimension across a separate dimension of GPUs. This is useful when the sequence length itself is the bottleneck (e.g., 1M-context models). Ring attention generalizes FlashAttention to distributed attention: GPU 0 processes query tokens 0-s/P, GPU 1 processes s/P to 2s/P, etc., but the key-value pairs rotate in a ring so every query can attend to every key. This requires 2× the activation bandwidth of local attention but enables sequences that can't fit on any single GPU.
| Technique | What's split | Split dimension | Comm cost per layer |
|---|---|---|---|
| Tensor Parallel (TP) | Weight matrices, FFN internals | Hidden dmodel | All-reduce B×s×h |
| Sequence Parallel (SP) | LayerNorm + Dropout activations | Sequence s | All-gather + RS, fused with TP |
| Context Parallel (CP) | Attention Q/K/V computations | Sequence s (inter-node) | Ring: 2× local attention BW |
Production frontier runs combine all three axes simultaneously. The standard recipe (Megatron-LM, Narayanan et al. 2021): TP within each node (uses NVLink, no inter-node latency), PP across nodes (only activation tensors cross the network, one per stage boundary, not per layer), DP for the rest (replicate the TP+PP group, use ZeRO-1 or ZeRO-2 on top for free memory wins). Total GPU count: TP × PP × DP.
Let's verify the math with a real example: 512 GPUs, 70B model. Cluster: 64 nodes × 8 GPUs each. Recipe: TP=8 (fill one node), PP=8 (8 nodes in a pipeline), DP=8 (8 copies of the TP+PP group). Total: 8 × 8 × 8 = 512. ✓
Memory per GPU (TP=8, PP=8, ZeRO-1 on DP=8): weights = 140 GB / (8×8) = 2.19 GB. Adam optimizer state = 840 GB / (8×8×8) = 1.64 GB. Gradients (replicated in ZeRO-1) = 140 GB / 64 = 2.19 GB. Activation memory (sequence s=4096, batch B=1, 96/8=12 layers per stage, h=8192): 12 × 10 × 4096 × 8192 × 2 bytes = 8.06 GB (with TP and SP, ÷8 ≈ 1 GB). Total model state ≈ 7 GB per GPU — with an 80 GB A100, that's 73 GB for activations and batch scaling. A batch of 32 sequences at s=4096 adds 32 × 1 GB = 32 GB. Fits with headroom.
Tune TP, PP, DP, and micro-batches. See total GPUs, memory per GPU, bubble fraction, and estimated comm overhead.
Even with all four forms of parallelism — TP, PP, DP, SP — activation memory can be the binding constraint at long context. The activations stored for the backward pass grow as O(L × B × s × h). Pipeline parallelism helps by splitting L, tensor + sequence parallelism split h and s. But with context lengths of 8192, 32768, or more, activations still dominate.
Activation recomputation (also called gradient checkpointing) trades memory for compute: instead of storing all intermediate activations, you store only "checkpoints" at selected points in the network, and recompute the intermediate values on-the-fly during the backward pass. The classic full recomputation stores only the input to each transformer block, recomputing the entire block during backward. Cost: one extra forward pass worth of compute (33% overhead). Benefit: activation memory drops from O(L × s × h) to O(L), because only the block inputs are stored.
The decision rule: compare the time to recompute a checkpoint versus the time saved by not keeping those activations in HBM. If arithmetic intensity is high (heavy recompute cost) and the activations are small, store them. If activations are large (long sequence, large batch) and the recompute is cheap (a few elementwise ops), recompute them. In practice:
| Activation type | Size | Recompute cost | Recommendation |
|---|---|---|---|
| Attention scores Q·KT/√d | B × h × s × s × 2 | O(s²·d) FLOPs | Recompute (FlashAttn already does this) |
| Attention dropout | B × h × s × s × 1 bit | Cheap (Bernoulli) | Recompute (binary mask, tiny) |
| FFN activation (GELU) | B × s × 4h × 2 | O(s·h) FLOPs | Store (cheap, but measurable) |
| LayerNorm inputs | B × s × h × 2 | One sum + sqrt | Recompute if SP is active |
DDP vs ZeRO-1/2/3/FSDP — see how memory and communication trade off. Includes activation cost at given sequence length.
You now have all three axes of 3D parallelism plus the supporting tools that make them work at scale. Here's a decision guide for a real run:
| Technique | Memory win | Compute cost | Comm cost | Best for |
|---|---|---|---|---|
| TP | ÷TP (weights + activations) | None | All-reduce per layer (NVLink) | Intra-node, always on |
| PP | ÷PP (weights + opt state) | Bubble fraction | Activation send per stage | Inter-node, large models |
| DP | None (naive) | None | Gradient all-reduce | Compute scaling |
| ZeRO-1/2 | ÷DP (opt/grad) | None | Same as DP | Free win on top of DP |
| FSDP (ZeRO-3) | ÷DP (all) | None | 3M vs 2M (50% more) | Cannot fit even with TP+PP |
| Sequence Parallel | ÷TP (LayerNorm act.) | None | Fused with TP comm | Long context + large TP |
| Activation Recompute | ~O(L) activations | +33% (full), ~5% (selective) | None | Long context, large batch |
Set PP stages. See bubble fraction curve as m increases. The dashed line marks 20% — a common practical target.
"The hardware is getting faster, but we will always want to train bigger models — so this hierarchical, multi-axis structure of parallelism will always be with us. The only thing that changes is which bandwidth is the bottleneck."
— Percy Liang, CS336 Lecture 8