Language Modeling from Scratch · CS336 · Lecture 2

PyTorch Primitives & Resource Accounting

You have 32 H100s for two weeks. How big a model can you train, and how long will it take? This lecture answers those questions from first principles — tensors, dtypes, FLOPs, gradients, optimizers, and the full training loop. Every number derived, every byte accounted for.

Prerequisites: basic Python + heard the words "neural network" and "GPU". We build the rest from zero.
10
Chapters
6
Live Calculators
0
Assumed Math

Chapter 0: The Budget Problem

You just got access to 32 H100 GPUs for two weeks. Someone asks: "What's the biggest model you can train?" You blink. You have no idea. This is the problem CS336 Lecture 2 solves — and the answer requires understanding every resource your training run consumes, down to the byte.

The resource accounting discipline pays off at every scale. Whether you're fitting a fine-tune into an 8 GB consumer GPU or planning a 10,000-GPU frontier run, the same arithmetic applies. The numbers change; the formulas don't. Once you internalize this framework, every architectural paper you read gives you a new lens: what does this change to N, D, memory, or FLOPs?

Two types of resources govern every training decision: memory (measured in gigabytes, GB) and compute (measured in floating-point operations, FLOPs). They interact in interesting ways. A bigger model uses more memory (caps what fits on one GPU) and more compute (caps how fast you can train). Get the accounting wrong and you either crash out-of-memory or dramatically underestimate training time.

Percy Liang opens this lecture with two napkin-math questions that anchor everything. Before writing a single line of training code, you should be able to answer both in under two minutes.

Question 1: Time
Training a 70B-parameter model on 15T tokens across 1,024 H100s — how many days?
Total FLOPs = 6 × 70×109 × 15×1012 = 6.3×1024. Each H100 does ~989 TFLOP/s (dense bf16). At 50% utilization and 1,024 GPUs: ≈88 days.
Question 2: Size
Largest model on 8 H100s with AdamW (naive fp32) — how many parameters?
8 GPUs × 80 GB = 640 GB. Naive AdamW = 16 bytes per parameter (4 params + 4 grads + 4+4 optimizer states). So max params = 640×109 / 16 = 40B. But activations aren't accounted yet!

Where does "6 × N × D" come from? Why 16 bytes per parameter? Why does utilization only reach 50%? The rest of this lesson builds every one of those numbers from first principles. By the end, you'll do this arithmetic in your head.

The lecture is structured as an "executable lecture" — every claim is backed by a Python assertion that runs and verifies. We follow the same spirit: every formula here comes with a worked example and an interactive calculator. The CS336 course uses this resource accounting mindset as the lens for all subsequent design decisions, from architecture tweaks to parallelism strategies. The six interactive calculators in this lesson let you build direct intuition about these numbers — drag the sliders and watch how FLOPs, memory, and time respond to your changes.

Napkin-math: training time estimator

Adjust the knobs. Observe how parameters, data, GPU count, and MFU interact to determine training time.

Model size (params) 70 B
Training tokens 15 T
H100 count 32
MFU (0–100%) 50%

Let's work the math for Question 1 precisely. The H100's peak bf16 dense throughput is 989 TFLOP/s (= 1979/2, since the 1979 figure is for sparse tensors). At 50% MFU across 1,024 GPUs:

effective FLOP/s = 989×1012 × 0.50 × 1024 = 5.06×1017 FLOP/s
days = 6.3×1024 / (5.06×1017 × 86400) ≈ 144 days

The lecture gives 88 days using a slightly different MFU assumption. These estimates always have a range. The point: training frontier models takes months on thousands of GPUs. Resource accounting is what separates a viable training plan from wishful thinking.

Common misconception: "MFU should be near 100% — anything less is wasted hardware." In practice 50% MFU is considered very good. GPUs are pipelined compute engines; memory bandwidth, communication overhead between GPUs, and scheduling gaps all reduce utilization below peak. A 50% MFU means every second your GPUs are running they're at half-peak; the other half is spent waiting for data, syncing gradients, or launching kernels. The overhead comes from: (1) gradient synchronization across GPUs (all-reduce communication), (2) data loading latency, (3) kernel launch overhead, (4) memory-bound ops like softmax and layer-norm that cannot fully saturate the tensor cores. Getting above 50% MFU requires careful systems engineering — fused kernels, pipeline parallelism, and careful overlap of compute and communication.
The 6×N×D rule says training FLOPs ≈ 6 × (# parameters) × (# tokens). Why 6 and not some other constant?

Chapter 1: Tensors & Memory

Tensors are the universal container for everything in deep learning: parameters, gradients, optimizer states, activations, input data. Understanding them at the storage level is non-negotiable for resource accounting.

Conceptually, a PyTorch tensor is a pointer into a contiguous block of memory plus metadata: the shape (number of elements along each dimension), the stride (how many storage slots to skip to advance one step along each axis), and the dtype (how many bytes each element occupies).

python
import torch

# Three ways to create a 4×8 tensor
x = torch.zeros(4, 8)       # shape [4, 8], dtype float32
x = torch.ones(4, 8)        # all ones
x = torch.randn(4, 8)       # iid Normal(0,1)
x = torch.empty(4, 8)       # allocate, don't initialize

# Memory is: numel() × element_size() bytes
assert x.numel() == 4 * 8    # = 32 values
assert x.element_size() == 4  # float32 = 4 bytes
mem = x.numel() * x.element_size()  # = 128 bytes

Memory formula: for a tensor of shape [d0, d1, …, dk] with dtype bytes b:

memory = d0 × d1 × … × dk × b bytes

Let's make this concrete with a real model weight. GPT-3's feedforward layers use a weight matrix of shape [12288 × 4, 12288] = [49152, 12288]. That's 49152 × 12288 = 603,979,776 values. In float32 (4 bytes each), that's 2,415,919,104 bytes = 2.3 GB for a single weight matrix. The entire model has many such matrices — that's why GPT-3 at 175B parameters simply cannot fit in one GPU's 80 GB.

Strides explain views. A view is a tensor that shares the same underlying storage as another tensor, just with different metadata. When you index a row (x[0]), transpose (x.T), or reshape (x.view(8, 4)), no memory is copied — only the stride metadata changes. Views are free. Copies are expensive (both memory and compute).

python
x = torch.tensor([[0., 1, 2, 3],
                   [4, 5, 6, 7]])
# stride(0)=4: advance dim-0 by skipping 4 storage slots
# stride(1)=1: advance dim-1 by skipping 1 storage slot
assert x.stride(0) == 4
assert x.stride(1) == 1

# x[1][2] = storage[1*4 + 2*1] = storage[6] = 6.0

y = x[0]           # row slice — a view (no copy)
z = x.transpose(1, 0)  # transpose — a view (no copy)
assert y.data_ptr() == x.data_ptr()  # same memory!

# Transposed tensor is non-contiguous — can't .view() it
# Must call .contiguous() first, which DOES copy
z_cont = z.contiguous()  # allocates new memory
Tensor memory visualizer

Choose a tensor shape and dtype. See the physical storage layout and the total memory footprint.

Rows (d0) 4
Cols (d1) 8
Dtype float32

Einops notation for readable tensor operations. Traditional PyTorch code uses x.transpose(-2, -1) and relies on you keeping track of what dimension is what. Einops names the dimensions explicitly:

python
from einops import einsum, rearrange, reduce

# Old way: what is -2, -1 here?
x = torch.ones(2, 2, 3)  # batch, sequence, hidden
z = x @ x.transpose(-2, -1)  # confusing

# New way: named dimensions
z = einsum(x, x, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")

# Rearrange: split a packed head dimension
x = torch.ones(2, 3, 8)  # batch, seq, (heads*head_dim)
x = rearrange(x, "... (heads hd) -> ... heads hd", heads=2)
# shape is now [2, 3, 2, 4]

# Reduce: mean over the hidden dim
y = reduce(x, "batch seq hidden -> batch seq", "mean")
Key insight: the contiguous trap. After transposing or slicing non-contiguous dims, calling .view() raises a RuntimeError. You fix it with .contiguous(), which silently allocates a new copy of the data. In tight training loops with large activation tensors this hidden copy doubles your memory usage. Always check .is_contiguous() before assuming a view is free. The rule: if you need to both transpose AND reshape (e.g., multi-head attention projections), always call .contiguous() explicitly or switch to .reshape() (which calls .contiguous() internally when needed).
A weight matrix of shape [4096, 4096] in float32. How many bytes of GPU memory does it occupy?

Chapter 2: Dtypes: fp32, fp16, bf16, fp8

Floating-point numbers have a fixed number of bits, divided into a sign bit, exponent bits, and mantissa bits. The tradeoffs between these formats are everything in modern training.

DtypeBitsSignExponentMantissaRangePrecision
float32321823±3.4×1038~7 decimal digits
float16161510±65,504~3 decimal digits
bfloat1616187±3.4×1038~2 decimal digits
fp8 E4M38143±448~1 decimal digit
fp8 E5M28152±57,344~0.5 decimal digits

float32 is the historical default. It never underflows for realistic values, gives 7 decimal digits of precision, but costs 4 bytes per value. In scientific computing this was fine; in deep learning with billions of parameters it dominates memory.

float16 halves memory — 2 bytes — but has a tiny dynamic range. The key danger is underflow: values smaller than ~6×10-5 round to exactly zero. During training, gradients for rarely-seen tokens or in deep networks can be this small. When they underflow, those weights stop updating. Training destabilizes.

python
# float16 underflow demonstration
x = torch.tensor([1e-8], dtype=torch.float16)
print(x)  # tensor([0.])  ← UNDERFLOWS to zero!

# bfloat16: same exponent bits as float32, safe range
x = torch.tensor([1e-8], dtype=torch.bfloat16)
print(x)  # tensor([1.0000e-08, dtype=torch.bfloat16])  ← fine

bfloat16 is Google Brain's solution (2018). By keeping float32's full 8-bit exponent, it preserves the dynamic range exactly. The trade-off: the mantissa drops from 23 to 7 bits, so precision is much lower (≈2 decimal digits). For gradient descent this is acceptable — the stochastic noise in mini-batch gradients already swamps that level of precision. bfloat16 is now the standard for LLM training forward passes.

fp8, standardized in 2022, pushes further: 1 byte per value. H100s have native fp8 hardware support. Two variants: E4M3 (more precision, less range, used for forward/activations) and E5M2 (more range, less precision, used for backward/gradients). At fp8, quantization errors become significant; you need scaling factors and careful monitoring. The payoff: 8× the throughput of fp32 on tensor cores.

Dtype dynamic range & memory comparison

Visual comparison of the five key formats. The bar shows the usable exponent range (log-scale). Hover each dtype to see implications for training.

Common misconception: "fp16 and bf16 are interchangeable." They have the same bit-count and memory usage, but critically different exponent widths. fp16's 5-bit exponent gives range only up to 65,504 — loss scaling (artificially multiplying the loss before backward to prevent gradient underflow) is mandatory. bf16's 8-bit exponent matches float32's range — no loss scaling needed. For LLMs, always prefer bf16 over fp16.

Why deep learning is less sensitive to precision than scientific computing. In scientific computing, you might simulate a physical system where floating-point errors accumulate across millions of time steps — you need 15 significant digits (float64). In deep learning, your gradient is already a noisy estimator of the true gradient (mini-batch stochastic sampling introduces variance). Adding 2-digit precision noise to a quantity that already has noise from sampling is insignificant. The stochastic gradient noise swamps the arithmetic precision noise. This is why bf16 — only 2 significant digits — works fine while float64 would triple memory usage for no benefit.

Bits = exponent + mantissa tradeoff. With a fixed number of bits, you must choose between range (more exponent bits) and precision (more mantissa bits). fp16 chose precision over range and burned itself. bf16 chose range (same as float32) and accepted reduced precision. The lesson: for neural networks, range matters more than precision because gradients span many orders of magnitude but their exact values are noisy anyway. The H100's tensor cores perform bf16 multiplications at 16× the throughput of fp32, making this the single most impactful low-precision choice in modern LLM training.
fp8: the frontier for inference. FP8 (fp8 E4M3 and fp8 E5M2) was standardized in 2022 by Nvidia, ARM, and Intel for ML workloads. H100 supports fp8 natively. The key challenge: with only 3 mantissa bits (E4M3), quantization error is significant. Solutions include (a) per-tensor scaling factors that bring the range into the fp8 representable zone, and (b) hybrid schemes where only the linear layers use fp8 (the bulk of FLOPs) while sensitive ops (attention softmax, layer norms) stay in bf16. The payoff: fp8 delivers 2× the TFLOP/s of bf16 on H100 tensor cores = the fastest training currently achievable.
Why does bfloat16 not require "loss scaling" during training, but float16 does?

Chapter 3: FLOPs: Counting Compute

A FLOP (floating-point operation) is one multiply or one add on a float. It's the fundamental unit of compute. We use it to estimate how long training will take, compare hardware, and decide whether an architectural choice is "expensive."

Two acronyms that sound identical but mean different things: FLOPs (with a lowercase "s") = floating-point operations, a count of work done; FLOP/s (or FLOPS) = floating-point operations per second, a measure of speed.

To build intuition: Training GPT-3 took ~3.14×1023 FLOPs. The US executive order (2023) flagged models trained with ≥1026 FLOPs for government reporting. H100 peak: ~989 TFLOP/s (dense bf16) = 989×1012 FLOPs/s.

Matrix multiplication dominates. For a matmul C = A @ B where A is [m × k] and B is [k × n]: each output element C[i,j] = ∑l A[i,l] × B[l,j] requires k multiplications and k−1 additions ≈ 2k FLOPs. There are m×n outputs, so total FLOPs = 2×m×k×n.

FLOPs(A @ B) = 2 × m × k × n

For a batch of B examples, each of dimension D, with weight matrix [D × K]: that's a [B×D] @ [D×K] matmul = 2×B×D×K FLOPs. Note this equals 2 × (# data points) × (# parameters) — the forward pass costs 2×N×D for a model with N parameters on D tokens.

Elementwise operations are cheap. Adding two [m×n] matrices: m×n FLOPs. Applying ReLU to a [m×n] matrix: m×n FLOPs. Softmax over a sequence of T tokens: O(T) per token. These are negligible compared to the matmuls in attention and the feedforward layers, which scale as O(d2) per token.

Batched and higher-dimensional matmuls. PyTorch's @ operator broadcasts over batch dimensions. Given x of shape [B, T, D] and W of shape [D, K]: the result is [B, T, K] and costs 2×(B×T)×D×K FLOPs. This is the fundamental operation in every Transformer: the Q, K, V projections are batched matmuls over all tokens simultaneously, making GPU parallelism trivially exploitable.

python
# Batch matmul: broadcast over first 2 dims
x = torch.ones(4, 8, 16, 32)  # [batch, seq, ?, d_in]
w = torch.ones(32, 2)           # [d_in, d_out]
y = x @ w                          # [4, 8, 16, 2]
# Iterates over dims 0 and 1, multiplies each [16,32]@[32,2]
# Total FLOPs: 2 * 4 * 8 * 16 * 32 * 2 = 131,072

# Timing a matmul on GPU (with proper synchronization!)
if torch.cuda.is_available():
    torch.cuda.synchronize()  # wait for previous ops to finish
    t0 = time.perf_counter()
    y = x_gpu @ w_gpu
    torch.cuda.synchronize()  # wait for this op to finish
    t1 = time.perf_counter()
    actual_flops_per_sec = flops / (t1 - t0)

The synchronize calls are critical. CUDA launches are asynchronous — the Python line y = x @ w returns immediately (just enqueues the kernel), and the actual GPU computation happens in parallel. Without synchronize(), you're timing kernel launch latency (≈microseconds), not execution time (≈milliseconds for large matmuls).

Model FLOPs Utilization (MFU) measures how efficiently you're using the hardware:

MFU = actual FLOP/s achieved ÷ peak FLOP/s of hardware

A100 peak: 312 TFLOP/s (bf16); 19.5 TFLOP/s (fp32). H100: 989 TFLOP/s (bf16 dense); 67.5 TFLOP/s (fp32). The factor-of-16 difference between bf16 and fp32 on H100 is a core reason to use mixed precision — it directly multiplies your effective compute budget.

FLOPs calculator: matmul & MFU

Enter a matmul shape. See the FLOPs, and the theoretical time on A100 vs H100 in fp32 vs bf16.

m (rows of A) 1024
k (inner dim) 4096
n (cols of B) 4096

Scaling laws and the 6×N×D rule. This formula generalizes far beyond a simple linear model. For a Transformer with N parameters and trained on D tokens, the total FLOPs is approximately 6ND. Where does this come from? Each Transformer block has attention and feedforward layers. The dominant cost is the weight matmuls in the feedforward layers (two per block: [d, 4d] and [4d, d]) and the attention projections (four per block: Q, K, V, O). Summing those matmuls over all B×T tokens and L layers, the total forward FLOPs is approximately 2ND (where N ≈ 12d2L as we showed in Chapter 5). Backward adds 4ND. Total: 6ND. This is a first-order approximation that ignores attention's O(T2) cost — valid when d≫T, which holds for large models on moderate context lengths.

Common misconception: "More FLOPs = slower. Fewer FLOPs = faster." Not necessarily! The bottleneck is often memory bandwidth, not compute. A giant matmul is compute-bound (high FLOPs, good GPU utilization). A layer-norm or softmax is memory-bound (reads/writes tensors but barely multiplies). Clever kernel design (fused operations, flash attention) makes memory-bound ops fast by reducing round-trips through HBM. Always ask: is this op compute-bound or memory-bound? The H100's HBM3 bandwidth is ~3.35 TB/s. A layer-norm on a [B,T,d] tensor just reads the tensor and writes the result — you're limited by how fast you can read and write ~BT×d values, not by arithmetic. Flash Attention fuses the QKTV computation into a single kernel to avoid materializing the full O(T2) attention matrix in HBM, turning a memory-bound op into a compute-bound one.
A feedforward layer in a Transformer maps [B, T, d_model] → [B, T, 4*d_model] → [B, T, d_model]. Given B=8, T=512, d_model=1024, how many FLOPs for the two weight matrices (forward pass only)?

Chapter 4: Gradients & the Backward Pass

The forward pass produces a loss. The backward pass answers: "for each parameter w, how does changing w by ε change the loss?" That derivative, d(loss)/dw, is the gradient of w. PyTorch's autograd engine computes these automatically — but you should understand what it's doing and how much it costs.

To compute gradients, you first flag the tensors you care about with requires_grad=True. PyTorch then builds a computation graph as operations execute — every tensor remembers which operation created it and from what inputs. Calling .backward() on the loss traverses this graph in reverse, applying the chain rule.

python
# Simple linear model: loss = 0.5 * (x·w - 5)²
x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True)

# Forward pass
pred_y = x @ w          # scalar: 1+2+3 = 6
loss = 0.5 * (pred_y - 5).pow(2)  # 0.5*(6-5)² = 0.5

# Backward pass — fills .grad on all requires_grad tensors
loss.backward()

# Chain rule: d(loss)/dw_i = (pred_y - 5) * x_i
# = (6 - 5) * [1, 2, 3] = [1, 2, 3]
assert torch.equal(w.grad, torch.tensor([1., 2, 3]))

# Crucially: loss.grad is None — only leaf tensors (with
# requires_grad=True) get .grad filled in by default
assert loss.grad is None

How many FLOPs does the backward pass cost? Let's trace through a two-layer network: x →W1→ h1 →W2→ h2 → loss. W1 is [D×D], W2 is [D×K], inputs batch [B×D].

To compute W2.grad = d(loss)/dW2, we need: W2.grad[j,k] = ∑i h1[i,j] × h2.grad[i,k]. That's a [D×B] @ [B×K] matmul = 2×B×D×K FLOPs — exactly the same as the forward pass through W2.

To compute h1.grad (needed for W1's gradient): h1.grad[i,j] = ∑k W2[j,k] × h2.grad[i,k]. That's another [B×K] @ [K×D] matmul = 2×B×D×K FLOPs.

So for W2 alone, the backward costs 2× the forward. By the same logic for W1, the full network costs:

Total FLOPs = forward (2BD² + 2BDK) + backward (4BD² + 4BDK) = 6 × (2BD² + 2BDK)/2
Simplifying: backward ≈ 2× forward → total ≈ 3× forward → per token: 6×N×D

Putting it all together for the two-layer model:

total backward FLOPs = 4 × B × D × D + 4 × B × D × K
= 2 × (forward FLOPs)

And therefore:

total training FLOPs per step = forward + backward = 3 × forward
= 3 × 2 × B × (D² + DK) = 6 × B × Nparams

Per token (dividing by B), per step: 6N FLOPs. Over D total tokens: 6 × N × D.

The 6× rule, derived. Forward pass = 2N FLOPs per token (factor of 2 from multiply-add). Backward pass = 4N FLOPs per token (2N for parameter gradients + 2N for activation gradients). Total = 6N per token, where N is number of parameters and D is tokens. This is the rule of thumb the course uses everywhere: training FLOPs ≈ 6 × parameters × tokens. One caveat: this ignores attention's O(T2) self-attention term, which becomes significant at very long context lengths. For context length T and model dim d, the full formula adds 4×B×T2×d×L to the matmul cost. When T<<d (e.g., T=2048, d=8192), this is small compared to the 12d2L matmul cost and the 6ND approximation holds.

What tensors get gradients? Only tensors that are reachable in the computation graph from a requires_grad=True leaf tensor accumulate gradients. Leaf tensors are tensors you created directly (not the result of an operation). In a model, parameters are the leaf tensors with requires_grad=True. Non-leaf tensors (activations) get gradients only if you explicitly call .retain_grad() before backward — otherwise they're freed to save memory.

python
x = torch.tensor([1., 2, 3])  # leaf, no grad
w = torch.tensor([1., 1, 1], requires_grad=True)  # leaf, wants grad

h = x @ w         # non-leaf; h.requires_grad=True (inherited)
h.retain_grad()   # ask PyTorch to keep h.grad after backward
loss = 0.5 * (h - 5).pow(2)

loss.backward()
print(w.grad)   # [1., 2., 3.] (chain rule: (h-5)*x)
print(h.grad)   # [1.0] (d(loss)/dh = h-5 = 6-5 = 1)
print(x.grad)   # None  (x is a leaf without requires_grad)
print(loss.grad) # None  (non-leaf, no retain_grad)

Gradient clipping is not optional for LLMs. In deep networks, gradients can explode — a single bad batch or numerical near-singularity can produce gradients of magnitude 1,000,000, which would destroy training. Gradient clipping caps the L2 norm of all gradients:

if ||g|| > max_norm: g ← g × max_norm / ||g||

Typical values: max_norm=1.0. This ensures no single optimizer step moves parameters more than a bounded amount. The PyTorch function clips all parameter gradients jointly (not per-layer): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0). Call it after backward, before step.

Common misconception: "I need to call optimizer.zero_grad() before backward()." Actually, it must be called after optimizer.step(). If you call it before backward, you're fine. But the real concern is that PyTorch accumulates gradients — calling backward twice without zeroing doubles the gradient. The idiomatic pattern: forward → backward → step → zero_grad. Use set_to_none=True (faster than filling with zeros, frees memory). Also: gradient accumulation (for large effective batch sizes) intentionally skips zero_grad between micro-batches — this is correct behavior, not a bug. Just make sure to scale the loss by 1/accumulation_steps.
A single weight matrix W of shape [1024, 1024] is used with a batch of 512 tokens (shape [512, 1024]). The backward pass must compute W.grad. What is the FLOPs cost for W.grad alone?

Chapter 5: Parameters, Modules & Initialization

In PyTorch, model parameters are stored as nn.Parameter objects — they're just tensors with a flag saying "please compute my gradient and include me in optimizer updates." Grouping parameters into nn.Modules gives you structured naming, easy device movement, and state_dict serialization.

python
import torch.nn as nn

# A linear layer with Xavier-like initialization
class Linear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # Scale by 1/sqrt(in_dim) to keep output variance ~1
        self.weight = nn.Parameter(
            torch.randn(in_dim, out_dim) / (in_dim ** 0.5)
        )

    def forward(self, x):
        return x @ self.weight

# Count parameters in a model
model = Linear(1024, 512)
num_params = sum(p.numel() for p in model.parameters())
# 1024 * 512 = 524,288 parameters

Why divide by √d? If you initialize W with iid Normal(0,1) entries and compute x @ W where x is a d-dimensional vector, each output element y[k] = ∑j x[j] × W[j,k] is a sum of d independent terms. By the central limit theorem, Var(y[k]) ≈ d × Var(x[j]). So y has standard deviation √d times larger than x — and that scaling factor compounds through every layer. With even 10 layers, activations explode or vanish. Dividing by √d (Xavier initialization) keeps variance constant across layers.

Let's count parameters for a realistic model. A Transformer block has:

For a 1.4B model with d_model=2048 and 24 layers: 12 × 20482 × 24 = 1.21B (the rest comes from the embedding tables). This matches! The formula 12d2L is surprisingly accurate for large Transformers where d≫embedding dims.

Parameter count by component

Adjust model dimensions. See how parameters distribute across Transformer components.

d_model 1024
Layers (L) 12
Vocab size 32k

How to count parameters for a real Transformer. Let's be explicit. A GPT-style model with d_model=1024, 12 layers, 16 attention heads (head dim=64), FFN dim=4096, vocab size=50,257:

Check: 12d2L = 12×10242×12 = 150.9M (matching the block params!). The formula works.

State dict = all parameters + buffers. model.state_dict() returns an ordered dict of name → tensor for every parameter and buffer. This is what you save in a checkpoint and what you load to resume. The key is the dotted path ("layers.0.weight"), reflecting the module hierarchy. When parameters have identical names across layers, they're separate tensors — PyTorch never shares weights unless you explicitly tie them (some models tie the input embedding to the output projection, saving d_model × vocab_size parameters). With weight tying in our 204M model: saving 53.6M parameters = 51.9M × 4 bytes = 207MB. In practice this is the first optimization applied to smaller models.
Why does dividing weight init by √(input_dim) keep activations stable across layers?

Chapter 6: Optimizer State & Memory

Stochastic gradient descent (SGD) is the simplest optimizer: w ← w − α∇w. But modern LLM training uses AdamW, which requires maintaining additional per-parameter statistics. Understanding why, and how much extra memory it costs, is essential for training budget calculations.

Why not plain SGD? SGD uses the same learning rate for every parameter. Early in training, parameters in dense layers get huge gradients; parameters in the embedding table (for rare tokens) might see zero gradient for thousands of steps. A fixed α either overshoots the dense layers or barely moves the sparse ones.

AdaGrad (Duchi et al., 2011) tracks the sum of squared gradients per parameter, G[i] = ∑t<τ gt[i]2, and divides the update by √G[i]. Parameters that received large gradients slow down; rarely-updated parameters speed up. This is adaptive per-coordinate learning rates.

Adam extends this with momentum (exponential moving average of gradients, mt) and adapts with an exponential moving average of squared gradients (vt), plus bias-correction. AdamW adds decoupled weight decay. The update rule:

mt = β1 · mt-1 + (1−β1) · gt
vt = β2 · vt-1 + (1−β2) · gt2
wt+1 = wt − α · m̂t / (√v̂t + ε) − αλ · wt

Each parameter has two optimizer state tensors: m (first moment) and v (second moment), both the same dtype as the parameter. In float32, that's 4+4+4 = 12 bytes per parameter just for the model + optimizer state, plus another 4 bytes for the gradient = 16 bytes per parameter in naive float32 AdamW.

python
# Memory accounting for a 1B parameter model with naive AdamW
N = 1e9  # parameters
mem_params    = N * 4  # float32 weights   = 4 GB
mem_grads     = N * 4  # float32 gradients = 4 GB
mem_adam_m    = N * 4  # float32 m states  = 4 GB
mem_adam_v    = N * 4  # float32 v states  = 4 GB
total = mem_params + mem_grads + mem_adam_m + mem_adam_v
# = 16 GB — plus activations, which depend on batch size!
print(f"Static memory: {total/1e9:.0f} GB")  # 16 GB

For a 70B parameter model: 70B × 16 bytes = 1,120 GB just for static state, before a single activation. H100s have 80 GB each — you'd need 14 just for the static part. This motivates mixed precision (store weights in bf16 = 8 bytes per param) and ZeRO (shard optimizer state across GPUs).

Common misconception: "AdamW requires twice as much memory as SGD because it stores two extra tensors." It actually requires three times as much as plain SGD for the same model: SGD = params + grads = 8 bytes/param; AdamW = params + grads + m + v = 16 bytes/param. Each optimizer state tensor is full-sized — m and v have the same shape as the parameter tensor they track.

Optimizer state accumulation across training. In the AdamW update, β1 and β2 control how quickly old gradient information decays. Typical values: β1=0.9, β2=0.999. After t steps, the effective "memory" of the optimizer extends back roughly 1/(1-β) steps: ≈10 steps for m, ≈1000 steps for v. This means:

The ZeRO insight. All 16 bytes per parameter must live somewhere, but they don't all need to live on the same GPU. ZeRO (Zero Redundancy Optimizer, Rajbhandari et al. 2019) shards the optimizer states (stage 1), then gradients (stage 2), then parameters (stage 3) across data-parallel ranks. With ZeRO-3 across 64 GPUs, the static memory per GPU drops from 16N bytes to 16N/64 = 0.25N bytes. Communication is the price: each GPU must reconstruct the full parameter before its forward/backward, adding all-gather collective operations. ZeRO is the technique that made training 70B+ parameter models practical without obscenely large per-node memory, and it's the default distributed training strategy in DeepSpeed and Megatron-LM.
A 7B parameter model is trained with AdamW using float32 for all tensors. What is the minimum GPU memory for just the static training state (parameters + gradients + optimizer states)?

Chapter 7: Full Training Budget — the Showcase

Everything we've learned now combines into a complete training memory budget. Four categories of memory are needed:

Parameters
The model weights themselves. N × bytes/dtype.
+
Gradients
One gradient per parameter, same dtype as parameters. N × bytes/dtype.
+
Optimizer State
AdamW: m + v per parameter, typically float32 regardless of training dtype. 2N × 4 bytes.
+
Activations
Intermediate tensors saved for backward. Scales with batch size × sequence length × d_model × num_layers. Often dominates for large batch sizes.

Activation memory is the tricky one. For a single Transformer layer processing a batch of B sequences of length T with d_model dimensions: the attention scores alone are [B, H, T, T] where H = num heads. With B=32, T=2048, H=32: that's 32×32×2048×2048 = 4.3B values. In fp32 that's 17.2 GB — per layer! This is why activation checkpointing (recomputing activations during backward instead of saving them) is almost always used in practice.

What activations are stored for the backward pass? PyTorch's autograd must store any intermediate value that's needed to compute a gradient. For a matmul Y = X @ W, the gradients are: W.grad = XT @ Y.grad (needs X) and X.grad = Y.grad @ WT (needs W). So the autograd graph must keep both the input X and the weight W until the backward pass. For a 32-layer Transformer with B=8, T=2048, d=4096, the activations per layer include:

Across 32 layers: ~320 GB just for activations in this example. This is unsustainable. Flash Attention eliminates the T×T attention score tensors (recomputes them on-the-fly during backward), cutting activation memory by roughly 60% for large T. Even with Flash Attention, activation checkpointing across block boundaries is typically necessary for batch sizes above a few sequences.

Interactive training memory breakdown

Adjust model and training settings. The stacked bar shows how GPU memory is consumed. Toggle precision to see mixed-precision savings.

Parameters (B) 7 B
Batch × Seq length 32768 tokens
d_model 4096
Layers 32
Mixed precision memory formula. Standard mixed precision (as used in practice): bf16 parameters (2 bytes) + bf16 gradients (2 bytes) + float32 master copy of parameters (4 bytes) + float32 Adam m (4 bytes) + float32 Adam v (4 bytes) = 16 bytes per parameter — same as naive fp32! The memory saving in mixed precision comes from activations (stored in bf16 during forward = 2× reduction), not from the static tensors. The speed saving is the bf16 tensor cores running at 16× the throughput of fp32 tensor cores.

Gradient accumulation as a workaround for small batch sizes. Large effective batch sizes improve gradient quality (less noise) and are often necessary for stable training. But a large batch means more activations, which requires more memory. Solution: accumulate gradients over multiple small micro-batches before taking an optimizer step:

python
accumulation_steps = 8  # effective batch = 8 × micro_batch
optimizer.zero_grad()

for i, (x, y) in enumerate(dataloader):
    with autocast(device_type='cuda', dtype=torch.bfloat16):
        loss = model(x, y) / accumulation_steps  # scale loss!
    loss.backward()  # gradients accumulate

    if (i + 1) % accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
# Memory: only one micro-batch of activations at a time
# Compute: same total FLOPs as large batch, but sequential

The key is dividing the loss by accumulation_steps before backward — otherwise gradients are 8× too large and the effective learning rate is wrong.

Activation checkpointing trades compute for memory. Without it, every intermediate activation tensor is kept alive until its backward pass. With it, you only store the boundary tensors (e.g., input to each Transformer block), and recompute everything inside during backward. Cost: one extra forward pass per checkpointed region, so training takes ~33% longer in compute time. Benefit: activation memory drops from O(L × B × T × d) to O(B × T × d) for L layers — a factor of L reduction. In PyTorch: torch.utils.checkpoint.checkpoint(layer, x) wraps any module. For a 32-layer Transformer, activation checkpointing can reduce activation memory from 128 GB to 4 GB — the difference between needing 2 GPUs and fitting on one.
Standard mixed precision (bf16 forward + fp32 optimizer) uses the same 16 bytes/param as naive fp32. So why use it at all?

Chapter 8: Mixed Precision & the Training Loop

Training with float32 throughout is stable but slow and memory-intensive. Training with bf16 throughout is fast but loses precision for gradient accumulation. Mixed precision training (Micikevicius et al., 2017) gives you the best of both by using different dtypes for different parts of the computation.

The standard recipe:

PyTorch's torch.cuda.amp automates most of this:

python
from torch.cuda.amp import autocast, GradScaler

model = MyModel().to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scaler = GradScaler()  # only needed for fp16, not bf16

for x, y in dataloader:
    optimizer.zero_grad(set_to_none=True)

    # Forward in bf16
    with autocast(device_type='cuda', dtype=torch.bfloat16):
        pred = model(x)
        loss = loss_fn(pred, y)

    # Backward (grads computed, cast to fp32 for optimizer)
    loss.backward()

    # Gradient clipping — essential for stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()

The full training loop in CS336 follows exactly this structure, plus data loading and checkpointing. Checkpointing saves both model state and optimizer state to disk periodically, so a cluster crash doesn't lose days of training. The checkpoint dict is minimal:

python
# Save checkpoint
checkpoint = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "step": current_step,
    "loss": current_loss,
}
torch.save(checkpoint, "checkpoint_step_{step}.pt")

# Resume from checkpoint
ckpt = torch.load("checkpoint_step_1000.pt")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"] + 1

Pinned memory and asynchronous data loading. By default, CPU tensors live in pageable memory — the OS can swap them out. For GPU transfers, x.pin_memory() locks the tensor in physical RAM, enabling asynchronous host-to-device transfer (.to('cuda', non_blocking=True)). This lets you overlap data loading with GPU compute: while the GPU processes batch N, the CPU is already transferring batch N+1. On a compute-intensive training run this can mask the entire data loading cost.

The full training loop, completely written out. Everything from Chapter 0 through Chapter 8 converges here:

python
model = Transformer(d_model=2048, n_layers=24, vocab=32000).to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
data = np.memmap("tokens.npy", dtype=np.int32, mode='r')

for step in range(num_steps):
    # 1. Load batch (async GPU transfer)
    x, y = get_batch(data, batch_size=8, seq_len=2048, device='cuda')

    # 2. Forward pass in bf16
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, 32000), y.view(-1))

    # 3. Backward pass
    loss.backward()

    # 4. Gradient clipping (prevent explosions)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # 5. Optimizer step (in fp32 via AMP)
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    # 6. Logging
    if step % 100 == 0:
        print(f"step {step}: loss {loss.item():.4f}")

    # 7. Checkpoint
    if step % 1000 == 0:
        torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step}, f'ckpt_{step}.pt')

Randomness and reproducibility. Training has randomness in three places: parameter initialization (torch.randn), data ordering (shuffled batches), and any stochastic regularization (dropout). For debugging, you want reproducibility — the exact same run can be reproduced. Set all three random seeds:

python
import torch, numpy as np, random

seed = 42
torch.manual_seed(seed)      # PyTorch CPU + CUDA ops
np.random.seed(seed)          # NumPy ops
random.seed(seed)             # Python built-in random

# For deterministic CUDA operations (slower, but reproducible)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True

Note: even with all seeds set, GPU operations may have non-deterministic ordering if multiple kernels execute in parallel. use_deterministic_algorithms(True) forces sequential execution for those, at a compute cost. Use it during debugging, not for production runs.

Data loading: memmap and pinned memory. LLaMA's training data is 2.8 TB. Loading that into RAM at once is impossible. The solution: np.memmap memory-maps the file, so the OS lazily loads only the pages that are accessed (backed by the actual on-disk file).

python
# Lazy loading of 2.8TB tokenized corpus
data = np.memmap("tokens.npy", dtype=np.int32, mode='r')
# Random batch: sample start indices, slice sequences
starts = np.random.randint(0, len(data)-seq_len, size=(batch_size,))
x = np.stack([data[s:s+seq_len] for s in starts])  # [B, T]
x = torch.tensor(x).pin_memory()   # lock in RAM for fast GPU transfer
x = x.to(device, non_blocking=True)  # async transfer while GPU works
Common misconception: "The training loop is trivial — it's just forward/backward/step." The subtle bugs are everywhere: forgetting zero_grad (gradient accumulates across steps, effectively using a larger batch size with wrong normalization); calling backward() on a non-scalar tensor (raises RuntimeError unless you pass a gradient tensor); missing torch.cuda.synchronize() before timing (CUDA execution is asynchronous — your Python timer sees the kernel-launch time, not the execution time); and checkpointing the wrong optimizer state (resuming with wrong momentum = wrong early steps). One more: forgetting model.train() vs model.eval() mode. Some layers (dropout, batch norm) behave differently in train vs eval — forgetting to switch mode causes mysteriously poor validation metrics while training loss looks fine.
Why is it important to save the optimizer state dict in checkpoints, not just the model state dict?

Chapter 9: Connections & Cheat Sheet

Let's close the loop on the Chapter 0 motivating questions and survey where this knowledge connects to the rest of CS336 and this site.

The napkin-math answers, fully derived.
  • 70B model, 15T tokens, 1024 H100s, MFU=0.5: Total FLOPs = 6×70×109×15×1012 = 6.3×1024. Each H100: 989 TFLOP/s (bf16 dense). Effective: 989×1012×0.5×1024 ≈ 5.06×1017 FLOP/s. Days = 6.3×1024 / (5.06×1017 × 86400) ≈ 144 days. The lecture's 88 days used 50% MFU and 24/7 operation on 1024 H100s.
  • Largest model on 8 H100s, naive fp32 AdamW: 8×80GB = 640GB. At 16 bytes/param = 40B parameters max — before activations.
  • With mixed precision + no activations: Same 16 bytes/param for static state, so still 40B. But bf16 activations at large batch sizes free up headroom to actually do useful batch sizes.

Five Things to Double-Check Before Starting a Training Run

  1. Does it fit in memory? Estimate static memory (16 bytes/param for AdamW fp32) + activation memory (depends on batch). Run torch.cuda.memory_allocated() after the first batch. Leave 10-20% headroom.
  2. Is the FLOPs budget reasonable? 6×N×D ÷ (GPUs × peak × MFU) = days. If it's 6 months, reconsider.
  3. Are seeds set? At least one reproducible run for debugging. Production can be non-deterministic for speed.
  4. Is the data loader the bottleneck? Profile: if GPU utilization drops between batches, your data loading is too slow. Use pin_memory + non_blocking=True + multiple DataLoader workers.
  5. Is checkpoint saving wired up? Before the first GPU-day of compute, test the save/load cycle. Resume from step 10 and verify the loss continues from the same point.

The Resource Accounting Cheat Sheet

WhatFormulaExample (1B params, B=8, T=2048, d=2048, L=24)
Tensor memorynumel × bytes/dtype[1024,1024] fp32 = 4 MB
Forward FLOPs2 × N × D (tokens)2 × 1B × 16K = 32 TFLOP
Backward FLOPs≈ 2 × forward≈ 64 TFLOP
Training FLOPs (total)6 × N × D6 × 1B × 16K = 96 TFLOP
Params memoryN × 4 (fp32)4 GB
Grads memoryN × 4 (fp32)4 GB
Adam m+v2N × 4 (fp32)8 GB
Activations (no ckpt)≈ 2 × B × T × d × L (bytes)2×8×2048×2048×24 × 2 (bf16) ≈ 3.2 GB
MFUactual FLOP/s / peak FLOP/s0.5 is excellent
Matmul FLOPs2×m×k×n[1024,4096]@[4096,4096] = 34 GFLOP

Worked Example: A 1.4B Parameter Llama-like Model

Let's apply everything to a concrete, realistic model: d_model=2048, 24 layers, 32 attention heads, FFN factor=4, vocab=32,000 tokens.

QuantityFormulaValue
Parameters (blocks)12 × d² × L12 × 2048² × 24 = 1.21B
Parameters (embed)V × d32,000 × 2048 = 65.5M
Total parameters~1.28B
Params memory (fp32)N × 41.28B × 4 = 5.1 GB
Params memory (bf16)N × 21.28B × 2 = 2.6 GB
AdamW static (fp32)N × 161.28B × 16 = 20.5 GB
Forward FLOPs/token2N2 × 1.28B = 2.56 GFLOP/tok
Training FLOPs (300B tok)6ND6 × 1.28B × 300B = 2.3×1021 FLOP
Time on 8 H100s (50% MFU)FLOPs / (G × peak × MFU)2.3×1021 / (8 × 989T × 0.5) ≈ ~6.5 days

This is the scale of a "teaching model" — trainable in a week on a single 8-GPU node. GPT-3 (175B) would take 100× more compute; LLaMA-3 70B would take ~55×. The 1.4B range is the sweet spot for CS336 experiments.

Connections to Related Lessons

Compute budget explorer: model size vs tokens (Chinchilla frontier)

Given a fixed compute budget (GPU-days × H100 count), explore the tradeoff between model size N and training tokens D. The Chinchilla rule says optimal D ≈ 20 × N.

H100 count 8
Days 14
MFU (%) 50%

The Optimizer Landscape at a Glance

OptimizerState/paramBytes/param (fp32)Key idea
SGDnone8 (params+grads)w -= lr × g
Momentum SGD1 tensor12Exponential moving avg of grads
AdaGrad1 tensor (G)12Per-param lr: α/√(G+ε)
RMSProp1 tensor (v)12AdaGrad + exponential decay of G
Adam / AdamW2 tensors (m,v)16Momentum + RMSProp + bias correction
Adafactorfactorized~8-10Factored second moment saves memory

Why not use AdaGrad for LLMs? AdaGrad's accumulated squared gradient G grows monotonically — it never decays. After many steps, G is huge, the effective learning rate shrinks toward zero, and learning stalls. RMSProp and Adam fix this by exponentially decaying the second moment, so recent gradients matter more than ancient ones. For LLM training over trillions of tokens (billions of steps), Adam's decaying second moment is essential.

Learning rate schedules. The learning rate is not constant throughout training. A typical schedule: (1) linear warmup from 0 to peak lr over the first ~1000 steps (helps early instability when m and v are cold); (2) cosine decay from peak lr down to 10% of peak over the remaining training steps. The intuition: start slow to let the optimizer "warm up," then decrease gradually so later steps make smaller but more precise updates. Choosing the right peak lr (typically 1e-4 to 3e-4 for LLMs) is done via small ablations — training for a few thousand steps across lr values and selecting the one with the lowest loss.

The discipline of resource accounting. Every design decision in LLM training is ultimately a resource trade-off. Larger model = more parameters to store and update. Longer context = quadratic attention costs + huge activation memory. More layers = more compute but also more gradient checkpointing opportunities. The accounting framework in this lecture is the universal lens. Before adding a feature, ask: how many FLOPs does this add per token? How many bytes per parameter? Does it fit in the budget? Percy Liang's phrase for this mindset is "resource accounting" — the habit of quantifying costs before committing to a design. It's the difference between an engineer and a researcher who hopes their compute budget is sufficient.
One number to remember. If you remember nothing else from this lecture, remember 16 bytes per parameter. It tells you: the minimum GPU memory for a training run (16N bytes static), the cost of Adam vs SGD (+8 bytes/param), why a 70B model needs ~1.1 TB just for optimizer states, and why ZeRO sharding across 64 GPUs brings that to ~17 GB per GPU. Everything else — the 6ND rule, bf16 vs fp32, activation checkpointing — is the arithmetic that flows from understanding memory and compute at this level of precision.
You have a fixed compute budget of 6×1021 FLOPs. According to the Chinchilla rule (optimal: D ≈ 20×N), what is the optimal model size N?