Language Modeling from Scratch · CS336 · Lecture 7

Parallelism I: Data & Tensor Parallelism

A 70B model needs 140 GB just for weights — a single A100 holds 80 GB. Training is too slow on one GPU anyway. This lesson derives every tool you need: collective communication primitives (all-reduce, reduce-scatter, all-gather), data parallelism with ring all-reduce cost formula, ZeRO stages 1–3, tensor parallelism (column/row parallel matmuls), and the memory/bandwidth/batch-size tradeoffs. Pipelines and sequence parallelism are next lecture.

Prerequisites: CS336 Lec 5 GPUs (memory hierarchy, arithmetic intensity). Lec 6 Kernels (HBM traffic). Basic PyTorch.
10
Chapters
5
Live Canvases
Real
Comm Cost Math

Chapter 0: One GPU Isn't Enough Anymore

You want to train Llama 3 405B. Its weights alone occupy 405 × 109 × 2 bytes (bf16) = 810 GB. A top-of-the-line NVIDIA A100 has 80 GB of HBM. The model does not fit — not even close. And even if memory weren't the issue, training on a single GPU would take years: GPT-3 at 300B tokens needs roughly 3 × 1023 FLOPs, and an A100 delivers about 312 TFLOPS. That's over 900 GPU-days — for a single training run.

The answer is multi-GPU parallelism: split both the memory and the compute across many devices. But "split" is not one thing. You can split the data (each GPU trains on a different subset of the batch), split the model layers (each GPU holds different layers), or split the weight matrices themselves (each GPU holds different columns or rows of a weight). These three ideas — data, pipeline, and tensor parallelism — are the three axes of what practitioners call 3D parallelism.

This lesson covers axes one and three. Axis two (pipeline parallelism) is the next lecture. By the end, you will be able to compute exactly how much communication each approach requires, derive the ring all-reduce cost formula, and understand why production runs combine all three axes simultaneously.

The core tension. Every form of parallelism forces a tradeoff. Data parallelism is simple and scales compute, but every GPU must store a full copy of the model — the memory problem stays. Tensor parallelism splits the memory, but requires fast all-reduce communication every single layer. ZeRO shards the optimizer state across GPUs, reducing memory with nearly zero communication overhead. No single approach wins. The skill is knowing when to use which — and how to combine them.

Throughout this lesson we work with concrete numbers: a 70B parameter model (Llama 2 scale), bf16 precision, AdamW optimizer, A100 80 GB GPUs. By the end you will know exactly how many GPUs you need to fit this model, and what the bandwidth cost is at each step.

Memory wall: 70B model across precisions and GPU counts

How much memory does a 70B model need? Slide to explore training memory across GPU counts. The dashed line is 80 GB per GPU.

GPU count P 1
A 70B parameter model in bf16 needs how many bytes just for weights?

Chapter 1: Network Topology: NVLink vs PCIe vs InfiniBand

Before we can reason about the cost of communication, we need to understand the network. Modern multi-GPU systems have a strict hierarchy of interconnects, each with very different bandwidths and latencies. Getting the topology wrong means paying 10–100× in communication cost.

Within a single machine (intra-node), NVIDIA GPUs are connected via NVLink — a point-to-point interconnect that runs at up to 600 GB/s bidirectional on H100 NVLink 4.0 (each A100 NVLink 3.0 peer-to-peer link runs at 600 GB/s total, 300 GB/s each direction across all its NVLink connections). This is roughly 10–50× faster than PCIe (which tops at ~32 GB/s for PCIe Gen 4). Within a DGX node you typically have 8 GPUs connected in an all-to-all NVLink topology, meaning every GPU can communicate with every other at full bandwidth simultaneously.

Across machines (inter-node), you're limited to whatever the data center network provides. Modern HPC clusters use InfiniBand HDR at 200 Gb/s (25 GB/s) per port, or Ethernet at 100–400 Gb/s per link. Even top-tier inter-node bandwidth is ~10–20× slower than NVLink. This bandwidth asymmetry is fundamental: it forces a different choice of parallelism strategy depending on whether you're communicating within a node or across nodes.

The golden rule of parallelism placement. Place communication-heavy parallelism on fast links, communication-light parallelism on slow links. In practice: tensor parallelism (8× all-reduce per transformer block) goes within a node on NVLink. Pipeline parallelism (single activation tensor passed between stages, point-to-point) goes across nodes on InfiniBand. Data parallelism (one all-reduce per training step over full gradients) also spans nodes — but the communication happens once per batch step, so the latency is more forgiving.

TPUs take a different approach: their "toroidal mesh" topology connects every chip to its neighbors in a 2D or 3D grid, with dedicated high-bandwidth ICI (inter-chip interconnect) running at ~600 Gb/s per direction. TPUs are optimized for collective communication patterns (all-reduce, reduce-scatter) at the hardware level, which is why they favor data-parallel and fully-sharded training. GPU clusters favor pipeline + tensor parallelism because NVLink is fast but not all-to-all at scale beyond 8 GPUs.

InterconnectBandwidth (per GPU)TopologyBest for
NVLink 3.0 (A100)600 GB/s totalAll-to-all within 8 GPUsTensor parallel, FSDP
NVLink 4.0 (H100)900 GB/s totalAll-to-all within 8 GPUsTensor parallel, FSDP
PCIe Gen 4~32 GB/sStar through CPUSmall DP, CPU offload
InfiniBand HDR25 GB/sFat-tree / DragonflyPipeline parallel, DP
TPU ICI~150 GB/s each directionToroidal meshData parallel, ZeRO
A production training cluster uses 32 nodes, each with 8 A100s. You want to use tensor parallelism (which requires all-reduce every transformer block). Where should tensor parallel boundaries fall?

Chapter 2: Collective Operations

All distributed ML training boils down to six collective communication primitives. A collective is an operation where every GPU in a group participates, and the result depends on contributions from all of them. Unlike point-to-point communication (GPU 0 sends to GPU 1), collectives have a defined mathematical structure that lets the network runtime use optimized algorithms.

Let's define all six with P = 4 GPUs, each holding a vector of M elements. Call GPU i's vector vi.

Broadcast. One GPU sends its data to all others. GPU 0 starts with v0; after broadcast, every GPU holds v0. Cost: M bytes transmitted once across the ring. Use case: distributing model weights from a checkpoint.
Scatter. One GPU splits its data into P chunks and sends one chunk to each GPU. GPU 0 starts with [v0, v1, v2, v3] concatenated; after scatter, GPU i holds vi. Cost: M bytes total, each GPU receives M/P. Use case: distributing a batch evenly across workers.
Gather. The reverse of scatter. Each GPU has a chunk; one GPU receives all chunks. GPU 0 ends with the full vector. Cost: M bytes total.
All-Gather. Every GPU ends up with the full data. Like gather, but the result is replicated to all P GPUs rather than just one. Each GPU starts with M/P elements; each ends with M elements. Cost: M × (P-1)/P bytes per GPU sent/received. This is how ZeRO stage 3 collects parameters before a forward pass.
Reduce. Each GPU holds a vector; one GPU receives the elementwise sum. GPU 0 ends with v0 + v1 + v2 + v3. Cost: M bytes transmitted total. Use case: gradient accumulation to a parameter server.
All-Reduce. Every GPU starts with M elements; every GPU ends with the elementwise sum. This is the critical operation in data-parallel training: each GPU computed gradients on its shard of the batch; now all GPUs need the same average gradient. Naive all-reduce = reduce + broadcast = 2M bytes per GPU. Ring all-reduce (next chapter) = 2 × (P-1)/P × M bytes — nearly the same cost but independent of P.
Reduce-Scatter. Every GPU holds M elements; after reduce-scatter, each GPU holds M/P elements that are the reduced sum of that partition across all GPUs. GPU i holds sum of the i-th chunk from every GPU. Cost: M bytes per GPU. Critically: reduce-scatter + all-gather = all-reduce at the same bandwidth cost. This equivalence is the key insight behind ZeRO.
Collective ops visualizer

Pick an operation to see data movement across 4 GPUs. Each colored block represents one chunk of data.

In data-parallel training, each GPU computes gradients on its mini-batch shard. Which collective is used to synchronize gradients so all GPUs have the same values before the optimizer step?

Chapter 3: Data Parallelism

Data parallelism (DP) is the simplest form of parallelism and the most commonly used. The idea: replicate the entire model on every GPU, split each mini-batch across GPUs, run forward and backward independently on each shard, then synchronize gradients with an all-reduce before the optimizer step.

Formally, SGD with a global batch of size B across M GPUs: each GPU i computes gradients on B/M examples. The update rule is:

θt+1 = θt − η · (1/M) ∑i=1Mi

The sum is computed by all-reduce. Every GPU starts with its local gradient ∇i and ends with the average ∇ = (∇0 + ∇1 + ... + ∇M-1) / M. Then every GPU applies the same optimizer step with the same gradient — so all replicas stay synchronized.

The communication cost. With P parameters, each in bf16 (2 bytes), the all-reduce transmits 2 × P × 2 bytes per GPU per step. The factor of 2 comes from the ring all-reduce algorithm: a reduce-scatter phase (P bytes) plus an all-gather phase (P bytes). For a 70B model that's 2 × 70 × 109 × 2 = 280 GB of data transmitted per GPU per training step. At InfiniBand HDR (25 GB/s), that's 280/25 = 11.2 seconds — not amortized by the forward/backward pass. This is why data parallelism alone is not sufficient at large scale.

The saving grace: gradient all-reduce can be overlapped with the backward pass. As backprop computes gradients for the last layer, those gradients can be immediately sent. By the time backprop reaches the first layer, the last layer's gradients are already synchronized. PyTorch's DistributedDataParallel (DDP) does this automatically via gradient hooks.

Data parallelism has good scaling properties for compute: doubling GPUs doubles throughput. The catch is memory: every GPU must store the full model — weights, gradients, and optimizer state. For a 70B model with AdamW in mixed precision: 2 (weights) + 2 (gradients) + 4 (fp32 master) + 4 (adam m1) + 4 (adam m2) = 16 bytes per parameter × 70B = 1.12 TB per GPU. An A100 has 80 GB. Data parallelism alone fails catastrophically on memory.

Data parallelism: batch split → local grads → all-reduce

Step through one training iteration with 4 GPUs. Watch the batch split, gradients accumulate, and all-reduce synchronize.

Step 0 / 4
Naive data parallelism with M=8 GPUs trains on a batch of 8192 examples. What is each GPU's batch size, and how does memory scale compared to 1 GPU?

Chapter 4: ZeRO: Sharding the Optimizer State

Data parallelism wastes memory: every GPU stores the full model, gradients, and optimizer state. ZeRO (Zero Redundancy Optimizer) eliminates this redundancy by sharding the expensive parts across GPUs — while keeping the same communication cost as naive DDP.

ZeRO has three stages, each sharding more aggressively:

ZeRO Stage 1: Optimizer state sharding. Every GPU still holds a full copy of weights and gradients. But instead of each GPU maintaining the full Adam first and second moments, each GPU is responsible for updating only its slice of the parameters. After the all-reduce averages gradients, a reduce-scatter distributes them: GPU i gets only the gradient for parameter shard i. Each GPU updates its parameter shard using its own moment estimates. Then an all-gather reconstructs the full parameter vector. Communication cost: 1 reduce-scatter + 1 all-gather = 2 × P params = same as naive DDP. Memory win: optimizer state is divided by P. For a 70B model with P=8: Adam moments are 8 bytes/param × 70B / 8 = 70 GB savings per GPU. Nearly free!
ZeRO Stage 2: Gradient sharding. Also shard the gradients. During backprop, each layer's gradients are computed and immediately reduce-scattered to the responsible GPU. Once a gradient is no longer needed in the backward graph, it's freed. Each GPU accumulates only the gradient for its shard. Memory win: gradient memory also divided by P. Communication cost: still 2P (one reduce-scatter over full gradients + one all-gather over params).
ZeRO Stage 3 (FSDP): Full sharding. Shard the parameters too. During forward pass, an all-gather reconstructs each layer's weights just before computing that layer, then frees them. During backward, another all-gather is needed. At the end, a reduce-scatter distributes gradients. Communication cost: 3P per step (2 all-gathers + 1 reduce-scatter). Memory win: all three components divided by P — linear memory scaling. PyTorch's FullyShardedDataParallel (FSDP) implements ZeRO stage 3.

The key insight — proven in the ZeRO paper — is that all-reduce is mathematically equivalent to reduce-scatter + all-gather, and in the bandwidth-limited regime both cost the same. So ZeRO stages 1 and 2 are free memory wins: you pay the same communication bandwidth but store much less on each GPU.

StrategyComm cost / stepBytes / param (70B, 8 GPU)Max params on 8×80GB
Naive DDP2P all-reduce16 bytes40B
ZeRO Stage 12P (RS + AG)2+2+(8/8) = 5 bytes128B
ZeRO Stage 22P (RS + AG)2+(2/8)+(8/8) = 3.25 bytes~197B
ZeRO Stage 33P (AG+AG+RS)(2+2+8)/8 = 1.5 bytes~427B
Pure BF16 training (Kahan summation). The table assumes fp32 master weights (4 bytes) + fp32 Adam moments (4+4=8 bytes) + bf16 weights (2 bytes) + bf16 grads (2 bytes) = 16 bytes. With pure BF16 training and Kahan summation, you drop the fp32 master copy → 12 bytes/param baseline. ZeRO stage 3 then fits a 53B model on a single 8×80GB node (12×70B×10⁹/8 = 105GB > 80GB, so 70B still barely doesn't fit on 1 GPU but fits across 8).
ZeRO Stage 1 costs the same communication bandwidth as naive DDP, but saves memory. Why is this possible?

Chapter 5: Ring All-Reduce: Deriving the Cost

Naive all-reduce — reduce to one GPU, then broadcast — is terrible: the bottleneck GPU receives P-1 vectors and transmits one, making bandwidth scale linearly with P. The ring all-reduce algorithm fixes this by distributing work evenly across all P GPUs in a ring topology.

The algorithm has two phases, each requiring P-1 steps:

Phase 1 — Reduce-Scatter: Arrange P GPUs in a ring. Each GPU holds a vector of M elements, split into P chunks of M/P each. In each of P-1 steps, every GPU sends one chunk to its right neighbor and receives one chunk from its left neighbor, accumulates (sums). After P-1 steps, each GPU holds one fully-reduced chunk of size M/P.

Phase 2 — All-Gather: Each GPU holds one correct chunk. In P-1 more steps, each GPU sends its chunk rightward. After P-1 steps, every GPU has all P chunks — the full reduced vector.

Now let's derive the bandwidth cost per GPU. In phase 1, each GPU sends M/P bytes in each of P-1 steps: (P-1) × M/P = (P-1)/P × M bytes sent, same received. Phase 2: same. Total per GPU:

Bytes per GPU = 2 × (P−1)/P × M

For large P, (P-1)/P → 1, so the cost approaches 2M — twice the message size, independent of P. This is remarkable: adding more GPUs doesn't increase the communication cost per GPU. The ring is bandwidth-optimal.

Time to complete the all-reduce, given bandwidth B GB/s per link:

Tallreduce = 2 × (P−1)/P × M / B

Example: 70B parameters × 2 bytes (bf16 gradients) = 140 GB. Ring all-reduce over P=32 GPUs on InfiniBand HDR (25 GB/s): T = 2 × 31/32 × 140 / 25 = 10.85 seconds. That's the communication cost floor per training step — which must be hidden by overlapping with computation.

Bandwidth vs latency regimes. The ring formula assumes bandwidth-limited operation (message large enough that startup latency is negligible). For small messages (latency-limited), the cost is dominated by the per-hop latency times P-1 hops. Real all-reduce libraries like NCCL use the ring for large tensors and tree-based algorithms for small tensors. Gradient communication in LLM training is almost always bandwidth-limited.
Ring all-reduce cost vs GPU count

Drag sliders to set model parameters and bandwidth. See how all-reduce time scales with P. Note it quickly plateaus.

Model size (B params) 70
Bandwidth (GB/s) 25
For ring all-reduce over P=8 GPUs with a 10 GB gradient tensor and 25 GB/s bandwidth, what is the approximate communication time?

Chapter 6: Tensor Parallelism: Splitting the Matmul

Data parallelism replicates the model — memory doesn't scale. ZeRO shards state — memory scales but you're still data-parallel at the compute level. Tensor parallelism (also called model parallelism along the width axis) takes a different approach: split the weight matrices themselves across GPUs, so each GPU only ever holds a fraction of the parameters.

The key mathematical insight: matrix multiplication can be decomposed into sub-matrix multiplications. Given Y = X W, where X is the input (B × din) and W is the weight (din × dout), you can partition W along either its columns or rows:

Column-parallel: Split W into P column blocks W = [W1 | W2 | ... | WP], where each Wi is din × (dout/P). GPU i computes Yi = X Wi without any communication (X is replicated or gathered). The result Yi is a partial output of shape B × (dout/P).

Row-parallel: Split W into P row blocks W = [W1; W2; ...; WP], each of shape (din/P) × dout. GPU i gets input shard Xi (shape B × din/P) and computes partial Yi = Xi Wi. The final output Y = Y0 + Y1 + ... + YP-1 requires an all-reduce over P GPUs.

Megatron-LM convention (forward pass). For a transformer block, you chain column-parallel (for the first projection, no all-reduce) with row-parallel (for the second projection, all-reduce). Specifically: for a two-layer MLP (W1, W2), GPU i computes Ai = GELU(X W1,i) (column-parallel, no comm), then Zi = Ai W2,i (row-parallel, partial sum), then all-reduce over Zi. In the backward pass, f is an all-reduce and g is the identity — the conjugate pattern.

The communication cost per transformer block is one all-reduce over the activations (shape B × s × h), not over the parameters. For batch B=1, sequence s=4096, hidden h=4096, in bf16: 1 × 4096 × 4096 × 2 = 32 MB per all-reduce. With two all-reduces per block (one for attention, one for FFN) and 96 blocks (Llama 3 405B): 2 × 96 × 32 MB = 6.1 GB of all-reduce per forward pass. That's 6.1 / 600 GB/s = 10 ms on NVLink. Compare to the attention + FFN compute time of ~100 ms → 10% overhead. Acceptable on NVLink; unacceptable on InfiniBand.

Why tensor parallel works well within a node. NVLink bandwidth (600 GB/s on A100) means that 32 MB all-reduce completes in 32/600 ≈ 0.05 ms. With P=8 GPUs and the ring formula: 2 × 7/8 × 32 / 600 = 0.047 ms. Effectively free. On InfiniBand (25 GB/s): 2 × 7/8 × 32 / 25 = 2.2 ms per layer. At 2 × 96 layers, that's 423 ms overhead vs ~100 ms compute — 4× slower than compute. This is why tensor parallelism must stay within a node.
Tensor parallelism: column-parallel + row-parallel matmul

See how Y = X·W is split across P GPUs. Toggle between column-parallel (no comm) and row-parallel (all-reduce at end).

GPU count P 2
In tensor-parallel row-parallel mode, each GPU computes a partial matrix product. What collective is needed to get the final output, and why?

Chapter 7: Memory Per GPU: What Gets Sharded

Let's do the accounting concretely. A transformer model has four memory consumers: parameters, gradients, optimizer state, and activations. Each parallelism strategy handles these differently. Understanding which gets sharded — and by how much — tells you exactly how many GPUs you need to fit a model.

For a 70B model with AdamW mixed precision (fp32 master + bf16 weights + bf16 grads + fp32 moments):

ComponentBytes/param70B totalPer-GPU (P=8)
bf16 weights2140 GB17.5 GB (TP) or 140 GB (DP)
bf16 gradients2140 GB17.5 GB (ZeRO-2+) or 140 GB (naive)
fp32 master weights4280 GB35 GB (ZeRO-1+) or 280 GB (naive)
fp32 Adam m14280 GB35 GB (ZeRO-1+) or 280 GB (naive)
fp32 Adam m24280 GB35 GB (ZeRO-1+) or 280 GB (naive)
Total161.12 TB?

Activations are the fourth consumer and are often overlooked. For a transformer layer with batch B, sequence s, and hidden dimension h, the activation memory required to store all intermediate values for backprop is approximately:

Activation memory per layer ≈ 10 × s × b × h × bytes

The 10 terms come from: attention scores (s×s), value projections, FFN intermediate values, layer norm inputs, and dropout masks. For Llama 3 70B with h=8192, s=4096, B=1, 32 layers, bf16: 10 × 4096 × 1 × 8192 × 2 × 32 = 21.5 GB. This is in addition to the parameter memory above.

Tensor parallelism halves activation memory. When you split the FFN intermediate dimension across P GPUs, each GPU only holds the activations for its slice of the hidden dimension. The all-reduce at the end of the row-parallel layer reconstitutes the full activation, but only at the layer output — the internals are sharded. Empirically, tensor parallelism with P=8 reduces activation memory by roughly 8× for the attention and FFN internals, though some terms (LayerNorm inputs) remain unsharded until sequence parallelism is added.
Activation recomputation: trading compute for memory. Instead of storing all activations for backprop, you can recompute them on demand during the backward pass. This costs one extra forward pass worth of compute (~33% overhead) but reduces activation memory to nearly zero. Selective recomputation — only recomputing the expensive attention terms — gives most of the memory benefit at less compute cost. Production runs at scale nearly always use some form of recomputation.
Memory per GPU: parallelism degree vs strategy

Set model size, GPU count, and parallelism strategy to see total memory per GPU. The dashed line is 80 GB (A100 limit).

Model size (B params) 70
GPU count P 8
Tensor parallelism with P=8 GPUs on a 70B model. Each GPU now holds 1/8 of the weights. What is the weight memory per GPU?

Chapter 8: Showcase: 3D Parallelism Configurator

Production LLM training runs never use just one form of parallelism. The Llama 3 405B paper, DeepSeek V3, and Yi all combine data parallelism, tensor parallelism, and pipeline parallelism simultaneously. This is called 3D parallelism — three orthogonal axes of splitting.

The rules of thumb from the Megatron-LM paper (Narayanan et al. 2021) are:

Step 1: Tensor parallel within each node. Use TP degree up to 8 (one per GPU per node). This uses NVLink and is cheap. TP splits weight matrices, reducing memory by TP degree.

Step 2: Pipeline parallel across nodes. After TP fills up the intra-node bandwidth, scale to multiple nodes by assigning stages of the model to different nodes. PP communication is point-to-point activation passing — much smaller than gradient all-reduce. But PP requires large batch sizes to hide the "pipeline bubble."

Step 3: Data parallel for the rest. Once the model fits (after TP + PP), scale out compute by replicating the TP+PP group across more GPU sets with data parallelism. Use ZeRO stage 1 or 2 to reduce the memory overhead of DDP without increasing communication cost.

From Narayanan et al. 2021 (Megatron + Pipeline): For a 530B parameter model on 3072 A100s: TP=8 (within node), PP=35 (across nodes), DP=11. Training throughput: 502 petaflop/s/GPU across all 3072 GPUs — near-linear scaling.
3D Parallelism Configurator

Set model size and cluster shape. See how TP, PP, and DP combine to fit the model and what the communication costs are. Red = model doesn't fit.

Model (B params) 70
Tensor parallel TP (within node) 8
Pipeline parallel PP (across nodes) 4
Data parallel DP (replicas) 8
For a 3D parallelism setup with TP=8, PP=16, DP=4, how many total GPUs are needed?

Chapter 9: Connections & What's Next

This lesson covered the two most bandwidth-critical forms of LLM parallelism. Here's the full picture of what we covered and where it fits:

TechniqueWhat's shardedComm per stepMemory scalingBatch size?
Naive DDPBatch only2P (all-reduce)NoneLinear
ZeRO Stage 1Optimizer state2P (RS+AG)Partial (opt)Linear
ZeRO Stage 2Opt + Grads2P (RS+AG)PartialLinear
ZeRO Stage 3 / FSDPOpt + Grads + Params3PLinearLinear
Tensor ParallelWeight matrices8×bsh per layer (NVLink)LinearNone
Pipeline ParallelLayersbsh per microbatchLinearNeeds large!

What's coming in Lec 8 (Parallelism II): Pipeline parallelism in depth — micro-batches, the bubble formula (nstages-1)/nmicro, zero-bubble pipelining (splitting backward into activation + weight gradient passes), and sequence parallelism (splitting the LayerNorm and dropout along the sequence dimension to reduce activation memory). Sequence parallelism is what makes activation memory truly linear in GPU count.

Real-world combinations (as of 2025). DeepSeek V3: ZeRO stage 1 + TP + sequence parallel + pipeline parallel across 2048 H800s. Llama 3 405B: TP=8, PP=16, DP=16 (2048 GPUs total) with selective activation recomputation. Gemma 2: ZeRO-3 + MP (TP+SP) + DP. Common pattern: fill the intra-node bandwidth (TP=8), then scale inter-node with PP, then scale compute with DP. Use ZeRO-1 or ZeRO-2 for free memory wins on top of whatever else you're doing.
Related lessons.
Cheat sheet: key formulas.
  • Model memory (naive, mixed precision): 16 bytes × P params
  • ZeRO-3 memory: 16 bytes × P params / NGPU
  • Ring all-reduce cost per GPU: 2 × (P-1)/P × M bytes
  • Tensor parallel all-reduce per layer: 2 × (NTP-1)/NTP × B×s×h × 2 bytes
  • Activation memory per layer (approx): 10 × s × b × h × 2 bytes
  • Total GPUs for 3D: NTP × NPP × NDP
"The question is not whether to parallelize, but how to partition the problem across the three axes — memory, compute, and communication — and that requires understanding all three at once."
— paraphrased from Tatsu Hashimoto, CS336 Lecture 7