Language Modeling from Scratch · CS336 · Lecture 5

GPUs: The Hardware Under Everything

Why is my GPU at 5% utilization? Start from transistors — learn the SM/warp/thread model, the memory hierarchy, arithmetic intensity, the roofline model, coalescing, tiling, and why minimizing data movement is the single most important skill in GPU programming. Build up to Flash Attention from first principles.

Prerequisites: CS336 Lec 2 resource accounting (FLOPs, memory). Matrix multiplication basics.
10
Chapters
5
Live Canvases
A100/H100
Real Specs

Chapter 0: The Bottleneck

You buy an A100 GPU. It costs $10,000. The spec sheet says 312 TFLOP/s of bfloat16 compute. You run your transformer training loop. The GPU profiler shows 5-8% utilization. You are wasting 92-95% of what you paid for.

This is not unusual. It is the default. And the reason is almost never "not enough math" — it is almost always data starvation. Your compute units are waiting, hungry, for data that hasn't arrived from memory yet. They sit idle while memory buses shuffle bytes. The compute is fast; the data pipeline is not.

Every technique in this lecture — coalescing, tiling, operator fusion, Flash Attention — attacks the same root cause: too many trips to slow memory. Before any of those techniques make sense, you need to understand the hardware well enough to feel the bottleneck yourself.

The warehouse analogy. Think of a GPU like a factory floor. The workers (compute cores) are extremely fast. But the raw materials (data) live in a distant warehouse (DRAM/HBM). Every time a worker needs new material, a truck must drive to the warehouse and back. If the workers are faster than the trucks, most workers sit idle waiting for deliveries. That is your 5% utilization. The goal of all GPU optimization is to reduce truck trips — by pre-staging materials in a local storage room (SRAM) that workers can reach instantly.

By the end of this lesson you will be able to look at any operation — a matrix multiply, an elementwise ReLU, a softmax — and immediately know: is this compute-bound or memory-bound? How many bytes must move per FLOP? Where on the roofline does it land? And what technique (fusion, tiling, quantization) would help most.

The utilization gap: compute vs memory bandwidth

An A100 has 312 TFLOP/s (bf16) but only 2 TB/s of memory bandwidth. Drag the slider to set arithmetic intensity (FLOPs/byte) and see how much compute you actually get.

Arithmetic intensity (FLOPs/byte) 4
A GPU at "5% utilization" most likely means:

Chapter 1: GPU vs CPU

A CPU is a sprinter: a small number of extremely fast, sophisticated cores designed to run one thread as quickly as possible. An Intel i9 has 24 physical cores. Each can execute instructions out-of-order, perform branch prediction, handle complex control flow, and has megabytes of L3 cache to absorb memory latency. The design goal is latency: finish one piece of work as fast as possible.

A GPU is a marathon relay team: thousands of small, simple cores designed to run many threads simultaneously. An A100 has 6912 CUDA cores organized into 108 Streaming Multiprocessors (SMs). Each core is relatively weak — no branch prediction, limited cache, simple in-order execution. But you can run 6912 threads in parallel, each doing a different piece of the same computation. The design goal is throughput: finish the maximum amount of total work per second.

Why does this matter for ML? A transformer forward pass on a sequence of 2048 tokens with a 4096-dim hidden state involves millions of independent multiplications in every matrix multiply. Those multiplications don't depend on each other. A CPU must execute them largely sequentially, wasting its sophistication on trivially-parallel work. A GPU runs them all at once.

The key tradeoff. CPUs use most of their die area on caches, branch predictors, and out-of-order execution engines — all designed to hide latency and speed up sequential code. GPUs use most of their die area on arithmetic units — floating-point multipliers and adders. A GPU die is like a field of calculators. A CPU die is like a small army of highly-trained accountants with filing systems. If you have ten billion independent multiplications to do, the calculator field wins by a thousand to one.

The SIMT model (Single Instruction, Multiple Threads) is the GPU execution model. Every thread in a warp (a group of 32 threads) executes the same instruction simultaneously — but on different data. Think of it like a marching band: every player performs the same move at the same time, but each is standing in a different position on the field. Your code looks like a single-threaded program, but the GPU secretly runs it on 32 inputs in parallel.

This is powerful but has a sharp edge: if threads in a warp take different paths through an if-else branch (control divergence), the GPU must execute both branches — first with threads satisfying the condition active, then with the others active, while the first group waits. You get no speedup from the parallelism for that section. Well-written GPU code avoids branches inside warps, or ensures all threads take the same branch.

Tensor Cores: the matmul accelerator. Starting with the Volta architecture (V100), NVIDIA added Tensor Cores — specialized 4×4 matrix multiplication circuits that operate 16× faster than ordinary CUDA cores for matmul. On an A100: regular CUDA cores do ~19.5 TFLOP/s (fp32), but Tensor Cores do 312 TFLOP/s (bf16) — a 16× speedup. This is why every serious ML workload uses mixed precision. Your matmuls land on Tensor Cores; other ops (activations, normalization) run on regular CUDA cores and are 16× slower per FLOP.
Why does the SIMT execution model suffer from "control divergence"?

Chapter 2: The Hardware Model

To write fast GPU code, you need a mental map of how the hardware is organized. The GPU has a three-level execution hierarchy that mirrors its three-level memory hierarchy. Understanding both simultaneously is the key to everything else in this lecture.

The top level is the GPU device. An A100 has 108 Streaming Multiprocessors (SMs). Each SM is an independent processing unit that can run blocks of threads simultaneously. SMs don't share data with each other except by writing to global memory (HBM). If you launch a kernel with 1000 thread blocks, the SM scheduler distributes them across the 108 SMs — some SMs get 9 blocks, some get 10, and they all run concurrently.

Inside each SM are CUDA cores and Tensor Cores. An A100 SM has 64 CUDA cores for general fp32 arithmetic and 4 Tensor Core units for matmul. Threads within a block share the SM's resources — most importantly, its shared memory (SRAM), which is like a programmable L1 cache. Threads in the same block can communicate through shared memory without going to global memory. Threads in different blocks cannot share data this way.

The bottom level is the warp: a group of exactly 32 consecutively numbered threads that the SM executes together. When you launch a thread block of size 256, the SM splits it into 256/32 = 8 warps and schedules them. Warps are the atomic unit of scheduling — the SM either runs all 32 threads of a warp this cycle, or none of them. This is why warp-level thinking matters for optimization.

GPU Device (1 chip)
A100: 108 SMs, HBM with 2 TB/s bandwidth, 80 GB capacity
↓ contains
Streaming Multiprocessor (SM)
64 CUDA cores + 4 Tensor Cores + 192 KB shared memory/L1
↓ schedules
Thread Block
Up to 1024 threads; shares SM resources and SRAM with other blocks on same SM
↓ subdivided into
Warp (32 threads)
All 32 execute the same instruction; atomic unit of scheduling and memory access
↓ each thread has
Registers (~255 per thread)
Fastest memory — on-die, zero-latency. Shared memory shared by block. HBM shared by all.
Occupancy: the hidden lever. Each SM has a fixed resource budget: ~2048 registers per SM, 192 KB shared memory, and can run at most 2048 threads simultaneously. If your kernel uses too many registers per thread (say 128), fewer threads can co-reside on the SM — occupancy drops. Low occupancy means fewer warps to hide latency: when one warp is waiting for a memory read, the SM switches to another warp to stay busy. With only 2 warps resident instead of 64, there's nobody to switch to, and the SM stalls. High occupancy ≠ high performance, but low occupancy often causes memory stalls.
GPU execution hierarchy: SMs, blocks, warps, threads

Click an SM to see its blocks. Click a block to see its warps. Numbers reflect real A100 specs.

A thread block has 256 threads. How many warps does the SM create from this block, and can threads in different warps of the same block communicate through shared memory?

Chapter 3: Memory Hierarchy

Every GPU has multiple memory tiers. They differ in three ways: capacity (how much you can store), bandwidth (how fast you can read/write), and latency (how long a single read takes). The faster the memory, the smaller and more expensive it is. This is the fundamental tension that GPU optimization must navigate.

Starting closest to the compute and working outward:

Registers are on-die flip-flops inside each SM. Access is instantaneous — zero clock cycles of latency. An A100 has 65,536 registers per SM, each 32 bits. With 2048 threads per SM, that's 32 registers per thread (you can actually use up to 255, but using more means fewer threads can fit). Registers hold the thread's own working variables. No sharing, no bus, no wait.

Shared memory (SRAM) sits inside the SM, shared among all threads in a block. An A100 has 192 KB of combined L1 cache + shared memory per SM (configurable split). Bandwidth is enormous — around 19 TB/s per SM, or ~30× faster than HBM at the SM level. Latency is ~20 clock cycles. This is the "local storage room" in the factory analogy. The key discipline: you must explicitly load data from HBM into SRAM (using cooperative loading across threads in a block), then compute from SRAM, then write results back.

L2 cache is on-chip (but off-SM), shared by all SMs. An A100 has 40 MB of L2. Bandwidth ~12 TB/s for the whole device. Automatically managed by hardware — you don't program it directly, but L2 hits matter for repeated accesses to the same data.

HBM (High Bandwidth Memory) is the main GPU DRAM — stacked memory chips mounted directly on the GPU package. An A100 has 80 GB at 2 TB/s bandwidth. An H100 SXM5 has 80 GB at ~3.35 TB/s. This is the "warehouse." It's fast compared to system RAM (CPU DRAM is ~50 GB/s), but tiny compared to SRAM bandwidth, and each access takes ~600-700 clock cycles. Almost every GPU performance problem traces back to too many HBM accesses.

Memory tierCapacity (A100)BandwidthLatencyScope
Registers256 KB / SM (~32 reg/thread)∞ (direct)~1 cyclePer-thread
Shared (SRAM)192 KB / SM~19 TB/s / SM~20 cyclesPer-block
L2 cache40 MB~12 TB/s~200 cyclesAll SMs
HBM (DRAM)80 GB2 TB/s (A100) / 3.35 TB/s (H100)~600 cyclesAll SMs
SRAM is ~100× more expensive than DRAM. Per bit, SRAM requires 6 transistors and occupies far more die area. DRAM uses 1 transistor + 1 capacitor. That is why an A100 has 80 GB of HBM but only 192 KB/SM of SRAM. The scarcity of SRAM is what makes tiling (Chapter 6) non-trivial: you must carefully partition work into chunks that fit in the tiny fast memory. If matrix tiles don't fit in SRAM, you fall back to HBM, and performance collapses.
Memory hierarchy pyramid (click a tier to explore)

Pyramid area ∝ capacity. Color intensity = bandwidth. Click/hover a tier for full specs.

Why can the A100 have 80 GB of HBM but only a few hundred KB of SRAM per SM?

Chapter 4: Arithmetic Intensity

Here is the central question of GPU performance: for a given computation, how many floating-point operations do you do per byte of memory traffic? This ratio is called arithmetic intensity (also "compute intensity" or "FLOP/byte ratio"), and it determines whether your operation is limited by compute or by memory bandwidth.

Arithmetic Intensity = FLOPs performed ÷ Bytes moved to/from HBM

Let's derive it for the two most common operations in a transformer.

Elementwise ReLU on a vector of n values (float32). You read each value from memory (4 bytes each), apply max(0, x) — one comparison, ~1 FLOP — and write it back (4 bytes). Total FLOPs = n. Total bytes = 4n (read) + 4n (write) = 8n. Arithmetic intensity = n / 8n = 1/8 = 0.125 FLOPs/byte. Extremely memory-bound.

In float16/bfloat16 (2 bytes each): bytes = 2n + 2n = 4n. Intensity = n/4n = 0.25 FLOPs/byte. Still very memory-bound, but halving precision doubles the intensity — this is why low-precision helps memory-bound ops.

Matrix multiplication: A(M×K) × B(K×N) → C(M×N). FLOPs = 2MKN (each of the MN output elements requires K multiply-accumulate pairs; each MAC is 2 FLOPs). Bytes = size of A + B + C = 2(MK + KN + MN) bytes (for float16, 2 bytes each). Intensity = 2MKN / 2(MK + KN + MN). For the square case M=N=K: intensity = 2K³ / (3 × 2K²) = K/3 FLOPs/byte. For K=4096 (typical d_model), that's ~1365 FLOPs/byte. This is why matmul is compute-bound.

The ridge point: where compute meets memory. An A100 delivers 312 TFLOP/s (bf16) and 2 TB/s of HBM bandwidth. The ridge point — the intensity at which you transition from memory-bound to compute-bound — is 312 × 1012 / (2 × 1012) = 156 FLOPs/byte. Any operation with intensity below 156 FLOPs/byte is memory-bound. Any operation above is compute-bound. ReLU (0.125) is 1248× below the ridge. Matmul with K=4096 (1365) is 8.7× above. This 10,000× range is why mixed workloads are so hard to optimize.
python
# Roofline arithmetic intensity calculator
def intensity(op, dtype_bytes=2):
    if op == 'relu':
        n = 1  # per element
        flops = 1  # one comparison / max op
        bytes_moved = dtype_bytes * 2  # read + write
        return flops / bytes_moved  # 0.25 for bf16

    if op == 'matmul_square':
        K = 4096  # matrix dimension
        flops = 2 * K**3
        bytes_moved = dtype_bytes * 3 * K**2  # A + B + C
        return flops / bytes_moved  # ~1365 for bf16

    if op == 'layer_norm':
        # Read x (d), write output (d), store mean + var: ~5d bytes
        # ~5 FLOPs per element (sub, square, add, div, mul)
        return 5 / (5 * dtype_bytes)  # ~0.5 — very memory-bound

for op in ['relu', 'layer_norm', 'matmul_square']:
    ai = intensity(op)
    region = 'compute-bound' if ai > 156 else 'MEMORY-BOUND'
    print(f"{op:20s}  {ai:8.1f} FLOPs/byte  →  {region}")
# relu                    0.2 FLOPs/byte  →  MEMORY-BOUND
# layer_norm              0.5 FLOPs/byte  →  MEMORY-BOUND
# matmul_square        1365.3 FLOPs/byte  →  compute-bound
You are computing elementwise GELU activation on a tensor of n bfloat16 values. GELU involves approximately 4 FLOPs per element (tanh approximation). What is the arithmetic intensity?

Chapter 5: The Roofline Model

Now that you know arithmetic intensity, you can ask: what is the theoretical maximum performance of a given operation on this hardware? The answer comes from the Roofline Model, one of the most useful thinking tools in performance engineering.

The key insight: performance is limited by whichever of two resources runs out first — compute capacity or memory bandwidth. The roofline model makes this precise.

For a GPU with peak compute P (FLOP/s) and peak memory bandwidth B (bytes/s), an operation with arithmetic intensity I (FLOPs/byte) achieves at most:

Performance = min( P, I × B )    [FLOPs/s]

Let's unpack this. If I is very small (memory-bound side): the GPU can do P FLOPs/s, but data only arrives at I × B FLOPs worth of data per second. Since I is small, I × B < P — the GPU is waiting on memory. You get I × B FLOP/s, not P.

If I is very large (compute-bound side): data arrives faster than you can compute. The limiting factor is now P FLOPs/s. You hit the compute ceiling.

The crossover — the ridge point — is where I × B = P, so I = P/B. For the A100: ridge = 312 TFLOP/s ÷ 2 TB/s = 156 FLOPs/byte. For the H100 SXM5: 989 TFLOP/s ÷ 3.35 TB/s ≈ 295 FLOPs/byte. The ridge point keeps rising because FLOPs have scaled faster than memory bandwidth.

Why FLOPs scaled faster than bandwidth. Dennard scaling (1970s–2000s) let transistors shrink every generation — you could pack more compute per mm². Memory bandwidth requires physical pins and bus width; these scale slowly with die geometry. Over the last decade, GPU FLOPs grew ~1000× while HBM bandwidth grew only ~30×. That's why the ridge point keeps rising and memory-bound ops get harder to feed. The "memory wall" is not a metaphor — it's a physics constraint.
Roofline Model (interactive)

Drag the slider to set your operation's arithmetic intensity. Toggle between A100 and H100. See where your op lands relative to the ridge point.

Arithmetic intensity (FLOPs/byte) 4.0

Applying concrete numbers. A transformer layer's main costs:

OperationIntensity (bf16)A100 regionH100 region
ReLU/GELU activation~0.5–1 FLOPs/byteMemory-boundMemory-bound
LayerNorm~0.5 FLOPs/byteMemory-boundMemory-bound
Softmax (attention)~3–5 FLOPs/byteMemory-boundMemory-bound
Matmul d=256 (small)~85 FLOPs/byteMemory-boundMemory-bound
Matmul d=2048~683 FLOPs/byteCompute-boundCompute-bound
Matmul d=4096~1365 FLOPs/byteCompute-boundCompute-bound
Practical implication. Elementwise ops (activations, normalization) are always memory-bound and never benefit from more compute. Adding fancier activation functions (GELU vs ReLU) barely affects wall-clock time — both are dominated by memory traffic. Fusing multiple elementwise ops into a single kernel (see Chapter 7) eliminates redundant memory reads/writes and can give 3-5× speedup with zero change to math.
An H100 has ~989 TFLOP/s (bf16) and ~3.35 TB/s HBM bandwidth. What is its ridge point, and what does that mean for a softmax operation with intensity ~4 FLOPs/byte?

Chapter 6: Coalescing & Tiling

Two techniques dominate GPU memory optimization: memory coalescing (read the right addresses to avoid wasting bus transactions) and tiling (move data into SRAM to reuse it many times without returning to HBM). Together they are the reason that well-written GPU matmul kernels can achieve 80%+ of theoretical peak, while naive implementations hit under 5%.

Memory coalescing exploits how DRAM physically works. When a warp of 32 threads issues memory loads, the GPU doesn't dispatch 32 individual read requests. Instead, it groups consecutive thread addresses into a single burst transfer — one transaction that fetches 128 bytes at once (a full cache line). If 32 threads each load a consecutive 4-byte float from addresses [base, base+4, base+8, ..., base+124], they all arrive in a single burst: perfectly coalesced, 32 useful loads for 1 transaction.

The disaster case: if thread 0 loads address base, thread 1 loads base+512, thread 2 loads base+1024 (strided access pattern), each load must be a separate burst — 32 transactions for 32 useful values. You get the same 32 values, but pay 32× the bus cost. Strided memory access is one of the most common performance killers in GPU code.

Row-major vs column-major in matmul. In a matrix stored row-major (C/NumPy default), consecutive memory addresses move along a row. Thread 0 loading element [0,0], thread 1 loading [0,1], etc. — consecutive addresses, coalesced. But if you need column access (loading all elements of column 0), thread 0 loads [0,0] at address 0, thread 1 loads [1,0] at address N×sizeof(float), thread 2 loads [2,0] — strided, non-coalesced. This is why matmul implementations transpose one operand or carefully choose thread-to-data assignments to ensure coalesced access.

Tiling attacks the reuse problem. Consider computing C = A × B where A is M×K and B is K×N. To compute one element C[i,j], you need all K values of row i from A and column j from B — that's K reads from each matrix. To compute all MN elements, naively you'd read each element of A MN/N = M times and each element of B MN/M = N times. For M=N=K=4096, each input is read 4096 times from HBM. That's 4096 passes over 64 MB = 256 GB of total HBM traffic per matmul.

Tiling loads a block of rows of A (size Mtile × K) and columns of B (size K × Ntile) into SRAM, computes the corresponding block of C entirely from SRAM, writes C back to HBM, then moves to the next tile. If the tile size is T, each element of A is loaded K/T times from HBM (once per tile column) instead of K times. You get a T× reduction in HBM traffic.

HBM reads per element (no tile): N passes    HBM reads per element (tile size T): N/T passes
Memory coalescing: coalesced vs strided access patterns

32 threads in a warp issue memory loads. Toggle between coalesced (consecutive addresses) and strided access. See how many bus transactions are needed.

triton/pseudocode
# Coalesced access: thread i loads address base + i*4 (consecutive floats)
# All 32 threads → one 128-byte burst transaction → 32 useful loads
tid = thread_id()
val = load(base + tid * 4)  # GOOD: stride-1 in units of element size

# Strided access: thread i loads address base + i*stride (stride large)
# All 32 threads → 32 separate transactions → only 4 useful bytes each
val = load(base + tid * 512)  # BAD: stride-128 in float units

# Tiled matmul: cooperative load of tile into SRAM, compute, repeat
# Each thread block loads a tile of A and B into shared memory
for tile in range(K // TILE_SIZE):
    A_shm[ty, tx] = A[row, tile * TILE_SIZE + tx]  # coalesced load
    B_shm[ty, tx] = B[tile * TILE_SIZE + ty, col]  # coalesced load
    __syncthreads()                               # all threads must finish loading
    for k in range(TILE_SIZE):
        acc += A_shm[ty, k] * B_shm[k, tx]        # reads from fast SRAM
    __syncthreads()                               # finish computing before next load
Tiling with tile size T reduces HBM reads per matrix element by a factor of T. Why can't you just use a very large tile size (e.g., T=4096) to eliminate almost all HBM traffic?

Chapter 7: Operator Fusion & Recomputation

Tiling keeps data in SRAM during a single operation. But what if you need to apply many operations in sequence? The naive approach launches a separate GPU kernel for each operation: kernel 1 computes A, writes result to HBM; kernel 2 reads from HBM, computes B, writes to HBM; and so on. Each round-trip to HBM wastes bandwidth on data that's never needed outside the GPU anyway.

Operator fusion merges multiple operations into a single kernel, keeping intermediate results in SRAM or registers and never writing them to HBM. Consider computing sin²x + cos²x elementwise. The unfused approach launches 5 kernels: sin, ², cos, ², +. Each writes its output to HBM and the next reads it back. Five round-trips for data that could stay on-chip. Fused: one kernel reads x once, computes everything, writes the final result. One round-trip, 5× less HBM traffic.

Why does PyTorch launch separate kernels by default? PyTorch is eager by default: each Python operation creates one CUDA kernel launch. The simplicity is valuable — you can inspect intermediate tensors, set breakpoints, add print statements. But it leaves performance on the table for memory-bound op chains. torch.compile() (PyTorch 2.0+) automatically identifies fusible op chains and generates fused CUDA kernels. FlashAttention does this manually for attention, getting ~3× speedup over unfused attention.

Recomputation (gradient checkpointing) is a technique that trades memory bandwidth for compute. In backpropagation, you normally save all activations during the forward pass so you can use them during the backward pass. For a 100-layer transformer, this means storing 100 activation tensors in HBM. Each takes up memory and must be loaded during the backward pass.

Recomputation throws away intermediate activations during the forward pass, then recomputes them on-the-fly during the backward pass. This sounds wasteful — you're doing the forward pass twice. But consider the memory bandwidth argument: storing activations to HBM and loading them back during the backward pass can cost more bandwidth than just recomputing them from scratch. For certain op chains with low arithmetic intensity (like sigmoid sequences), recomputation actually reduces total HBM traffic.

python
# Three chained sigmoids — without recomputation
# Forward: 3 reads + 3 writes = 6 HBM ops
# Backward: 3 reads of saved activations + 3 grad writes = 6 HBM ops
# Total: 12 HBM round-trips, very low arithmetic intensity
x1 = sigmoid(x0)  # write x1 to HBM (needed for backward)
x2 = sigmoid(x1)  # write x2 to HBM
x3 = sigmoid(x2)  # write x3 to HBM

# With recomputation: fuse all three into one kernel
# Forward: 1 read (x0) + 1 write (x3) = 2 HBM ops
# Backward: 1 read (x0), recompute x1, x2 on-chip, 1 write grad = 2 HBM ops
# Total: 4 HBM round-trips — 3x improvement, at cost of extra compute
with torch.utils.checkpoint.checkpoint(fn, x0) as x3:
    pass  # PyTorch handles the recomputation automatically
When is recomputation worth it? Recomputation is almost always worth it for memory-bound op chains (elementwise ops, softmax, normalization). It is rarely worth it for compute-bound ops (large matmuls) — there the activation storage cost is small relative to the computation cost, so you'd be doing expensive matmuls twice for little bandwidth savings. Flash Attention uses recomputation for the softmax in the backward pass — exactly the right call, because softmax is highly memory-bound.
Operator fusion helps most when operations are:

Chapter 8: Flash Attention: Putting It Together

Flash Attention (Dao et al. 2022) is the single best demonstration of GPU principles applied to a real problem. It makes standard attention 2-4× faster and uses O(n) rather than O(n²) memory, with identical numerical output. Every modern LLM inference stack uses it. Understanding why it works is a direct application of everything in this lecture.

Standard attention computes: S = QKT / √dk, P = softmax(S), O = PV. The naive PyTorch implementation materializes two O(n²) matrices — the full attention score matrix S and the probability matrix P — in HBM. For a sequence of length 4096 and head dimension 64, S has 4096² = 16.7M entries × 2 bytes = 33 MB per head. With 32 heads, that's over 1 GB just for the score matrices, written then read again for the softmax, then read again for the PV multiplication. Three HBM round-trips for 1 GB each = 3 GB of HBM traffic for attention alone.

Flash Attention eliminates this by combining three techniques we've now covered: tiling (process Q, K, V in blocks that fit in SRAM), operator fusion (compute the entire attention pipeline — QK multiplication, softmax, PV — in one kernel without writing intermediates to HBM), and recomputation in the backward pass (don't save S and P; recompute them from Q, K tile-by-tile during backward).

The softmax tiling trick. The hard part is tiling softmax: you need the maximum of the entire row to numerically stabilize the exp, but you're processing the row in tiles and don't have the full row in memory at once. The solution is online softmax from Milakov & Gimelshein (2018): maintain a running maximum m and running sum s. For each new tile, update m to the new maximum if it's larger, and correct the previous sum: s ← s × exp(mold − mnew) + ∑ exp(logit − mnew). This telescoping correction gives you the correct softmax value at the end without ever seeing the full row at once.

The algorithm in steps: iterate over tiles of K and V (outer loop). For each tile, load Q block, K tile, V tile into SRAM. Compute QKT tile for this block. Update running max and sum for online softmax. Accumulate PV partial sums in registers. At the end, normalize by the running sum to get the final output O. Write O to HBM — just once. No S, no P matrix ever touches HBM.

Flash Attention: standard vs flash memory access pattern

Adjust sequence length. See how HBM traffic scales for standard (O(n²)) vs Flash (O(n)) attention. Toggle between modes.

Sequence length (tokens) 2048
python
# HBM traffic comparison: standard vs Flash Attention
def attention_hbm_bytes(n, d, h, dtype_bytes=2):
    # n = seq len, d = head dim, h = num heads

    # Standard: materialize full S matrix (n×n per head)
    qkv_read   = 3 * h * n * d * dtype_bytes       # read Q, K, V
    s_write    = h * n * n * dtype_bytes             # write n×n score matrix
    p_write    = h * n * n * dtype_bytes             # write softmax probs
    o_write    = h * n * d * dtype_bytes             # write output
    standard   = qkv_read + s_write + p_write + o_write

    # Flash: only read QKV + write O (no S, no P matrix)
    flash      = qkv_read + o_write  # O(n*d) traffic, not O(n^2)

    return standard, flash

n, d, h = 4096, 64, 32
std, fla = attention_hbm_bytes(n, d, h)
print(f"Standard: {std/1e9:.1f} GB  Flash: {fla/1e9:.2f} GB  Ratio: {std/fla:.1f}x")
# Standard: 4.3 GB  Flash: 0.10 GB  Ratio: 43x
Why does Flash Attention need to use the "online softmax" trick (running max + running sum)?

Chapter 9: Connections & Cheat Sheet

This lecture is the foundation for everything in CS336 that deals with training speed, memory, and efficiency. The GPU model you built here — memory hierarchy, arithmetic intensity, roofline, coalescing, tiling, fusion — applies to every operator you'll ever write or profile.

ConceptWhat it isWhy it matters
SIMT / warp32 threads execute same instruction on different dataBasis for control divergence; warp = atom of scheduling
SM / SRAMPer-SM fast memory (~192 KB), programmable cacheThe "staging area" for tiling; key to reducing HBM trips
HBMMain GPU DRAM, 80 GB at 2-3.35 TB/sThe bottleneck; almost every perf problem traces here
Arithmetic IntensityFLOPs ÷ bytes moved to/from HBMDetermines if you're memory- or compute-bound
Ridge PointP / B (peak FLOPs ÷ bandwidth) = ~156 on A100Intensity below this = memory-bound; above = compute-bound
Coalescing32 warp threads load consecutive addresses → 1 transactionStrided access = 32× overhead; must think in warp patterns
TilingLoad A, B tile into SRAM, compute, move to next tileReduces HBM reads by factor T; foundational matmul optimization
Operator FusionMerge multiple kernels into one, keep intermediates on-chip3-5× speedup for elementwise-heavy chains (attention, norms)
RecomputationDiscard activations, recompute during backwardTrades compute for memory bandwidth; worthwhile for memory-bound layers
Flash AttentionTiling + fusion + online softmax + recomputation for attn43× less HBM traffic, O(n) memory, same output
A100 vs H100 quick-reference.
SpecA100 SXM4H100 SXM5
bf16 Tensor Core TFLOP/s312989
HBM bandwidth2 TB/s3.35 TB/s
HBM capacity80 GB80 GB
SRAM per SM192 KB228 KB
SMs108132
Ridge point (bf16)~156 FLOPs/byte~295 FLOPs/byte

The compute/bandwidth gap keeps widening: H100 has 3.2× the FLOPs of A100 but only 1.7× the bandwidth. The ridge point jumped from 156 to 295 — more operations are now memory-bound on H100 than on A100. This means the techniques in this lecture become more important with each GPU generation, not less.

What comes next. CS336 Lecture 6 (Kernels) takes these ideas to implementation: writing actual Triton kernels that tile and fuse manually. Lecture 9 (Parallelism) connects the SM/device model to multi-GPU training — tensor parallelism, pipeline parallelism, and how to choose the right strategy for different model sizes. The arithmetic intensity framework from this lecture predicts communication bottlenecks in distributed training too: allreduce bandwidth vs FLOP intensity of the operation you're synchronizing.
"The bottleneck in almost every ML workload is not compute — it is data movement. The hardware has compute to burn; it just can't be fed fast enough."
— paraphrase of Horace He's central thesis, CS336 Lec 5