Language Modeling from Scratch · CS336 · Lecture 8

Parallelism II: Pipeline & FSDP

Tensor parallelism maxes out at 8 GPUs per node — yet frontier runs use thousands. This lesson derives how to go further: split the model across nodes with pipeline parallelism (micro-batches, the bubble, 1F1B schedules), shard every byte with FSDP/ZeRO-3 (all-gather on-the-fly, memory ↔ comm tradeoff), add sequence parallelism for activations, and learn the proven recipe for combining all three into a real 3D run on 512+ GPUs.

Prerequisites: CS336 Lec 7 (collective ops, data + tensor parallelism, ZeRO stages). Basic PyTorch.
10
Chapters
5
Live Canvases
Derived
Bubble Formula

Chapter 0: Tensor Parallelism Tops Out at 8 GPUs

You learned in Lec 7 that tensor parallelism (TP) splits weight matrices across GPUs, reducing memory by the TP degree. It works brilliantly — inside one node. An NVIDIA DGX A100 node has 8 GPUs connected by NVLink at 600 GB/s. Tensor parallel all-reduces take ~0.05 ms. You barely notice them.

Now scale to a real training run: 512 GPUs across 64 nodes. TP saturates at 8 — you can't extend it beyond the NVLink boundary without paying 24× the communication cost. But the model still doesn't fit on 8 GPUs if it's large enough, and 8 GPUs gives you only 8× compute speedup regardless. You need two more tools.

The gap tensor parallelism leaves open. TP handles memory within a node and compute within a node. But it doesn't give you inter-node compute scaling, and it requires a batch-independent all-reduce every transformer block. For 512 GPUs you need two orthogonal axes: one that goes across nodes without per-layer communication (pipeline parallelism), and one that gives you linear memory scaling at the cost of controlled communication (FSDP / ZeRO-3). This lesson derives both.

Let's make this concrete. Suppose you're training a 70B parameter model on 512 GPUs across 64 nodes (8 GPUs each). Here's the capacity problem: 70B × 2 bytes (bf16) = 140 GB for weights alone. One A100 has 80 GB. With TP=8 you get 140/8 = 17.5 GB for weights — that fits! But optimizer state in fp32 is 3 × 4 bytes × 70B = 840 GB total, divided by 8 is still 105 GB per GPU. It doesn't fit. Now also consider you want 64 × these "TP groups" for compute scaling. How do you coordinate the 64 replicas? Data parallelism says all-reduce the 70B gradients across 64 groups. That's a massive all-reduce. FSDP distributes that memory burden too.

And pipeline parallelism? Imagine a 70B model with 80 transformer layers. Instead of TP=8 struggling to fit them, assign layers 0–9 to stage 0 (one node), layers 10–19 to stage 1 (next node), etc. Each stage only stores 1/8 of the parameters. The data flows like an assembly line — a micro-batch enters stage 0, stage 0 sends activations to stage 1, and so on. This is the third axis of 3D parallelism. But assembly lines have idle time. Deriving and minimizing that idle time is this chapter's core problem.

Why TP alone is not enough: memory breakdown at scale

Toggle total GPUs and TP degree. See how much memory per GPU remains and whether FSDP is needed on top.

Model size (B params) 70
TP degree (within node) 8
DP replicas 8
Why can't you just use TP=64 to scale tensor parallelism across 64 GPUs spread over 8 nodes?

Chapter 1: Pipeline Parallelism: Splitting by Depth

The idea is elegant: a 96-layer transformer doesn't need all 96 layers on the same GPU. Partition the layers into P equal groups called pipeline stages. Stage 0 holds layers 0 through L/P-1, stage 1 holds layers L/P through 2L/P-1, and so on. Each stage lives on one GPU (or one group of GPUs if combined with TP). The forward pass is literally a pipeline: stage 0 computes its layers, emits the activation tensor, and hands it to stage 1 via a point-to-point send.

Point-to-point here is key. Unlike tensor parallelism's all-reduce (every GPU talks to every other), pipeline parallelism only ever sends activations between adjacent stages. The communication volume is the activation tensor shape: for a batch of B sequences of length s with hidden dimension h, the activation at each stage boundary is B × s × h × 2 bytes (bf16). For B=1, s=4096, h=8192: 4096 × 8192 × 2 = 64 MB per boundary. On InfiniBand at 25 GB/s, that's 64/25000 = 2.6 ms. Compare to TP's all-reduce per layer: per-layer all-reduce is 6× cheaper for TP=8 on NVLink, but TP needs it every single layer. Pipeline needs only one send/receive per stage boundary, not per layer.

Memory wins from pipeline parallelism. With P pipeline stages, each stage stores only L/P layers worth of parameters, gradients, and optimizer state. For a 70B model split into P=8 stages: weights per stage ≈ 140/8 = 17.5 GB, optimizer state per stage ≈ 840/8 = 105 GB, total ≈ 122 GB — still doesn't fit on one 80 GB A100! But combine with TP=8 within each stage-node: parameters divide by another 8, bringing weights down to 2.2 GB and optimizer state down to ~13 GB per GPU. Now it fits with headroom. This is why PP and TP are complementary, not alternatives.

The backward pass is more subtle. Backprop through a pipeline stage needs: (1) the gradients flowing backward from the next stage, and (2) the activations saved during the forward pass. The backward pass reverses the pipeline: stage P-1 computes first (it has the loss), sends gradient tensors back to stage P-2, and so on. Each stage needs its own activations from the forward pass — stored in memory during the wait. This storage cost is the other price of pipeline parallelism, and micro-batching (Chapter 3) addresses it.

python — minimal 2-stage pipeline (CS336 lecture_08.py)
def pipeline_parallelism_main(rank, world_size, data, num_layers, num_micro_batches):
    local_num_layers = num_layers // world_size  # e.g. 4 layers / 2 GPUs = 2 per GPU
    local_params = [get_init_params(num_dim, num_dim, rank) for _ in range(local_num_layers)]
    micro_batch_size = batch_size // num_micro_batches

    for x in micro_batches:
        if rank - 1 >= 0:
            dist.recv(tensor=x, src=rank - 1)  # receive from previous stage
        for param in local_params:
            x = x @ param; x = F.gelu(x)      # compute this stage's layers
        if rank + 1 < world_size:
            dist.send(tensor=x, dst=rank + 1)  # send activations to next stage
What each stage stores. During the forward pass of micro-batch i, stage k must hold the activation output it will need for backprop — because by the time backprop reaches stage k, stage k has already moved on to the next micro-batch's forward pass. This intermediate activation storage is called the pipeline buffer. For m micro-batches and P stages, stage k is buffering up to min(k, P-k) activations simultaneously. This grows with P and is one reason you want large m (many micro-batches) relative to P.
Pipeline parallelism sends activations between stages. For B=2, s=2048, h=4096 in bf16, what is the size of one activation tensor passed at a stage boundary?

Chapter 2: The Pipeline Bubble: Deriving the Idle Fraction

Here's the problem with a naive pipeline schedule. You have P=4 stages and one mini-batch to process. Stage 0 starts immediately. But stage 1 can't start until stage 0 finishes its forward pass and sends the activation. Stage 2 must wait for stage 1. Stage 3 waits for stage 2. In the forward direction, stages 0, 1, 2, 3 start at times t=0, 1, 2, 3 (in units of one forward-pass time).

Then backward pass: stage 3 computes its backward pass first (it has the loss from the forward pass). But now stages 0, 1, 2 are idle again — waiting for the backward pass to propagate back to them. They finish the backward pass in reverse order: stage 3 at t=5, stage 2 at t=6, stage 1 at t=7, stage 0 at t=8. The total time is 8 units. Ideal time (if all stages ran in parallel) is 2 units (1 forward + 1 backward). The bubble costs us P-1 = 3 forward-pass-equivalents of startup and P-1 = 3 of draining.

Formally, for one micro-batch and P stages, the bubble fraction is:

Bubble fraction (m=1) = (P−1) / (P+P−1) = (P−1) / (2P−1)

For P=4 with one micro-batch: bubble = 3/7 ≈ 43%. Nearly half the time is wasted. The fix is micro-batching: instead of one large batch, split it into m smaller micro-batches. Stage 0 can immediately process micro-batch 1 after sending micro-batch 0 to stage 1. The pipeline stays busy. The bubble at startup and draining still costs P-1 units of idle, but now the pipeline runs for m steps total — so the bubble fraction shrinks:

Bubble fraction = (P−1) / (m + P−1)

Let's verify with a concrete example. P=8 pipeline stages, m=32 micro-batches:

Bubble fraction = (8−1) / (32 + 8−1) = 7/39 ≈ 17.9%

With m=1 (no micro-batching) and P=8, the bubble would be 7/15 = 47%. Adding 32 micro-batches brings it down to 18%. Going to m=64: 7/71 = 9.9%. The tradeoff: more micro-batches reduce the bubble but increase the pipeline buffer memory (you're keeping more activations in flight simultaneously).

Hand derivation of the formula. Think of the timeline as a grid: P rows (stages) × total time columns. The pipeline starts when stage 0 begins micro-batch 0 forward. The last useful work finishes when stage 0 completes micro-batch m-1 backward. Total elapsed time = (startup fill = P-1 steps) + (m steps of full pipeline) + (drain = P-1 steps) = m + 2(P-1) steps. Useful work time = m steps (one per micro-batch, all stages occupied). But wait — we count in units where each "cell" is one stage's work on one micro-batch. For forward only: total = m + (P-1). For forward+backward: total = 2m + 2(P-1). The fraction that is bubble = 2(P-1) / (2m + 2(P-1)) = (P-1)/(m + P-1). ◾
Pipeline bubble: naive schedule (P stages, m=1 micro-batch)

Adjust P and m to see bubble fraction update live. Each cell is one stage processing one micro-batch forward or backward.

Pipeline stages P 4
Micro-batches m 1
A pipeline with P=16 stages and m=47 micro-batches. What is the bubble fraction?

Chapter 3: GPipe & 1F1B: Smarter Schedules

The naive schedule (sometimes called GPipe schedule after the 2018 Google paper) flushes all m micro-batches forward before starting any backward pass. It keeps the bubble formula (P-1)/(m+P-1) but has a painful memory cost: all m micro-batch activations must be stored simultaneously in the pipeline buffer. Each activation tensor is B×s×h; for m=32 and a 70B model stage, that's 32 × 64 MB = 2 GB of extra buffer per stage. This may not be a problem for small P, but it compounds.

The 1F1B schedule (one-forward-one-backward, from the PipeDream paper) interleaves forward and backward passes at each stage. As soon as stage k finishes the forward pass for micro-batch i, it doesn't wait — it also starts the backward pass for micro-batch i-1 (which arrived back from stage k+1). The pipeline keeps the same bubble fraction as GPipe but reduces peak activation memory to O(P) instead of O(m × P).

1F1B memory win. In 1F1B, each stage only holds one micro-batch's activation at a time (it runs backward before starting the next forward). Pipeline buffer memory is proportional to P (the number of stages that are simultaneously "in flight") rather than m. For a 70B model with P=16 stages and m=64 micro-batches: GPipe uses 64 activations, 1F1B uses 16. For the same bubble fraction (~20%), 1F1B uses 4× less pipeline buffer memory. Production systems almost always use 1F1B or its variants for this reason.

There's a subtler variant called interleaved 1F1B (Megatron-LM v3): instead of assigning P contiguous layers to stage k, you assign multiple non-contiguous "virtual stages" per GPU. E.g. with 4 real stages and 2 virtual stages each, GPU 0 holds layers {0,1} and {16,17}, GPU 1 holds {2,3} and {18,19}, etc. The pipeline now has P×v virtual stages. The bubble fraction becomes (P-1)/(v×m + P-1) — significantly lower because the effective depth of the pipeline (for bubble purposes) increases. The cost: each virtual-stage boundary requires an activation send/receive, so communication doubles. Worthwhile when m is small.

python — 1F1B schedule conceptual sketch (FSDP/pipeline)
# Forward queue: next micro-batch to push forward through this stage
# Backward queue: micro-batch whose grads arrived from next stage
for step in range(num_micro_batches + warmup_steps):
    if can_forward(step):
        x = forward_pass(micro_batches[fwd_idx])
        activation_buffer[fwd_idx] = x        # save for backward
        send_activation(x, dst=rank + 1)
        fwd_idx += 1
    if can_backward(step):
        grad_out = recv_gradient(src=rank + 1)
        backward_pass(activation_buffer[bwd_idx], grad_out)
        send_gradient(grad, dst=rank - 1)
        del activation_buffer[bwd_idx]       # free immediately
        bwd_idx += 1
Zero-bubble pipeline (Qi et al. 2023). Even 1F1B has a bubble: the very first stage sits idle during the "drain" as the backward pass ripples back from stage P-1. Zero-bubble pipelining splits the backward pass into two parts: backward-for-activation (B) computes the gradient of the activation (what flows backward to the previous stage) and backward-for-weights (W) computes the gradient of this stage's own parameters. B is on the critical path; W can be deferred. By scheduling deferred W passes during what would have been idle time, ZB-1p achieves near-zero bubble at the cost of holding more intermediate activation data.
GPipe vs 1F1B schedule visualization (P stages, m micro-batches)

Teal = forward pass, orange = backward pass, gray = bubble. Compare total idle time (bubble %) between schedules.

Pipeline stages P 4
Micro-batches m 4
1F1B and GPipe have the same bubble fraction. What does 1F1B improve?

Chapter 4: FSDP / ZeRO-3: Shard Everything On the Fly

Data parallelism (DP) with P replicas costs P times the communication in gradient all-reduce. But the deeper problem is memory: each DP rank holds a full copy of parameters, gradients, and optimizer state. A 70B model with AdamW uses 16 bytes per parameter = 1.12 TB per GPU if naive. ZeRO stages 1–3 shard these three components progressively. FSDP (Fully Sharded Data Parallel, PyTorch's implementation of ZeRO-3) shards all three.

The key idea of ZeRO-3 / FSDP: don't store the full parameter tensor on any GPU. Instead, each of the D data-parallel ranks holds only 1/D of each parameter tensor. When you need a full parameter tensor (for a forward or backward computation), you reconstruct it on-the-fly by doing an all-gather across D ranks. After the computation, discard the gathered tensor and keep only your shard. This way you only ever pay for 1/D of the parameters in steady state.

FSDP is not free: the all-gather dance. Every forward pass through a layer requires first gathering the full parameter from all D shards. This takes time proportional to the parameter size divided by bandwidth. FSDP implementations overlap the all-gather for layer k+1 with the forward computation of layer k — prefetching so you don't stall. Similarly, after the backward pass computes gradients for a layer, you immediately reduce-scatter those gradients to distribute the shard updates. Discard everything except your shard. The net cost: 2× the all-gather volume (once for forward, once for backward) plus 1× reduce-scatter = 3× the parameter size in communication per training step.

Compare to naive DDP: DDP all-reduces 2× the parameter size (one pass = one ring all-reduce = reduce-scatter + all-gather = 2P bytes). ZeRO-3/FSDP costs 3P bytes but eliminates the memory: naive DDP stores 16 bytes/param × P params total per GPU, FSDP stores 16/D bytes/param per GPU. For a 70B model on D=64 data-parallel ranks: naive DDP = 1.12 TB per GPU (impossible), FSDP = 1.12 TB / 64 = 17.5 GB per GPU (fits with room to spare).

python — FSDP wrap pattern (PyTorch)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Wrap at the transformer block level (shard per block, not per layer)
auto_wrap = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock}  # your block class
)
model = FSDP(
    model,
    auto_wrap_policy=auto_wrap,
    sharding_strategy=ShardingStrategy.FULL_SHARD,   # ZeRO-3
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,                  # store params as bf16
        reduce_dtype=torch.float32,                  # accumulate grads in fp32
    ),
    device_id=torch.cuda.current_device(),
)
The FSDP tradeoff in one sentence. FSDP trades 50% more communication for linear memory scaling. Naive DDP: 2P comm, O(P/D) memory per GPU. FSDP: 3P comm, O(P/D) memory per GPU — same result but achieves it with fundamentally different memory behavior. For runs that couldn't fit at all with DDP, FSDP is the only option regardless of communication cost.
In FSDP, layer k's forward pass needs the full parameter W_k. What happens right before the forward compute of layer k?

Chapter 5: FSDP Memory Math: Deriving Per-GPU Consumption

Let's do the full accounting. A 70B model with AdamW mixed precision uses 16 bytes per parameter in the naive case. What does FSDP/ZeRO-3 save, and what is the remaining memory per GPU?

Parameter shards. Each of D data-parallel ranks holds 1/D of every parameter. In FSDP with bf16 parameters: 2 bytes × P / D. For P=70B, D=64: 2 × 70B / 64 = 2.19 GB. When gathered for computation, the full tensor temporarily occupies 2 bytes × P on the calling GPU — but only for one layer at a time (block-level sharding), not the whole model at once.

Gradient shards. Same story — reduce-scatter means each rank accumulates only its shard. With bf16 gradients: 2 bytes × P / D. Same 2.19 GB.

Optimizer state shards. fp32 master weights + Adam m1 + Adam m2 = 12 bytes/param. After ZeRO-3 sharding: 12 × P / D = 12 × 70B / 64 = 13.1 GB. This is the dominant term.

Total steady-state per GPU: (2 + 2 + 12) × 70B / 64 = 16 × 70B / 64 = 17.5 GB. We've fit 70B on 64 GPUs with only 17.5 GB per GPU for model state! An A100 80 GB has 80 − 17.5 = 62.5 GB headroom for activations.

Peak memory during all-gather. The steady-state is 17.5 GB but during the forward pass of one block, FSDP gathers the full block parameters — temporarily adding (full block size / total layers) = roughly 70B/96 × 2 bytes ≈ 1.46 GB for a 96-layer model. With 64 data-parallel ranks that's still just 17.5 + 1.5 = 19 GB peak. Manageable. The activation memory (see Chapter 8) is the real concern at scale.

Communication volume comparison. Let M = total parameter bytes (140 GB for 70B bf16). Per training step:

StrategyComm volume / stepMemory / GPU (70B, D=64)Fits on A100?
Naive DDP2M (all-reduce)1120 GBNo (14× over)
ZeRO-12M (RS + AG)~230 GB (optimizer sharded)No (2.9×)
ZeRO-22M (RS + AG)~160 GB (opt + grad sharded)No (2×)
ZeRO-3 / FSDP3M (AG+AG+RS)17.5 GBYes (4.6× margin)
FSDP layer-by-layer memory animation: gather → compute → free

Watch the memory bar as FSDP gathers each layer's params, computes, then frees them. Each layer spike = one all-gather.

Model layers L 8
DP replicas D 8
FSDP costs 3M communication per step vs DDP's 2M. Yet FSDP is preferred for large models. What is the decisive advantage?

Chapter 6: Sequence Parallelism: Splitting the Activations

You've sharded parameters (FSDP), split layers (PP), and split weight matrices (TP). But there's one more memory consumer that none of these address directly: activation memory. For a transformer layer with batch B, sequence length s, and hidden dimension h, the activations stored for backpropagation cost roughly 10 × B × s × h × 2 bytes. For s=8192, B=1, h=8192, one layer = 10 × 8192 × 8192 × 2 = 1.34 GB. With 96 layers that's 128 GB — already over an A100.

Tensor parallelism (from Lec 7) splits the FFN and attention projections across TP ranks. The activations inside those computations are sharded. But some operations are not split: LayerNorm, Dropout, and the residual add all operate on the full B×s×h tensor replicated on every TP rank. These terms can dominate activation memory at long context.

Sequence parallelism (Korthikanti et al. 2022, Megatron-LM) solves this by splitting the sequence dimension across TP ranks for the LayerNorm and Dropout regions. In the regions that are not inside TP shards, each TP rank holds only 1/TP of the sequence dimension. The handoff between "sequence-parallel" and "tensor-parallel" modes requires an all-gather or reduce-scatter at each boundary, but these can be fused with the existing TP communication ops.

Sequence parallelism reduces activation memory by TP×. If TP=8, LayerNorm/Dropout activations are stored at 1/8 of their normal size. For the 96-layer 70B model above, the 128 GB activation footprint shrinks to ~16 GB from the LayerNorm/Dropout terms alone. Combined with selective recomputation (Chapter 8), activation memory can be reduced to fit any practical context length.

A different kind of sequence parallelism, sometimes called context parallelism or ring attention, splits the sequence dimension across a separate dimension of GPUs. This is useful when the sequence length itself is the bottleneck (e.g., 1M-context models). Ring attention generalizes FlashAttention to distributed attention: GPU 0 processes query tokens 0-s/P, GPU 1 processes s/P to 2s/P, etc., but the key-value pairs rotate in a ring so every query can attend to every key. This requires 2× the activation bandwidth of local attention but enables sequences that can't fit on any single GPU.

TechniqueWhat's splitSplit dimensionComm cost per layer
Tensor Parallel (TP)Weight matrices, FFN internalsHidden dmodelAll-reduce B×s×h
Sequence Parallel (SP)LayerNorm + Dropout activationsSequence sAll-gather + RS, fused with TP
Context Parallel (CP)Attention Q/K/V computationsSequence s (inter-node)Ring: 2× local attention BW
Why SP is "free" on top of TP. TP already communicates an all-gather and reduce-scatter at each transformer block boundary. Sequence parallelism fuses its communication at exactly those same boundaries: the all-gather that reconstructs the full hidden dimension for TP computation also reconstructs the full sequence dimension for SP. In well-optimized implementations (e.g., Megatron-LM), SP adds almost zero latency because the communication kernels were already there.
Sequence parallelism splits activations along the sequence dimension. Which operation specifically motivates this — i.e., which layer is NOT already sharded by tensor parallelism alone?

Chapter 7: Showcase: 3D Parallelism Configurator

Production frontier runs combine all three axes simultaneously. The standard recipe (Megatron-LM, Narayanan et al. 2021): TP within each node (uses NVLink, no inter-node latency), PP across nodes (only activation tensors cross the network, one per stage boundary, not per layer), DP for the rest (replicate the TP+PP group, use ZeRO-1 or ZeRO-2 on top for free memory wins). Total GPU count: TP × PP × DP.

Let's verify the math with a real example: 512 GPUs, 70B model. Cluster: 64 nodes × 8 GPUs each. Recipe: TP=8 (fill one node), PP=8 (8 nodes in a pipeline), DP=8 (8 copies of the TP+PP group). Total: 8 × 8 × 8 = 512. ✓

Memory per GPU (TP=8, PP=8, ZeRO-1 on DP=8): weights = 140 GB / (8×8) = 2.19 GB. Adam optimizer state = 840 GB / (8×8×8) = 1.64 GB. Gradients (replicated in ZeRO-1) = 140 GB / 64 = 2.19 GB. Activation memory (sequence s=4096, batch B=1, 96/8=12 layers per stage, h=8192): 12 × 10 × 4096 × 8192 × 2 bytes = 8.06 GB (with TP and SP, ÷8 ≈ 1 GB). Total model state ≈ 7 GB per GPU — with an 80 GB A100, that's 73 GB for activations and batch scaling. A batch of 32 sequences at s=4096 adds 32 × 1 GB = 32 GB. Fits with headroom.

Real run numbers (Llama 3 405B training, Meta 2024). TP=8 (within H100 node), PP=16 (across nodes), DP=16 (replicas). Total: 8 × 16 × 16 = 2048 H100 GPUs. Global batch size: 16,384 sequences × 8192 tokens. Micro-batch per pipeline stage: 128. Bubble fraction: (16-1)/(128+15) = 15/143 ≈ 10.5%. Hardware utilization reported: ~38% MFU (model FLOP utilization) — the rest is communication, bubble, and memory management.
3D Parallelism Configurator

Tune TP, PP, DP, and micro-batches. See total GPUs, memory per GPU, bubble fraction, and estimated comm overhead.

Model size (B params) 70
Tensor Parallel TP (1-8, within node) 8
Pipeline Parallel PP (across nodes) 8
Data Parallel DP (replicas) 8
Micro-batches m 32
A 3D run with TP=8, PP=16, DP=16, m=64 micro-batches. What is the bubble fraction and total GPU count?

Chapter 8: Activation Recomputation: Trading FLOPs for Memory

Even with all four forms of parallelism — TP, PP, DP, SP — activation memory can be the binding constraint at long context. The activations stored for the backward pass grow as O(L × B × s × h). Pipeline parallelism helps by splitting L, tensor + sequence parallelism split h and s. But with context lengths of 8192, 32768, or more, activations still dominate.

Activation recomputation (also called gradient checkpointing) trades memory for compute: instead of storing all intermediate activations, you store only "checkpoints" at selected points in the network, and recompute the intermediate values on-the-fly during the backward pass. The classic full recomputation stores only the input to each transformer block, recomputing the entire block during backward. Cost: one extra forward pass worth of compute (33% overhead). Benefit: activation memory drops from O(L × s × h) to O(L), because only the block inputs are stored.

Selective recomputation: the best of both worlds. Not all activations cost equally. Attention scores (shape B × heads × s × s) scale quadratically with sequence length — they are expensive to store but cheap-ish to recompute. The FFN intermediate activations (shape B × s × 4h) are cheaper to store. Selective recomputation (Megatron-LM) recomputes only the attention softmax and dropout activations, which account for most of the activation memory at long context, while storing the cheaper ones. Typical result: 60–70% of full recomputation's memory benefit at only ~5% compute overhead instead of 33%.

The decision rule: compare the time to recompute a checkpoint versus the time saved by not keeping those activations in HBM. If arithmetic intensity is high (heavy recompute cost) and the activations are small, store them. If activations are large (long sequence, large batch) and the recompute is cheap (a few elementwise ops), recompute them. In practice:

Activation typeSizeRecompute costRecommendation
Attention scores Q·KT/√dB × h × s × s × 2O(s²·d) FLOPsRecompute (FlashAttn already does this)
Attention dropoutB × h × s × s × 1 bitCheap (Bernoulli)Recompute (binary mask, tiny)
FFN activation (GELU)B × s × 4h × 2O(s·h) FLOPsStore (cheap, but measurable)
LayerNorm inputsB × s × h × 2One sum + sqrtRecompute if SP is active
FlashAttention already recomputes by design. The online softmax trick in FlashAttention doesn't materialize the full s×s attention matrix. It recomputes the softmax denominator and attention scores in the backward pass from the Q, K, V tiles and the output O. So using FlashAttention automatically gives you activation recomputation for the most expensive activation — for free.
Memory vs compute tradeoff: recomputation impact across strategies

DDP vs ZeRO-1/2/3/FSDP — see how memory and communication trade off. Includes activation cost at given sequence length.

Sequence length (tokens) 4096
Batch size per GPU 2
Full activation recomputation reduces activation memory to O(L) at what compute cost compared to standard backprop?

Chapter 9: Connections: Putting It All Together

You now have all three axes of 3D parallelism plus the supporting tools that make them work at scale. Here's a decision guide for a real run:

The 3D parallelism recipe.
  1. Step 1 — Fill the node with TP. Set TP = min(8, max_useful_split). TP reduces parameter memory by TP× and activations by TP×. Use sequence parallelism on top to reduce LayerNorm/Dropout activations by another TP×. Communication: NVLink (fast).
  2. Step 2 — Stack nodes with PP. Set PP = num_nodes_needed to fit model after TP. Each stage holds L/PP layers. Communication: one activation tensor per stage boundary per micro-batch (InfiniBand, but low volume). Pick m such that bubble = (PP-1)/(m+PP-1) < 20%.
  3. Step 3 — Scale compute with DP. Replicate the TP×PP group across DP copies. Add ZeRO-1 or ZeRO-2 on the DP axis for essentially free optimizer memory savings. Add FSDP if memory is still tight. Communication: gradient all-reduce once per step (InfiniBand, but can overlap with compute).
  4. Step 4 — Add activation recomputation if activation memory is the bottleneck. Use selective recomputation first (only attention scores and dropout), full recomputation only if needed.
TechniqueMemory winCompute costComm costBest for
TP÷TP (weights + activations)NoneAll-reduce per layer (NVLink)Intra-node, always on
PP÷PP (weights + opt state)Bubble fractionActivation send per stageInter-node, large models
DPNone (naive)NoneGradient all-reduceCompute scaling
ZeRO-1/2÷DP (opt/grad)NoneSame as DPFree win on top of DP
FSDP (ZeRO-3)÷DP (all)None3M vs 2M (50% more)Cannot fit even with TP+PP
Sequence Parallel÷TP (LayerNorm act.)NoneFused with TP commLong context + large TP
Activation Recompute~O(L) activations+33% (full), ~5% (selective)NoneLong context, large batch
Bubble fraction worked example (Llama 3 405B). PP=16, m=64 global micro-batches per pipeline pass. Bubble = (16-1)/(64+15) = 15/79 = 18.99% ≈ 19%. Reported MFU = 38%. If there were no bubble and no comm overhead, theoretical MFU would be ~47% (accounting for inherent compute bound). Bubble explains ~9 percentage points of lost efficiency; communication overlap hides much of the rest.
Bubble fraction vs micro-batch count (bubble curve)

Set PP stages. See bubble fraction curve as m increases. The dashed line marks 20% — a common practical target.

Pipeline stages PP 8
Related lessons.
"The hardware is getting faster, but we will always want to train bigger models — so this hierarchical, multi-axis structure of parallelism will always be with us. The only thing that changes is which bandwidth is the bottleneck."
— Percy Liang, CS336 Lecture 8
Cheat sheet: key formulas.
  • Pipeline bubble fraction: (P−1) / (m + P−1)
  • Total GPUs for 3D: NTP × NPP × NDP
  • FSDP memory per GPU: (2+2+12) × Params / D bytes
  • FSDP comm per step: 3M (2 all-gather + 1 reduce-scatter), DDP = 2M
  • Activation memory per layer: 10 × B × s × h × 2 bytes (no recompute)
  • Recompute overhead: +33% FLOPs (full), +5% (selective attention only)
  • PP activation tensor size: B × s × h × 2 bytes per stage boundary
For a 512-GPU run: TP=8, PP=8, DP=8. A 70B model. Bubble target < 20%. What is the minimum number of micro-batches m needed?