A 70B model's weights alone weigh 140 GB in FP16 — and training needs weights + gradients + Adam states, which totals over 1 TB. No single GPU holds that. You must distribute computation across hundreds of GPUs while keeping them synchronized. This lesson derives every communication primitive from scratch, shows exactly how ZeRO shards optimizer states, gradients, and weights to recover memory, explains how Megatron splits a single matmul across GPUs, and derives the pipeline bubble formula that quantifies the idle-time tax of staging layers. Every claim comes with numbers.
Let's start with a concrete problem. You want to train LLaMA-2 70B, the open-source model released by Meta. Its parameter count is 70 × 10⁹. In FP16 (2 bytes/param) that's 140 GB just for weights. An NVIDIA H100 — the most capable GPU available as of 2024 — has 80 GB of HBM. The model weights alone don't fit on a single GPU.
But that's only the beginning. Training requires three more memory buckets beyond weights. Gradients for every parameter: another 140 GB in FP16. Optimizer states for Adam: a FP32 master copy of weights (4 bytes/param = 280 GB), first moment (4 bytes/param), and second moment (4 bytes/param) — totaling 12 bytes per parameter = 840 GB. Grand total: 140 + 140 + 840 = 1,120 GB to train LLaMA-2-70B in full precision. That's 14 fully-loaded H100s just to hold the state, and you haven't run a single forward pass yet.
Even for a "small" 7B model: weights (14 GB) + gradients (14 GB) + Adam states (84 GB) = 112 GB. Still doesn't fit on one 80 GB GPU. And if it did fit, training 7B on a trillion tokens would take a single GPU 355 years (GPT-3's training bill was 3.1 million GPU-hours). Distributed training isn't optional — it's the only path.
There are three ways to distribute a training job. Data parallelism: every GPU holds a full model copy but processes a different mini-batch shard; gradients are synchronized across GPUs. Tensor parallelism: a single layer's weight matrix is split across multiple GPUs that each compute a partial result and merge it. Pipeline parallelism: different layers run on different GPUs in sequence, like an assembly line. In practice, training large models uses all three simultaneously — called 3D parallelism.
Each stacked bar shows the memory breakdown (weights + gradients + Adam states) for a model. The red dashed line is one GPU's memory (80 GB H100). Drag the slider to see how GPU count affects per-GPU memory under naive data parallelism (no sharding).
Before we can understand how GPUs synchronize during training, we need to establish the vocabulary of collective communication. These are the fundamental operations — implemented in NCCL (NVIDIA's GPU communication library) — that underpin every distributed training system.
The simplest primitive is Send / Recv: transfer a tensor from one process to another. This is one-to-one communication — the building block for everything else. More interesting are one-to-many operations. Broadcast sends an identical copy from one GPU to all others. Scatter splits a tensor and sends each chunk to a different GPU. Gather collects chunks from all GPUs into one destination. Reduce is like Gather but applies an aggregation (usually summation or averaging) during collection — so [3][5][2][4] reduces to [14] at the destination.
The most important primitives for training are many-to-many operations. All-Reduce performs Reduce across all GPUs and then broadcasts the result back — every GPU ends up with the same aggregated tensor. All-Gather is like Gather but every GPU ends up holding all chunks. Reduce-Scatter is the inverse: reduce all chunks, then scatter so each GPU holds one shard of the reduced result. These three — All-Reduce, All-Gather, and Reduce-Scatter — are the workhorses of distributed training.
In the classic Parameter Server architecture: step 1 is a Broadcast (server sends weights to all workers), and step 4 is a Reduce (workers push gradients, server aggregates). The problem: the server's bandwidth scales as O(N) with the number of workers — it's the bottleneck. At N=128 workers, the server must receive 128× the gradient traffic. This is why modern distributed training uses All-Reduce instead: no central bottleneck, symmetric load across all nodes.
| Operation | Who sends | Who receives | Result | Use in training |
|---|---|---|---|---|
| Broadcast | 1 GPU | All GPUs | All have same copy | Param server → workers |
| Reduce | All GPUs | 1 GPU | 1 GPU has sum | Workers → param server |
| All-Reduce | All GPUs | All GPUs | All have same sum | Gradient sync in DDP |
| All-Gather | All GPUs | All GPUs | All have full tensor | ZeRO-3: gather shards before forward |
| Reduce-Scatter | All GPUs | All GPUs | Each holds 1 shard | ZeRO-3: scatter grads |
The naive all-reduce is sequential: GPU 0 reduces with GPU 1, then with GPU 2, and so on. This takes O(N) steps and O(N) bandwidth at the aggregator — the same bottleneck as a parameter server. Can we do better?
The ring all-reduce arranges all N GPUs in a logical ring: 0 → 1 → 2 → ... → N-1 → 0. Each GPU sends data to its right neighbor and receives from its left neighbor simultaneously. Split the gradient tensor into N equal chunks. In Phase 1 (Reduce-Scatter), over N-1 steps, each GPU passes its chunk to the next GPU, accumulating (summing) as it goes. After N-1 steps, each GPU holds the fully-reduced version of exactly one chunk — which is 1/N of the total data. In Phase 2 (All-Gather), over another N-1 steps, each GPU sends its reduced chunk around the ring. After 2(N-1) total steps, every GPU holds every reduced chunk — the same final result as a centralized all-reduce.
The bandwidth math is elegant. At each step, each GPU sends and receives exactly (data size / N). The per-GPU bandwidth is therefore constant regardless of N — O(1) peak load per node. Over 2(N-1) steps, each GPU sends a total of 2(N-1)/N × data_size bytes. As N → ∞, this approaches 2 × data_size. For large N, the ring all-reduce communication cost is:
Where D is the total gradient data size. This is bandwidth-optimal: you cannot synchronize N copies of a tensor with less than 2D total communication (you need to both reduce and distribute). The ring algorithm achieves this lower bound. For 7B model gradients in FP16: D = 14 GB. Ring all-reduce total bytes per GPU ≈ 2 × 14 GB = 28 GB — and this can be fully pipelined with backward computation (PyTorch DDP does this by starting the all-reduce as soon as gradients are computed for each layer, before the backward pass finishes).
Sliders control the number of GPUs in the ring and which step to visualize. Each colored segment represents a partial sum chunk. After the reduce-scatter phase (N-1 steps), each node has one fully-reduced chunk. After all-gather (another N-1 steps), all nodes have all chunks.
Data parallelism is the most natural way to scale training. The idea: give every GPU a complete, identical copy of the model. Split the training batch across GPUs — if you have N GPUs and a batch of B, each GPU sees B/N samples. Each GPU runs a full forward and backward pass on its shard, computing local gradients. Then, all-reduce the gradients across GPUs. Now every GPU has the same gradient — as if the full batch B had been processed on a single GPU. Every GPU performs the same weight update. Repeat.
In PyTorch, this is DistributedDataParallel (DDP). DDP registers gradient hooks so that as soon as a layer's backward pass finishes, its gradients are launched into the all-reduce immediately — before the rest of the backward pass completes. This computation-communication overlap hides most of the all-reduce latency behind the backward pass, making DDP nearly as fast as single-GPU training up to moderate scales.
The math is clean. With N GPUs, each processes B/N samples per step. The effective batch size is still B (since the averaged gradients are equivalent to full-batch). Throughput scales as O(N): N GPUs can process N× as many tokens per second. However, there is a subtle issue: large-batch training degrades convergence quality. The "linear scaling rule" (linear-scale the learning rate with batch size) works up to batch size ~8K; beyond that, large-batch training requires warm-up, LARS/LAMB optimizers, and often produces lower final accuracy than small-batch training. Training GPT-3 required 175B parameters × 300B tokens — even with 1,024 A100s, that's over 34 days. Data parallelism handles throughput, not model scale.
python # PyTorch DDP — launch with: torchrun --nproc_per_node=8 train.py import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP dist.init_process_group(backend='nccl') # NCCL for GPU-to-GPU rank = dist.get_rank() # this GPU's ID (0..N-1) device = torch.device(f'cuda:{rank}') model = MyModel().to(device) model = DDP(model, device_ids=[rank]) # wraps model, hooks gradients # Training loop — DDP handles all-reduce transparently for batch in dataloader: # sampler splits batches by rank loss = forward(model, batch) loss.backward() # triggers all-reduce as we go optimizer.step() # same gradient on all GPUs optimizer.zero_grad() # gradient synchronization cost: ~2×(N-1)/N × 14GB = ~27.6GB for 7B, N=8
Data parallelism wastes memory on redundancy: N GPUs each hold N identical copies of the same weights, gradients, and optimizer states. ZeRO (Zero Redundancy Optimizer), from DeepSpeed (Microsoft Research), eliminates this redundancy by sharding across the N data-parallel GPUs — while maintaining the same computation pattern.
ZeRO has three progressive stages. ZeRO-1 shards only the optimizer states. Each GPU holds full weights and full gradients, but only holds 1/N of the optimizer states (momentum + variance). Memory: (2 + 2 + 12/N) bytes/param. For N=64: (2 + 2 + 0.1875) = 4.1875 bytes/param. A 7B model: 7B × 4.19 = 29.3 GB per GPU. ZeRO-2 also shards gradients. Each GPU holds full weights but only 1/N of gradients and 1/N of optimizer states. Memory: (2 + 2/N + 12/N) bytes/param = (2 + 14/N). For N=64: 2.22 bytes/param. ZeRO-3 shards everything — weights, gradients, and optimizer states. Memory: (2 + 2 + 12)/N = 16/N bytes/param. For N=64: 0.25 bytes/param → 7B × 0.25 = 1.75 GB per GPU. You can train a 320B model on 64 A100s!
But ZeRO-3 has a catch: when a GPU needs to compute the forward or backward pass through a layer, it must first gather the full layer parameters from all other GPUs (an All-Gather), then compute, then discard those weights. Gradients are similarly scattered with a Reduce-Scatter after the backward pass. The communication volume for ZeRO-3 is equivalent to a standard all-reduce: the extra cost is small relative to the memory savings, but it does add latency on each layer.
In PyTorch, ZeRO-3 is implemented as FSDP (Fully Sharded Data Parallel). FSDP wraps each module (or a group of modules) — when a forward pass enters a wrapped module, FSDP runs All-Gather to reconstruct the full parameters; when the backward pass exits, FSDP runs Reduce-Scatter to shard the gradients back. The parameters are discarded immediately after use to reclaim memory.
python from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy import functools # Wrap each transformer layer independently (shard at layer granularity) auto_wrap = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock} ) model = FSDP(model, auto_wrap_policy=auto_wrap, mixed_precision=MixedPrecision( param_dtype=torch.float16, # FP16 for fwd pass reduce_dtype=torch.float32)) # FP32 for gradients # Training loop is identical to DDP — FSDP handles all-gather/reduce-scatter for batch in dataloader: loss = forward(model, batch) loss.backward() optimizer.step() # optimizer updates only the local shard
Drag the GPU count slider. Each line shows per-GPU memory for Baseline DP, ZeRO-1, ZeRO-2, and ZeRO-3 for a 7B model. The H100 80 GB line shows the single-GPU limit. ZeRO-3 is the only strategy that crosses below the GPU line with few GPUs.
Data parallelism improves throughput but not per-GPU memory unless combined with ZeRO. Tensor parallelism reduces per-GPU memory but requires fast intra-node interconnect. A third strategy is pipeline parallelism: split the model by layer across GPUs. GPU 0 runs layers 0–15, GPU 1 runs layers 16–31, and so on. The model passes activations between GPUs like an assembly line.
The memory benefit is immediate and exact. A 70B model with 80 transformer layers distributed across 8 GPUs: each GPU holds 10 layers. 70B/8 = 8.75B params worth of weights. At 2 bytes/param (FP16), that's 17.5 GB per GPU — well within an 80 GB H100. Pipeline parallelism splits the model by layer depth, so memory scales as 1/P where P is the number of pipeline stages.
But there's a painful inefficiency: the pipeline bubble. In a naive pipeline, GPU 1 can't start until GPU 0 finishes its forward pass. GPU 2 waits for GPU 1. During this ramp-up phase, only one GPU is working at a time — the others are idle. The same happens in the ramp-down (drain) phase of the backward pass. With P stages and a single micro-batch, the bubble fraction is:
For P=4 stages, bubble fraction = 3/4 = 75% idle time. GPipe (Google, 2019) introduced the fix: micro-batching. Split the mini-batch into m micro-batches and feed them continuously. While stage 3 finishes micro-batch 1's forward, stage 2 can process micro-batch 2. The bubble fraction with m micro-batches is:
For P=4 stages and m=8 micro-batches: bubble = 3/(8+3) = 3/11 ≈ 27.3% idle. For m=16: 3/19 ≈ 15.8%. For m→∞, bubble → 0. Practical sweet spot: m ≈ 4P gives bubble fraction below 20%. Let's verify: P=4, m=8. Total time slots = m + (P-1) = 8 + 3 = 11. Useful slots = m = 8. Idle slots = 3 (both ramp-up and drain). Bubble = 3/11 = 27.3%. Matches the formula.
Each row is a pipeline stage (GPU). Each cell is a time slot. Blue = forward pass, orange = backward pass, gray = bubble (idle). Drag sliders to change stages and micro-batches. Watch the bubble fraction value update.
Pipeline parallelism splits between layers — each GPU processes a different set of layers. Tensor parallelism splits within a single layer — multiple GPUs each compute a different part of the same matrix multiplication. This is the approach pioneered by Megatron-LM (NVIDIA, 2019).
Consider an FFN with two linear layers: X ∈ [seq, d_model] → A ∈ [d_model, 4d] → ReLU → B ∈ [4d, d_model] → Z. For d_model=4096 and d_ff=16384, matrix A alone is 4096×16384 = 67M params = 128 MB in FP16. Now split A by columns across T=8 GPUs: each GPU holds A_i ∈ [d_model, d_ff/T] = [4096, 2048] = 8M params = 16 MB. Each GPU receives the full input X (same on all GPUs) and computes a partial hidden state H_i = ReLU(X @ A_i) locally. Then split B by rows: each GPU holds B_i ∈ [d_ff/T, d_model]. Each GPU computes partial output Z_i = H_i @ B_i. The final output Z = ΣZ_i requires one All-Reduce to sum partial results.
Memory saving: each GPU holds 1/T of A and 1/T of B. The All-Reduce tensor has shape [seq_len, d_model] — for seq_len=4096, d_model=4096, FP16: 32 MB. On NVLink at 600 GB/s, this takes ~0.05 ms. The forward pass time for one FFN on one GPU is roughly (seq × d_ff) matmul time — on an A100 at 312 TFLOPS, this is (4096 × 4096 × 16384 × 2) / 312e12 ≈ 1.7 ms. The All-Reduce is ~3% of compute time. Tensor parallelism is efficient when NVLink is available (intra-node), and becomes communication-bound over InfiniBand or Ethernet (inter-node).
python # Megatron-style tensor parallelism for one FFN block import torch import torch.distributed as dist class TensorParallelFFN(nn.Module): def __init__(self, d_model, d_ff, tp_size, tp_rank): super().__init__() # Column-parallel: each GPU gets d_ff/T output columns self.W1 = nn.Parameter(torch.empty(d_model, d_ff // tp_size)) # Row-parallel: each GPU gets d_ff/T input rows self.W2 = nn.Parameter(torch.empty(d_ff // tp_size, d_model)) self.tp_group = ... # dist process group for TP ranks def forward(self, x): # x: [batch, seq, d_model] — replicated across TP group h = torch.nn.functional.relu(x @ self.W1) # [batch,seq,d_ff/T] z = h @ self.W2 # [batch,seq,d_model] partial dist.all_reduce(z, group=self.tp_group) # sum partial outputs return z # full d_model output
Data parallelism scales throughput. ZeRO/FSDP reduces memory redundancy. Pipeline parallelism splits layers across GPUs. Tensor parallelism splits individual weight matrices. Large-scale training systems use all three simultaneously — called 3D parallelism. The total GPU count is the product of three degrees: N = D × T × P, where D is data-parallel degree, T is tensor-parallel degree, and P is pipeline stages.
Consider training a 530B parameter model on 2,048 A100s (as done by BigScience BLOOM). Configuration: T=8 (one node, NVLink), P=12 (pipeline stages, inter-node), D=2048/(8×12)=21.3 ≈ 21 data-parallel groups. In practice: T=8, P=12, D=21, total = 8×12×21 = 2,016 GPUs. Each pipeline stage holds 530B/12 ≈ 44B params. Each TP-P group then has 44B/8 ≈ 5.5B params per GPU. At 16 bytes/param for training (with ZeRO-1 for optimizer states): 5.5B × (2+2+12/21) ≈ 5.5B × 4.57 ≈ 25 GB per GPU — fits in 80 GB.
The key rule for ordering the parallelism dimensions: tensor parallelism innermost (uses fast NVLink), pipeline parallelism middle (uses InfiniBand but is point-to-point), data parallelism outermost (uses all-reduce but only once per step). Within a single node of 8 GPUs: T=8. Across nodes within a rack: P stages. Across racks: D data parallel. This layering minimizes the use of the slowest interconnect.
Beyond 3D parallelism, sequence parallelism (splitting the token dimension for very long sequences), expert parallelism (for mixture-of-experts models like Mixtral), and activation checkpointing interact with the 3D scheme. Activation checkpointing trades compute for memory: discard intermediate activations during the forward pass, recompute them from the checkpoint during backward. This reduces activation memory from O(L×seq×d) to O(sqrt(L)×seq×d) at the cost of roughly 33% extra compute.
Each cell represents one GPU. Color indicates the data-parallel group. The grid rows are tensor-parallel groups (within nodes). Columns of colored groups form pipeline stages. Drag sliders to configure the 3D setup. Total GPUs = D×T×P shown in title.
This showcase brings together the five key numerical results of this lesson: the memory wall, ring all-reduce cost, ZeRO memory savings, pipeline bubble fraction, and 3D parallelism memory-per-GPU. You can configure a model size, parallelism strategy, and hardware, and see all the numbers update live.
Configure your model and parallelism strategy. The canvas shows per-GPU memory breakdown under your chosen strategy, plus the pipeline bubble fraction and ring all-reduce communication cost. All numbers are derived, not looked up.
You now have the complete distributed training toolkit. Let's consolidate everything into a decision framework — then connect to the broader landscape.
| Strategy | What it splits | Memory/GPU | Communication | Interconnect needed | When to use |
|---|---|---|---|---|---|
| Data Parallel (DDP) | Batch across GPUs | Full model | All-Reduce gradients, 2×D bytes/step | Any (async OK) | Model fits on 1 GPU; want throughput |
| ZeRO-1 | Optimizer states sharded | ~4 bytes/param | +All-Gather optim states | InfiniBand | Moderate memory savings; easy to enable |
| ZeRO-2 | Optim + gradients sharded | ~2 bytes/param | +Reduce-Scatter grads | InfiniBand | Larger models; 7B on 4×A100 |
| ZeRO-3 / FSDP | Everything sharded | 16Ψ/N bytes | All-Gather params + Reduce-Scatter grads per layer | InfiniBand | 70B+ on 8–64 GPUs |
| Pipeline Parallel | Layers across GPUs | 1/P model | Activations between stages | InfiniBand OK | Very deep models; ≤20% bubble with m≥4P |
| Tensor Parallel | Weight matrices within layers | 1/T model | All-Reduce per layer (sync!) | NVLink required | Very large layers; T≤8 (one node) |
| 3D Parallel | All three combined | 16Ψ/(D×T×P) approx | All of the above | NVLink + InfiniBand | 100B+ models on 1000+ GPUs |
Before adding complexity, always start with this formula. For a model with Ψ parameters trained with Adam in FP16 mixed precision, on N GPUs with ZeRO-3, pipeline degree P, tensor degree T:
Activation checkpointing (gradient checkpointing): discard all intermediate activations during forward; recompute them from checkpoints during backward. Memory savings: 2×–4× reduction in activation memory. Cost: ~33% more compute. Mixed precision: store weights in BF16, compute in FP32, store gradients in FP32 — reduces weight memory 2×. Gradient accumulation: run multiple micro-batches and sum gradients before doing one optimizer step — simulates larger batch without extra memory. These three techniques are usually the first adjustments before adding model parallelism.
The lesson that follows this one covers the other extreme: training not on a cluster of 1,000+ GPUs, but on device — a smartphone or MCU with 1–100 MB of RAM. The memory budgets flip: where distributed training worries about fitting 1 TB+ of optimizer state, on-device training worries about fitting a 1MB model update. Techniques like LoRA and federated learning bridge the two extremes. Both ends share the same fundamental constraint: model state must fit in available memory, and communication is the bottleneck.