You have 32 H100s for two weeks. How big a model can you train, and how long will it take? This lecture answers those questions from first principles — tensors, dtypes, FLOPs, gradients, optimizers, and the full training loop. Every number derived, every byte accounted for.
You just got access to 32 H100 GPUs for two weeks. Someone asks: "What's the biggest model you can train?" You blink. You have no idea. This is the problem CS336 Lecture 2 solves — and the answer requires understanding every resource your training run consumes, down to the byte.
The resource accounting discipline pays off at every scale. Whether you're fitting a fine-tune into an 8 GB consumer GPU or planning a 10,000-GPU frontier run, the same arithmetic applies. The numbers change; the formulas don't. Once you internalize this framework, every architectural paper you read gives you a new lens: what does this change to N, D, memory, or FLOPs?
Two types of resources govern every training decision: memory (measured in gigabytes, GB) and compute (measured in floating-point operations, FLOPs). They interact in interesting ways. A bigger model uses more memory (caps what fits on one GPU) and more compute (caps how fast you can train). Get the accounting wrong and you either crash out-of-memory or dramatically underestimate training time.
Percy Liang opens this lecture with two napkin-math questions that anchor everything. Before writing a single line of training code, you should be able to answer both in under two minutes.
Where does "6 × N × D" come from? Why 16 bytes per parameter? Why does utilization only reach 50%? The rest of this lesson builds every one of those numbers from first principles. By the end, you'll do this arithmetic in your head.
The lecture is structured as an "executable lecture" — every claim is backed by a Python assertion that runs and verifies. We follow the same spirit: every formula here comes with a worked example and an interactive calculator. The CS336 course uses this resource accounting mindset as the lens for all subsequent design decisions, from architecture tweaks to parallelism strategies. The six interactive calculators in this lesson let you build direct intuition about these numbers — drag the sliders and watch how FLOPs, memory, and time respond to your changes.
Adjust the knobs. Observe how parameters, data, GPU count, and MFU interact to determine training time.
Let's work the math for Question 1 precisely. The H100's peak bf16 dense throughput is 989 TFLOP/s (= 1979/2, since the 1979 figure is for sparse tensors). At 50% MFU across 1,024 GPUs:
The lecture gives 88 days using a slightly different MFU assumption. These estimates always have a range. The point: training frontier models takes months on thousands of GPUs. Resource accounting is what separates a viable training plan from wishful thinking.
Tensors are the universal container for everything in deep learning: parameters, gradients, optimizer states, activations, input data. Understanding them at the storage level is non-negotiable for resource accounting.
Conceptually, a PyTorch tensor is a pointer into a contiguous block of memory plus metadata: the shape (number of elements along each dimension), the stride (how many storage slots to skip to advance one step along each axis), and the dtype (how many bytes each element occupies).
python import torch # Three ways to create a 4×8 tensor x = torch.zeros(4, 8) # shape [4, 8], dtype float32 x = torch.ones(4, 8) # all ones x = torch.randn(4, 8) # iid Normal(0,1) x = torch.empty(4, 8) # allocate, don't initialize # Memory is: numel() × element_size() bytes assert x.numel() == 4 * 8 # = 32 values assert x.element_size() == 4 # float32 = 4 bytes mem = x.numel() * x.element_size() # = 128 bytes
Memory formula: for a tensor of shape [d0, d1, …, dk] with dtype bytes b:
Let's make this concrete with a real model weight. GPT-3's feedforward layers use a weight matrix of shape [12288 × 4, 12288] = [49152, 12288]. That's 49152 × 12288 = 603,979,776 values. In float32 (4 bytes each), that's 2,415,919,104 bytes = 2.3 GB for a single weight matrix. The entire model has many such matrices — that's why GPT-3 at 175B parameters simply cannot fit in one GPU's 80 GB.
Strides explain views. A view is a tensor that shares the same underlying storage as another tensor, just with different metadata. When you index a row (x[0]), transpose (x.T), or reshape (x.view(8, 4)), no memory is copied — only the stride metadata changes. Views are free. Copies are expensive (both memory and compute).
python x = torch.tensor([[0., 1, 2, 3], [4, 5, 6, 7]]) # stride(0)=4: advance dim-0 by skipping 4 storage slots # stride(1)=1: advance dim-1 by skipping 1 storage slot assert x.stride(0) == 4 assert x.stride(1) == 1 # x[1][2] = storage[1*4 + 2*1] = storage[6] = 6.0 y = x[0] # row slice — a view (no copy) z = x.transpose(1, 0) # transpose — a view (no copy) assert y.data_ptr() == x.data_ptr() # same memory! # Transposed tensor is non-contiguous — can't .view() it # Must call .contiguous() first, which DOES copy z_cont = z.contiguous() # allocates new memory
Choose a tensor shape and dtype. See the physical storage layout and the total memory footprint.
Einops notation for readable tensor operations. Traditional PyTorch code uses x.transpose(-2, -1) and relies on you keeping track of what dimension is what. Einops names the dimensions explicitly:
python from einops import einsum, rearrange, reduce # Old way: what is -2, -1 here? x = torch.ones(2, 2, 3) # batch, sequence, hidden z = x @ x.transpose(-2, -1) # confusing # New way: named dimensions z = einsum(x, x, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2") # Rearrange: split a packed head dimension x = torch.ones(2, 3, 8) # batch, seq, (heads*head_dim) x = rearrange(x, "... (heads hd) -> ... heads hd", heads=2) # shape is now [2, 3, 2, 4] # Reduce: mean over the hidden dim y = reduce(x, "batch seq hidden -> batch seq", "mean")
.view() raises a RuntimeError. You fix it with .contiguous(), which silently allocates a new copy of the data. In tight training loops with large activation tensors this hidden copy doubles your memory usage. Always check .is_contiguous() before assuming a view is free. The rule: if you need to both transpose AND reshape (e.g., multi-head attention projections), always call .contiguous() explicitly or switch to .reshape() (which calls .contiguous() internally when needed).Floating-point numbers have a fixed number of bits, divided into a sign bit, exponent bits, and mantissa bits. The tradeoffs between these formats are everything in modern training.
| Dtype | Bits | Sign | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|---|
| float32 | 32 | 1 | 8 | 23 | ±3.4×1038 | ~7 decimal digits |
| float16 | 16 | 1 | 5 | 10 | ±65,504 | ~3 decimal digits |
| bfloat16 | 16 | 1 | 8 | 7 | ±3.4×1038 | ~2 decimal digits |
| fp8 E4M3 | 8 | 1 | 4 | 3 | ±448 | ~1 decimal digit |
| fp8 E5M2 | 8 | 1 | 5 | 2 | ±57,344 | ~0.5 decimal digits |
float32 is the historical default. It never underflows for realistic values, gives 7 decimal digits of precision, but costs 4 bytes per value. In scientific computing this was fine; in deep learning with billions of parameters it dominates memory.
float16 halves memory — 2 bytes — but has a tiny dynamic range. The key danger is underflow: values smaller than ~6×10-5 round to exactly zero. During training, gradients for rarely-seen tokens or in deep networks can be this small. When they underflow, those weights stop updating. Training destabilizes.
python # float16 underflow demonstration x = torch.tensor([1e-8], dtype=torch.float16) print(x) # tensor([0.]) ← UNDERFLOWS to zero! # bfloat16: same exponent bits as float32, safe range x = torch.tensor([1e-8], dtype=torch.bfloat16) print(x) # tensor([1.0000e-08, dtype=torch.bfloat16]) ← fine
bfloat16 is Google Brain's solution (2018). By keeping float32's full 8-bit exponent, it preserves the dynamic range exactly. The trade-off: the mantissa drops from 23 to 7 bits, so precision is much lower (≈2 decimal digits). For gradient descent this is acceptable — the stochastic noise in mini-batch gradients already swamps that level of precision. bfloat16 is now the standard for LLM training forward passes.
fp8, standardized in 2022, pushes further: 1 byte per value. H100s have native fp8 hardware support. Two variants: E4M3 (more precision, less range, used for forward/activations) and E5M2 (more range, less precision, used for backward/gradients). At fp8, quantization errors become significant; you need scaling factors and careful monitoring. The payoff: 8× the throughput of fp32 on tensor cores.
Visual comparison of the five key formats. The bar shows the usable exponent range (log-scale). Hover each dtype to see implications for training.
Why deep learning is less sensitive to precision than scientific computing. In scientific computing, you might simulate a physical system where floating-point errors accumulate across millions of time steps — you need 15 significant digits (float64). In deep learning, your gradient is already a noisy estimator of the true gradient (mini-batch stochastic sampling introduces variance). Adding 2-digit precision noise to a quantity that already has noise from sampling is insignificant. The stochastic gradient noise swamps the arithmetic precision noise. This is why bf16 — only 2 significant digits — works fine while float64 would triple memory usage for no benefit.
A FLOP (floating-point operation) is one multiply or one add on a float. It's the fundamental unit of compute. We use it to estimate how long training will take, compare hardware, and decide whether an architectural choice is "expensive."
Two acronyms that sound identical but mean different things: FLOPs (with a lowercase "s") = floating-point operations, a count of work done; FLOP/s (or FLOPS) = floating-point operations per second, a measure of speed.
To build intuition: Training GPT-3 took ~3.14×1023 FLOPs. The US executive order (2023) flagged models trained with ≥1026 FLOPs for government reporting. H100 peak: ~989 TFLOP/s (dense bf16) = 989×1012 FLOPs/s.
Matrix multiplication dominates. For a matmul C = A @ B where A is [m × k] and B is [k × n]: each output element C[i,j] = ∑l A[i,l] × B[l,j] requires k multiplications and k−1 additions ≈ 2k FLOPs. There are m×n outputs, so total FLOPs = 2×m×k×n.
For a batch of B examples, each of dimension D, with weight matrix [D × K]: that's a [B×D] @ [D×K] matmul = 2×B×D×K FLOPs. Note this equals 2 × (# data points) × (# parameters) — the forward pass costs 2×N×D for a model with N parameters on D tokens.
Elementwise operations are cheap. Adding two [m×n] matrices: m×n FLOPs. Applying ReLU to a [m×n] matrix: m×n FLOPs. Softmax over a sequence of T tokens: O(T) per token. These are negligible compared to the matmuls in attention and the feedforward layers, which scale as O(d2) per token.
Batched and higher-dimensional matmuls. PyTorch's @ operator broadcasts over batch dimensions. Given x of shape [B, T, D] and W of shape [D, K]: the result is [B, T, K] and costs 2×(B×T)×D×K FLOPs. This is the fundamental operation in every Transformer: the Q, K, V projections are batched matmuls over all tokens simultaneously, making GPU parallelism trivially exploitable.
python # Batch matmul: broadcast over first 2 dims x = torch.ones(4, 8, 16, 32) # [batch, seq, ?, d_in] w = torch.ones(32, 2) # [d_in, d_out] y = x @ w # [4, 8, 16, 2] # Iterates over dims 0 and 1, multiplies each [16,32]@[32,2] # Total FLOPs: 2 * 4 * 8 * 16 * 32 * 2 = 131,072 # Timing a matmul on GPU (with proper synchronization!) if torch.cuda.is_available(): torch.cuda.synchronize() # wait for previous ops to finish t0 = time.perf_counter() y = x_gpu @ w_gpu torch.cuda.synchronize() # wait for this op to finish t1 = time.perf_counter() actual_flops_per_sec = flops / (t1 - t0)
The synchronize calls are critical. CUDA launches are asynchronous — the Python line y = x @ w returns immediately (just enqueues the kernel), and the actual GPU computation happens in parallel. Without synchronize(), you're timing kernel launch latency (≈microseconds), not execution time (≈milliseconds for large matmuls).
Model FLOPs Utilization (MFU) measures how efficiently you're using the hardware:
A100 peak: 312 TFLOP/s (bf16); 19.5 TFLOP/s (fp32). H100: 989 TFLOP/s (bf16 dense); 67.5 TFLOP/s (fp32). The factor-of-16 difference between bf16 and fp32 on H100 is a core reason to use mixed precision — it directly multiplies your effective compute budget.
Enter a matmul shape. See the FLOPs, and the theoretical time on A100 vs H100 in fp32 vs bf16.
Scaling laws and the 6×N×D rule. This formula generalizes far beyond a simple linear model. For a Transformer with N parameters and trained on D tokens, the total FLOPs is approximately 6ND. Where does this come from? Each Transformer block has attention and feedforward layers. The dominant cost is the weight matmuls in the feedforward layers (two per block: [d, 4d] and [4d, d]) and the attention projections (four per block: Q, K, V, O). Summing those matmuls over all B×T tokens and L layers, the total forward FLOPs is approximately 2ND (where N ≈ 12d2L as we showed in Chapter 5). Backward adds 4ND. Total: 6ND. This is a first-order approximation that ignores attention's O(T2) cost — valid when d≫T, which holds for large models on moderate context lengths.
The forward pass produces a loss. The backward pass answers: "for each parameter w, how does changing w by ε change the loss?" That derivative, d(loss)/dw, is the gradient of w. PyTorch's autograd engine computes these automatically — but you should understand what it's doing and how much it costs.
To compute gradients, you first flag the tensors you care about with requires_grad=True. PyTorch then builds a computation graph as operations execute — every tensor remembers which operation created it and from what inputs. Calling .backward() on the loss traverses this graph in reverse, applying the chain rule.
python # Simple linear model: loss = 0.5 * (x·w - 5)² x = torch.tensor([1., 2, 3]) w = torch.tensor([1., 1, 1], requires_grad=True) # Forward pass pred_y = x @ w # scalar: 1+2+3 = 6 loss = 0.5 * (pred_y - 5).pow(2) # 0.5*(6-5)² = 0.5 # Backward pass — fills .grad on all requires_grad tensors loss.backward() # Chain rule: d(loss)/dw_i = (pred_y - 5) * x_i # = (6 - 5) * [1, 2, 3] = [1, 2, 3] assert torch.equal(w.grad, torch.tensor([1., 2, 3])) # Crucially: loss.grad is None — only leaf tensors (with # requires_grad=True) get .grad filled in by default assert loss.grad is None
How many FLOPs does the backward pass cost? Let's trace through a two-layer network: x →W1→ h1 →W2→ h2 → loss. W1 is [D×D], W2 is [D×K], inputs batch [B×D].
To compute W2.grad = d(loss)/dW2, we need: W2.grad[j,k] = ∑i h1[i,j] × h2.grad[i,k]. That's a [D×B] @ [B×K] matmul = 2×B×D×K FLOPs — exactly the same as the forward pass through W2.
To compute h1.grad (needed for W1's gradient): h1.grad[i,j] = ∑k W2[j,k] × h2.grad[i,k]. That's another [B×K] @ [K×D] matmul = 2×B×D×K FLOPs.
So for W2 alone, the backward costs 2× the forward. By the same logic for W1, the full network costs:
Putting it all together for the two-layer model:
And therefore:
Per token (dividing by B), per step: 6N FLOPs. Over D total tokens: 6 × N × D.
What tensors get gradients? Only tensors that are reachable in the computation graph from a requires_grad=True leaf tensor accumulate gradients. Leaf tensors are tensors you created directly (not the result of an operation). In a model, parameters are the leaf tensors with requires_grad=True. Non-leaf tensors (activations) get gradients only if you explicitly call .retain_grad() before backward — otherwise they're freed to save memory.
python x = torch.tensor([1., 2, 3]) # leaf, no grad w = torch.tensor([1., 1, 1], requires_grad=True) # leaf, wants grad h = x @ w # non-leaf; h.requires_grad=True (inherited) h.retain_grad() # ask PyTorch to keep h.grad after backward loss = 0.5 * (h - 5).pow(2) loss.backward() print(w.grad) # [1., 2., 3.] (chain rule: (h-5)*x) print(h.grad) # [1.0] (d(loss)/dh = h-5 = 6-5 = 1) print(x.grad) # None (x is a leaf without requires_grad) print(loss.grad) # None (non-leaf, no retain_grad)
Gradient clipping is not optional for LLMs. In deep networks, gradients can explode — a single bad batch or numerical near-singularity can produce gradients of magnitude 1,000,000, which would destroy training. Gradient clipping caps the L2 norm of all gradients:
Typical values: max_norm=1.0. This ensures no single optimizer step moves parameters more than a bounded amount. The PyTorch function clips all parameter gradients jointly (not per-layer): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0). Call it after backward, before step.
optimizer.zero_grad() before backward()." Actually, it must be called after optimizer.step(). If you call it before backward, you're fine. But the real concern is that PyTorch accumulates gradients — calling backward twice without zeroing doubles the gradient. The idiomatic pattern: forward → backward → step → zero_grad. Use set_to_none=True (faster than filling with zeros, frees memory). Also: gradient accumulation (for large effective batch sizes) intentionally skips zero_grad between micro-batches — this is correct behavior, not a bug. Just make sure to scale the loss by 1/accumulation_steps.In PyTorch, model parameters are stored as nn.Parameter objects — they're just tensors with a flag saying "please compute my gradient and include me in optimizer updates." Grouping parameters into nn.Modules gives you structured naming, easy device movement, and state_dict serialization.
python import torch.nn as nn # A linear layer with Xavier-like initialization class Linear(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() # Scale by 1/sqrt(in_dim) to keep output variance ~1 self.weight = nn.Parameter( torch.randn(in_dim, out_dim) / (in_dim ** 0.5) ) def forward(self, x): return x @ self.weight # Count parameters in a model model = Linear(1024, 512) num_params = sum(p.numel() for p in model.parameters()) # 1024 * 512 = 524,288 parameters
Why divide by √d? If you initialize W with iid Normal(0,1) entries and compute x @ W where x is a d-dimensional vector, each output element y[k] = ∑j x[j] × W[j,k] is a sum of d independent terms. By the central limit theorem, Var(y[k]) ≈ d × Var(x[j]). So y has standard deviation √d times larger than x — and that scaling factor compounds through every layer. With even 10 layers, activations explode or vanish. Dividing by √d (Xavier initialization) keeps variance constant across layers.
Let's count parameters for a realistic model. A Transformer block has:
For a 1.4B model with d_model=2048 and 24 layers: 12 × 20482 × 24 = 1.21B (the rest comes from the embedding tables). This matches! The formula 12d2L is surprisingly accurate for large Transformers where d≫embedding dims.
Adjust model dimensions. See how parameters distribute across Transformer components.
How to count parameters for a real Transformer. Let's be explicit. A GPT-style model with d_model=1024, 12 layers, 16 attention heads (head dim=64), FFN dim=4096, vocab size=50,257:
Check: 12d2L = 12×10242×12 = 150.9M (matching the block params!). The formula works.
model.state_dict() returns an ordered dict of name → tensor for every parameter and buffer. This is what you save in a checkpoint and what you load to resume. The key is the dotted path ("layers.0.weight"), reflecting the module hierarchy. When parameters have identical names across layers, they're separate tensors — PyTorch never shares weights unless you explicitly tie them (some models tie the input embedding to the output projection, saving d_model × vocab_size parameters). With weight tying in our 204M model: saving 53.6M parameters = 51.9M × 4 bytes = 207MB. In practice this is the first optimization applied to smaller models.Stochastic gradient descent (SGD) is the simplest optimizer: w ← w − α∇w. But modern LLM training uses AdamW, which requires maintaining additional per-parameter statistics. Understanding why, and how much extra memory it costs, is essential for training budget calculations.
Why not plain SGD? SGD uses the same learning rate for every parameter. Early in training, parameters in dense layers get huge gradients; parameters in the embedding table (for rare tokens) might see zero gradient for thousands of steps. A fixed α either overshoots the dense layers or barely moves the sparse ones.
AdaGrad (Duchi et al., 2011) tracks the sum of squared gradients per parameter, G[i] = ∑t<τ gt[i]2, and divides the update by √G[i]. Parameters that received large gradients slow down; rarely-updated parameters speed up. This is adaptive per-coordinate learning rates.
Adam extends this with momentum (exponential moving average of gradients, mt) and adapts with an exponential moving average of squared gradients (vt), plus bias-correction. AdamW adds decoupled weight decay. The update rule:
Each parameter has two optimizer state tensors: m (first moment) and v (second moment), both the same dtype as the parameter. In float32, that's 4+4+4 = 12 bytes per parameter just for the model + optimizer state, plus another 4 bytes for the gradient = 16 bytes per parameter in naive float32 AdamW.
python # Memory accounting for a 1B parameter model with naive AdamW N = 1e9 # parameters mem_params = N * 4 # float32 weights = 4 GB mem_grads = N * 4 # float32 gradients = 4 GB mem_adam_m = N * 4 # float32 m states = 4 GB mem_adam_v = N * 4 # float32 v states = 4 GB total = mem_params + mem_grads + mem_adam_m + mem_adam_v # = 16 GB — plus activations, which depend on batch size! print(f"Static memory: {total/1e9:.0f} GB") # 16 GB
For a 70B parameter model: 70B × 16 bytes = 1,120 GB just for static state, before a single activation. H100s have 80 GB each — you'd need 14 just for the static part. This motivates mixed precision (store weights in bf16 = 8 bytes per param) and ZeRO (shard optimizer state across GPUs).
Optimizer state accumulation across training. In the AdamW update, β1 and β2 control how quickly old gradient information decays. Typical values: β1=0.9, β2=0.999. After t steps, the effective "memory" of the optimizer extends back roughly 1/(1-β) steps: ≈10 steps for m, ≈1000 steps for v. This means:
Everything we've learned now combines into a complete training memory budget. Four categories of memory are needed:
Activation memory is the tricky one. For a single Transformer layer processing a batch of B sequences of length T with d_model dimensions: the attention scores alone are [B, H, T, T] where H = num heads. With B=32, T=2048, H=32: that's 32×32×2048×2048 = 4.3B values. In fp32 that's 17.2 GB — per layer! This is why activation checkpointing (recomputing activations during backward instead of saving them) is almost always used in practice.
What activations are stored for the backward pass? PyTorch's autograd must store any intermediate value that's needed to compute a gradient. For a matmul Y = X @ W, the gradients are: W.grad = XT @ Y.grad (needs X) and X.grad = Y.grad @ WT (needs W). So the autograd graph must keep both the input X and the weight W until the backward pass. For a 32-layer Transformer with B=8, T=2048, d=4096, the activations per layer include:
Across 32 layers: ~320 GB just for activations in this example. This is unsustainable. Flash Attention eliminates the T×T attention score tensors (recomputes them on-the-fly during backward), cutting activation memory by roughly 60% for large T. Even with Flash Attention, activation checkpointing across block boundaries is typically necessary for batch sizes above a few sequences.
Adjust model and training settings. The stacked bar shows how GPU memory is consumed. Toggle precision to see mixed-precision savings.
Gradient accumulation as a workaround for small batch sizes. Large effective batch sizes improve gradient quality (less noise) and are often necessary for stable training. But a large batch means more activations, which requires more memory. Solution: accumulate gradients over multiple small micro-batches before taking an optimizer step:
python accumulation_steps = 8 # effective batch = 8 × micro_batch optimizer.zero_grad() for i, (x, y) in enumerate(dataloader): with autocast(device_type='cuda', dtype=torch.bfloat16): loss = model(x, y) / accumulation_steps # scale loss! loss.backward() # gradients accumulate if (i + 1) % accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() # Memory: only one micro-batch of activations at a time # Compute: same total FLOPs as large batch, but sequential
The key is dividing the loss by accumulation_steps before backward — otherwise gradients are 8× too large and the effective learning rate is wrong.
torch.utils.checkpoint.checkpoint(layer, x) wraps any module. For a 32-layer Transformer, activation checkpointing can reduce activation memory from 128 GB to 4 GB — the difference between needing 2 GPUs and fitting on one.Training with float32 throughout is stable but slow and memory-intensive. Training with bf16 throughout is fast but loses precision for gradient accumulation. Mixed precision training (Micikevicius et al., 2017) gives you the best of both by using different dtypes for different parts of the computation.
The standard recipe:
PyTorch's torch.cuda.amp automates most of this:
python from torch.cuda.amp import autocast, GradScaler model = MyModel().to('cuda') optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) scaler = GradScaler() # only needed for fp16, not bf16 for x, y in dataloader: optimizer.zero_grad(set_to_none=True) # Forward in bf16 with autocast(device_type='cuda', dtype=torch.bfloat16): pred = model(x) loss = loss_fn(pred, y) # Backward (grads computed, cast to fp32 for optimizer) loss.backward() # Gradient clipping — essential for stability torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()
The full training loop in CS336 follows exactly this structure, plus data loading and checkpointing. Checkpointing saves both model state and optimizer state to disk periodically, so a cluster crash doesn't lose days of training. The checkpoint dict is minimal:
python # Save checkpoint checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": current_step, "loss": current_loss, } torch.save(checkpoint, "checkpoint_step_{step}.pt") # Resume from checkpoint ckpt = torch.load("checkpoint_step_1000.pt") model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) start_step = ckpt["step"] + 1
Pinned memory and asynchronous data loading. By default, CPU tensors live in pageable memory — the OS can swap them out. For GPU transfers, x.pin_memory() locks the tensor in physical RAM, enabling asynchronous host-to-device transfer (.to('cuda', non_blocking=True)). This lets you overlap data loading with GPU compute: while the GPU processes batch N, the CPU is already transferring batch N+1. On a compute-intensive training run this can mask the entire data loading cost.
The full training loop, completely written out. Everything from Chapter 0 through Chapter 8 converges here:
python model = Transformer(d_model=2048, n_layers=24, vocab=32000).to('cuda') optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1) data = np.memmap("tokens.npy", dtype=np.int32, mode='r') for step in range(num_steps): # 1. Load batch (async GPU transfer) x, y = get_batch(data, batch_size=8, seq_len=2048, device='cuda') # 2. Forward pass in bf16 with torch.autocast(device_type='cuda', dtype=torch.bfloat16): logits = model(x) loss = F.cross_entropy(logits.view(-1, 32000), y.view(-1)) # 3. Backward pass loss.backward() # 4. Gradient clipping (prevent explosions) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 5. Optimizer step (in fp32 via AMP) optimizer.step() optimizer.zero_grad(set_to_none=True) # 6. Logging if step % 100 == 0: print(f"step {step}: loss {loss.item():.4f}") # 7. Checkpoint if step % 1000 == 0: torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step}, f'ckpt_{step}.pt')
Randomness and reproducibility. Training has randomness in three places: parameter initialization (torch.randn), data ordering (shuffled batches), and any stochastic regularization (dropout). For debugging, you want reproducibility — the exact same run can be reproduced. Set all three random seeds:
python import torch, numpy as np, random seed = 42 torch.manual_seed(seed) # PyTorch CPU + CUDA ops np.random.seed(seed) # NumPy ops random.seed(seed) # Python built-in random # For deterministic CUDA operations (slower, but reproducible) torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True
Note: even with all seeds set, GPU operations may have non-deterministic ordering if multiple kernels execute in parallel. use_deterministic_algorithms(True) forces sequential execution for those, at a compute cost. Use it during debugging, not for production runs.
Data loading: memmap and pinned memory. LLaMA's training data is 2.8 TB. Loading that into RAM at once is impossible. The solution: np.memmap memory-maps the file, so the OS lazily loads only the pages that are accessed (backed by the actual on-disk file).
python # Lazy loading of 2.8TB tokenized corpus data = np.memmap("tokens.npy", dtype=np.int32, mode='r') # Random batch: sample start indices, slice sequences starts = np.random.randint(0, len(data)-seq_len, size=(batch_size,)) x = np.stack([data[s:s+seq_len] for s in starts]) # [B, T] x = torch.tensor(x).pin_memory() # lock in RAM for fast GPU transfer x = x.to(device, non_blocking=True) # async transfer while GPU works
zero_grad (gradient accumulates across steps, effectively using a larger batch size with wrong normalization); calling backward() on a non-scalar tensor (raises RuntimeError unless you pass a gradient tensor); missing torch.cuda.synchronize() before timing (CUDA execution is asynchronous — your Python timer sees the kernel-launch time, not the execution time); and checkpointing the wrong optimizer state (resuming with wrong momentum = wrong early steps). One more: forgetting model.train() vs model.eval() mode. Some layers (dropout, batch norm) behave differently in train vs eval — forgetting to switch mode causes mysteriously poor validation metrics while training loss looks fine.Let's close the loop on the Chapter 0 motivating questions and survey where this knowledge connects to the rest of CS336 and this site.
torch.cuda.memory_allocated() after the first batch. Leave 10-20% headroom.| What | Formula | Example (1B params, B=8, T=2048, d=2048, L=24) |
|---|---|---|
| Tensor memory | numel × bytes/dtype | [1024,1024] fp32 = 4 MB |
| Forward FLOPs | 2 × N × D (tokens) | 2 × 1B × 16K = 32 TFLOP |
| Backward FLOPs | ≈ 2 × forward | ≈ 64 TFLOP |
| Training FLOPs (total) | 6 × N × D | 6 × 1B × 16K = 96 TFLOP |
| Params memory | N × 4 (fp32) | 4 GB |
| Grads memory | N × 4 (fp32) | 4 GB |
| Adam m+v | 2N × 4 (fp32) | 8 GB |
| Activations (no ckpt) | ≈ 2 × B × T × d × L (bytes) | 2×8×2048×2048×24 × 2 (bf16) ≈ 3.2 GB |
| MFU | actual FLOP/s / peak FLOP/s | 0.5 is excellent |
| Matmul FLOPs | 2×m×k×n | [1024,4096]@[4096,4096] = 34 GFLOP |
Let's apply everything to a concrete, realistic model: d_model=2048, 24 layers, 32 attention heads, FFN factor=4, vocab=32,000 tokens.
| Quantity | Formula | Value |
|---|---|---|
| Parameters (blocks) | 12 × d² × L | 12 × 2048² × 24 = 1.21B |
| Parameters (embed) | V × d | 32,000 × 2048 = 65.5M |
| Total parameters | — | ~1.28B |
| Params memory (fp32) | N × 4 | 1.28B × 4 = 5.1 GB |
| Params memory (bf16) | N × 2 | 1.28B × 2 = 2.6 GB |
| AdamW static (fp32) | N × 16 | 1.28B × 16 = 20.5 GB |
| Forward FLOPs/token | 2N | 2 × 1.28B = 2.56 GFLOP/tok |
| Training FLOPs (300B tok) | 6ND | 6 × 1.28B × 300B = 2.3×1021 FLOP |
| Time on 8 H100s (50% MFU) | FLOPs / (G × peak × MFU) | 2.3×1021 / (8 × 989T × 0.5) ≈ ~6.5 days |
This is the scale of a "teaching model" — trainable in a week on a single 8-GPU node. GPT-3 (175B) would take 100× more compute; LLaMA-3 70B would take ~55×. The 1.4B range is the sweet spot for CS336 experiments.
Given a fixed compute budget (GPU-days × H100 count), explore the tradeoff between model size N and training tokens D. The Chinchilla rule says optimal D ≈ 20 × N.
| Optimizer | State/param | Bytes/param (fp32) | Key idea |
|---|---|---|---|
| SGD | none | 8 (params+grads) | w -= lr × g |
| Momentum SGD | 1 tensor | 12 | Exponential moving avg of grads |
| AdaGrad | 1 tensor (G) | 12 | Per-param lr: α/√(G+ε) |
| RMSProp | 1 tensor (v) | 12 | AdaGrad + exponential decay of G |
| Adam / AdamW | 2 tensors (m,v) | 16 | Momentum + RMSProp + bias correction |
| Adafactor | factorized | ~8-10 | Factored second moment saves memory |
Why not use AdaGrad for LLMs? AdaGrad's accumulated squared gradient G grows monotonically — it never decays. After many steps, G is huge, the effective learning rate shrinks toward zero, and learning stalls. RMSProp and Adam fix this by exponentially decaying the second moment, so recent gradients matter more than ancient ones. For LLM training over trillions of tokens (billions of steps), Adam's decaying second moment is essential.
Learning rate schedules. The learning rate is not constant throughout training. A typical schedule: (1) linear warmup from 0 to peak lr over the first ~1000 steps (helps early instability when m and v are cold); (2) cosine decay from peak lr down to 10% of peak over the remaining training steps. The intuition: start slow to let the optimizer "warm up," then decrease gradually so later steps make smaller but more precise updates. Choosing the right peak lr (typically 1e-4 to 3e-4 for LLMs) is done via small ablations — training for a few thousand steps across lr values and selecting the one with the lowest loss.