Language Modeling from Scratch · CS336 · Lecture 6

Kernels & Triton: Writing Fast GPU Code

Your elementwise chain reads HBM five times when one would do. Learn why — then fix it. Cover kernel launch overhead, operator fusion from first principles, the Triton programming model (program_id, tiles, tl.load/tl.store, masks), write fused kernels for GELU and softmax, and arrive at FlashAttention's online-softmax as the canonical tiled-fused kernel. Bench everything.

Prerequisites: CS336 Lec 5 GPUs (memory hierarchy, arithmetic intensity, roofline). Python basics.
10
Chapters
5
Live Canvases
Real
Triton Code

Chapter 0: Your MLP is Reading HBM Five Times

You have an MLP layer: x = gelu(W₂ · gelu(W₁ · x)). You profile it and see four CUDA kernels firing in sequence — a matmul, a GELU, another matmul, another GELU. On paper, that GELU is trivial: tanh approximation, a handful of arithmetic ops. It should be fast. But the profiler shows it taking 40% of your layer's wall-clock time.

The reason is not the arithmetic. For an elementwise GELU on a tensor of n float16 values, you do about 8 FLOPs per element. With 2 bytes per element and a read + write pass, that's 4 bytes/element. Arithmetic intensity: 8/4 = 2 FLOPs/byte. The A100's ridge point is 156 FLOPs/byte. Your GELU is running at 2/156 = 1.3% of the theoretical compute ceiling — purely because it spends almost all of its time shuffling data between HBM and the registers.

Now chain three ops: y = mul(add(gelu(x), bias), scale). PyTorch eager launches three separate kernels: gelu writes its result to HBM; add reads that result back and writes its output to HBM; mul reads that output and writes again. Six HBM transactions for data that never needed to leave the chip in the first place.

The warehouse analogy, made concrete. Each kernel launch is a truck trip to the warehouse (HBM). For a chain of three elementwise ops on a 16 MB tensor: unfused = 6 round-trips × 16 MB = 96 MB of HBM traffic. Fused into one kernel: 1 read + 1 write = 32 MB. Three times less traffic, three times more effective bandwidth, roughly three times faster on a memory-bound op. This is not a micro-optimization — it is the difference between 2% GPU utilization and 6%.

This lecture answers the question: how do you actually fuse kernels? You have three paths. You can write CUDA (C++ + GPU intrinsics) — powerful but verbose. You can use Triton (Python + compiler) — almost as fast, far more readable. Or you can call torch.compile() and let the compiler figure it out. We will cover all three, but focus on Triton — it is what real ML researchers use to write custom kernels for new ops.

HBM traffic: unfused chain vs fused kernel

Drag the slider to set the number of chained elementwise ops. See how HBM traffic scales for unfused (separate kernels) vs fused (one kernel). Tensor size fixed at 64 MB (bf16).

Number of chained ops 3
An unfused chain of 5 elementwise ops on a bf16 tensor writes its output to HBM after each op. Approximately how many times does the data travel between HBM and SRAM compared to the fused version?

Chapter 1: Kernel Launch Overhead

Every time PyTorch calls a CUDA function — matmul, GELU, softmax — the CPU sends a "launch" command to the GPU. The GPU queues this command, schedules the kernel across its SMs, sets up thread blocks, and then runs. This pipeline has a fixed overhead of roughly 5–20 microseconds per kernel launch, whether the kernel does 1 FLOP or 1 trillion.

For a large matmul (milliseconds of work), 10 µs of overhead is invisible. For a tiny elementwise op on a small tensor, 10 µs can exceed the actual computation time. And in a real transformer training loop, you might launch hundreds or thousands of kernels per forward pass — a 16-layer transformer with attention, FFN, and normalization at each layer can easily launch 500+ kernels. At 10 µs each, that is 5 ms of pure overhead before any arithmetic happens.

The kernel launch tax. The overhead is not from the GPU being slow — it is from the CPU-GPU coordination protocol. The CPU enqueues launch instructions into the CUDA command queue asynchronously (so the CPU can keep running while the GPU catches up). But if your GPU kernels are short-lived, the GPU finishes each one before the CPU has submitted the next, leaving the GPU idle between kernels. You can measure this with the CUDA profiler: idle gaps between kernels are the launch overhead in action.

PyTorch 2.0 introduced CUDA graphs to amortize this overhead: record an entire forward pass (all kernel launches) as a single graph, then replay it in one shot. The CPU overhead for 500 launches becomes a single graph-replay call — roughly 100× reduction in CPU overhead. For inference-heavy workloads, CUDA graphs can give 20–30% speedup on small-batch regimes where kernel launch dominates.

But CUDA graphs require that tensor shapes and addresses don't change between runs — no dynamic control flow, no variable-length inputs. They are a specialized optimization. The more general technique is to eliminate unnecessary launches in the first place through operator fusion (Chapter 2), which addresses both the launch overhead and the redundant HBM traffic at once.

python
# Measuring kernel launch overhead: time a no-op vs a real op
import torch, time

x = torch.randn(1, device='cuda')  # tiny tensor
y = torch.randn(10_000_000, device='cuda')  # large tensor

def bench(fn, n=200):
    for _ in range(5): fn()  # warmup
    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(n): fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t) / n * 1e6  # µs per call

tiny_time  = bench(lambda: torch.nn.functional.relu(x))
large_time = bench(lambda: torch.nn.functional.relu(y))

print(f"tiny tensor:  {tiny_time:.1f} µs  (mostly launch overhead)")
print(f"large tensor: {large_time:.1f} µs  (mostly real work)")
# typical output:  tiny: ~8 µs  large: ~120 µs
# The launch overhead is ~8 µs regardless of tensor size
The five-way comparison (lecture numbers on an A100). For GELU on a 16K-element float32 tensor: manual PyTorch (unfused, 5 ops) ≈ 0.8 ms. Fused PyTorch ≈ 0.05 ms. CUDA kernel ≈ 0.04 ms. Triton kernel ≈ 0.05 ms. torch.compile ≈ 0.05 ms. The manual version is ~16× slower — all from unnecessary HBM trips and kernel launches, zero from arithmetic complexity.
A transformer forward pass launches 500 CUDA kernels. Each kernel takes 5 µs to execute and incurs 10 µs of launch overhead. What fraction of total GPU wall-clock time is launch overhead?

Chapter 2: Operator Fusion: Count the Trips

Operator fusion is the practice of merging multiple GPU operations into a single kernel so that intermediate results never leave the chip. Instead of writing output A to HBM and then reading it back as input to operation B, you pipe A directly from registers into B inside the same kernel thread. This eliminates HBM traffic for all intermediates.

Let's count exactly. Consider the GELU activation: gelu(x) = 0.5 · x · (1 + tanh(0.7979 · (x + 0.04471 · x³))). If you implement this naively in PyTorch, it decomposes into: (1) cube x, (2) scale by 0.04471, (3) add x, (4) scale by 0.7979, (5) tanh. Each step is a separate elementwise kernel. With a tensor of n bf16 values (2 bytes each):

Unfused GELU: 5 reads + 5 writes = 10 × 2n bytes = 20n bytes of HBM traffic
Fused GELU: 1 read + 1 write = 2 × 2n bytes = 4n bytes of HBM traffic
Speedup from fusion: 20n / 4n = 5× less HBM traffic

For a tensor of n = 16,384 bf16 elements (32 KB): unfused = 327 KB of HBM traffic; fused = 65 KB. On an A100 with 2 TB/s bandwidth: unfused takes ~163 ns just for the memory transfers; fused takes ~33 ns. The 5× traffic reduction translates directly to a 5× bandwidth-limited speedup.

The factory floor analogy. Each kernel is a factory that takes raw materials (from the warehouse = HBM) and produces a product (writes back to HBM). Unfused GELU: five small factories, each doing one step, shipping product back to the warehouse after each step. Fused GELU: one factory that does all five steps, keeping the work-in-progress on the factory floor (SRAM/registers). One warehouse trip total. The work is the same; the shipping cost is five times less.

What can be fused? Elementwise operations are the easiest case — any chain of ops where each output element depends only on the corresponding input element. No reduction, no cross-element dependencies. GELU, ReLU, swish, sigmoid, bias-add, scale — all fusible into a single kernel.

Reductions (sum, max, mean) are harder — they require communication across elements. But you can fuse an elementwise op before a reduction in the same kernel (compute exp(x) and accumulate the sum in one pass). Fused softmax does exactly this: compute max(x), subtract, exp, accumulate sum — all in one kernel, with the row of x read exactly once from HBM.

Fusion traffic bar: unfused chain vs fused (with breakdown)

Set the number of chained elementwise ops and tensor size. Each bar shows HBM traffic per op for unfused vs total for fused. Hover a bar to see the exact MB.

Chained ops 5
Tensor size (M bf16 elements) 16
You have a chain: x₁ = gelu(x₀), x₂ = x₁ + bias, x₃ = x₂ × scale. How many HBM read+write round-trips does the unfused version make vs the fused version (each tensor is 8 MB bf16)?

Chapter 3: CUDA vs Triton vs torch.compile

To write a fused kernel, you have three options. They form a ladder of abstraction: CUDA at the bottom (maximum control, maximum boilerplate), Triton in the middle (near-CUDA performance, Python-level code), and torch.compile at the top (automatic fusion, zero kernel code, but less controllable).

CUDA is NVIDIA's GPU programming language — C/C++ with extensions for GPU execution. You write code at the level of individual threads. Each thread is responsible for one or a few elements; you compute its (blockIdx, threadIdx) from the launch configuration to figure out which elements it owns. You manage shared memory explicitly: decide what to load, when to synchronize, when to write back. This gives you maximum performance — hand-written CUDA can squeeze every last FLOP out of the hardware. But a correct CUDA kernel for even a simple fused softmax is 100+ lines of C++ with subtle bugs lurking in the synchronization logic.

Triton (OpenAI, 2021) is a Python-based GPU programming language that operates at the level of thread blocks rather than individual threads. You write code that processes a tile of elements, and Triton's compiler handles thread assignment, memory coalescing, and shared memory management automatically. The key insight from the Triton paper: the hard manual work in CUDA (coalescing, shared memory banks, instruction scheduling) can be automated for most practical kernels — you only need to specify the tiling strategy. Result: Triton kernels are typically 80-100% of CUDA performance at 20% of the code complexity.

torch.compile (PyTorch 2.0+) takes regular Python/PyTorch code and compiles it to optimized Triton kernels automatically. It traces the computation graph, identifies fusible op chains, and generates Triton code behind the scenes. For standard elementwise fusion, it matches hand-written Triton. The limitation: it cannot fuse operations that require tiling strategies it doesn't know (like custom FlashAttention variants), and it adds compilation overhead on the first call.

ApproachAbstraction levelCoalescingShared memoryScheduling across SMsTypical use
CUDAThreadManualManualManualMaximum perf, novel algorithms
TritonThread block (tile)AutomaticAutomaticManualCustom ops, research kernels
torch.compilePython opAutomaticAutomaticAutomaticStandard fusion, production
PyTorch eagerPython opN/A (each op separate)N/AN/ADebugging, prototyping
The compiler does more work over time. In 2019, writing a fast softmax required CUDA. In 2021, Triton let you write it in Python at 95% of CUDA performance. In 2023, torch.compile started fusing elementwise chains automatically. By 2025, torch.compile can fuse most things — but you still need Triton for truly novel algorithms (new attention variants, quantized ops, custom positional encodings). The window where you need to write Triton is narrowing, but it still exists at the research frontier.
Which GPU programming model requires you to explicitly manage shared memory, but automates memory coalescing?

Chapter 4: The Triton Programming Model

Triton replaces CUDA's per-thread view with a per-tile view. Instead of asking "which element does this thread own?", you ask "which block of elements does this program instance own?" The GPU runs many program instances in parallel, each identified by a program_id. Each instance processes a contiguous tile of elements. The tile size, called BLOCK_SIZE, is a compile-time constant — Triton uses it to generate efficient vectorized code.

The three fundamental Triton primitives are tl.program_id(axis), tl.load(ptr + offsets, mask=mask), and tl.store(ptr + offsets, values, mask=mask). Together, they let you say: "instance pid works on elements [pid·BLOCK_SIZE, (pid+1)·BLOCK_SIZE). Load them, compute, store."

The mask is mandatory. If the tensor length n is not a multiple of BLOCK_SIZE, the last program instance will have a partial tile. Without a mask, those out-of-bounds threads would read garbage from adjacent memory — or cause a segfault. The mask offsets < n silently discards out-of-bounds operations. Triton makes masking easy; CUDA requires manual boundary checking.
triton
import triton
import triton.language as tl

# Triton kernel: process BLOCK_SIZE elements per program instance
@triton.jit
def add_kernel(x_ptr, y_ptr, z_ptr, n, BLOCK_SIZE: tl.constexpr):
    # Step 1: which tile does this program instance own?
    pid = tl.program_id(axis=0)          # e.g. 0, 1, 2, 3, ...
    block_start = pid * BLOCK_SIZE         # first element of this tile

    # Step 2: compute indices for all elements in this tile
    offsets = block_start + tl.arange(0, BLOCK_SIZE)  # shape [BLOCK_SIZE]

    # Step 3: mask out-of-bounds elements (if n % BLOCK_SIZE != 0)
    mask = offsets < n

    # Step 4: load from HBM (coalescing automatic — Triton ensures consecutive)
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Step 5: compute (stays in registers)
    z = x + y

    # Step 6: store to HBM
    tl.store(z_ptr + offsets, z, mask=mask)

# Launch: determine grid = number of program instances
def add(x, y):
    z = torch.empty_like(x)
    n = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n, BLOCK_SIZE),)  # ceiling division
    add_kernel[grid](x, y, z, n, BLOCK_SIZE=BLOCK_SIZE)
    return z

The launch grid (triton.cdiv(n, BLOCK_SIZE),) tells the GPU how many program instances to create. For n=8192 and BLOCK_SIZE=1024: 8 instances (pid 0 through 7). Each handles 1024 elements. The GPU schedules these 8 instances across its 108 SMs — some SMs handle 1 instance, the scheduler distributes them. This is the "scheduling across SMs" that Triton leaves manual: you choose the grid size, but the SM assignment is automatic.

Triton tiling: program_ids mapped to tensor blocks

Drag BLOCK_SIZE to see how the tensor is partitioned. Each colored block = one program instance (pid). The last block may be partial (uses mask).

BLOCK_SIZE 256
A Triton kernel has BLOCK_SIZE=512 and the input tensor has n=1300 elements. How many program instances (pids) are launched, and what does pid=2 compute?

Chapter 5: Writing a Fused GELU Kernel

Let's build the fused GELU kernel from scratch. GELU uses a tanh approximation: gelu(x) = 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³))). The constant √(2/π) ≈ 0.7979. Since Triton's tl namespace does not have a tanh primitive, we implement it using the identity tanh(a) = (e2a - 1) / (e2a + 1), which only needs tl.exp.

The PyTorch reference implementation is five lines of scalar Python. The Triton version is structurally identical — the same arithmetic — but operates on a BLOCK_SIZE-wide vector at once, keeping the entire computation in registers without touching HBM for intermediates.

python + triton
import torch
import triton
import triton.language as tl

# PyTorch reference (unfused — 5 separate kernels)
def manual_gelu(x):
    return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))

# Triton fused version — same math, one kernel, one HBM read+write
@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, n, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n

    x = tl.load(x_ptr + offsets, mask=mask)  # read once from HBM

    # All arithmetic stays in registers:
    a = 0.79788456 * (x + 0.044715 * x * x * x)
    exp2a = tl.exp(2 * a)
    tanh = (exp2a - 1) / (exp2a + 1)
    y = 0.5 * x * (1 + tanh)               # result in registers

    tl.store(y_ptr + offsets, y, mask=mask)   # write once to HBM

def triton_gelu(x):
    assert x.is_cuda() and x.is_contiguous()
    y = torch.empty_like(x)
    n = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n, BLOCK_SIZE),)
    triton_gelu_kernel[grid](x, y, n, BLOCK_SIZE=BLOCK_SIZE)
    return y

# Benchmark comparison (A100, n=16384 elements):
# manual_gelu (unfused):  ~0.8 ms  — 5 kernel launches, 5× HBM traffic
# pytorch_gelu (fused):   ~0.05 ms — PyTorch's built-in fused kernel
# triton_gelu (fused):    ~0.05 ms — our Triton version, matches PyTorch
# torch.compile(manual):  ~0.05 ms — auto-fused by compiler
Thread coarsening: the hidden Triton win. The PTX code generated by Triton for this kernel shows each thread processing 8 elements at once — called thread coarsening. Instead of one thread per element, each thread loads a vector of 8 bf16 values in one instruction (ld.global.v4.b32 = 128-bit load). This reduces instruction overhead and improves register utilization. Triton's compiler does this automatically based on BLOCK_SIZE. CUDA requires you to implement it manually with explicit vectorized load intrinsics.

The takeaway: the Triton kernel is almost byte-for-byte identical to the PyTorch reference — same constants, same operations, same structure. The difference is that we placed the load/store calls manually, ensuring all arithmetic happens in registers between them. The compiler handles the rest: thread assignment, vectorization, coalescing. This is why Triton is so useful for research: you express the algorithm, not the hardware.

In the Triton GELU kernel, the intermediate variables a, exp2a, tanh, and y live in:

Chapter 6: Fused Softmax: Row-Wise Tiling

GELU is elementwise — each output depends only on the same-index input. Softmax is different: each output element depends on all inputs in the same row. Computing softmax requires a full-row scan to find the maximum (for numerical stability), then a full-row scan to compute the sum, then a full-row division. Three passes over the same data.

The naive PyTorch decomposition of softmax(x) is:

xmax = max(x, dim=1)  —  MN reads, M writes
x' = x − xmax[:, None]  —  MN + M reads, MN writes
num = exp(x')  —  MN reads, MN writes
denom = num.sum(dim=1)  —  MN reads, M writes
y = num / denom[:, None]  —  MN + M reads, MN writes
Total: 5MN + M reads, 3MN + 2M writes  ≈  8MN HBM ops
Fused: 1 read + 1 write  =  2MN HBM ops  —  4× fewer trips

The Triton fused softmax kernel assigns one program instance per row. Each instance loads the entire row into SRAM (within the kernel's register space for small rows, or iterating over tiles for large rows), computes max, subtracts, exps, sums, normalizes, and writes back — one round-trip to HBM per row, regardless of how many sub-ops happen inside the kernel.

triton
@triton.jit
def softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride,
                    num_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)     # one program per row
    col_offsets = tl.arange(0, BLOCK_SIZE)

    # Load the entire row from HBM (one read)
    x_start = x_ptr + row_idx * x_row_stride
    x_row = tl.load(x_start + col_offsets,
                     mask=col_offsets < num_cols,
                     other=float('-inf'))  # out-of-bounds → -inf for safety

    # Compute in registers (no HBM traffic for intermediates):
    x_row = x_row - tl.max(x_row, axis=0)   # subtract max for stability
    numerator = tl.exp(x_row)
    denominator = tl.sum(numerator, axis=0)
    y_row = numerator / denominator

    # Write back (one write)
    y_start = y_ptr + row_idx * y_row_stride
    tl.store(y_start + col_offsets, y_row, mask=col_offsets < num_cols)

def triton_softmax(x):
    M, N = x.shape
    y = torch.empty_like(x)
    BLOCK_SIZE = triton.next_power_of_2(N)  # power-of-2 for efficiency
    softmax_kernel[(M,)](x, y,
        x_row_stride=x.stride(0),
        y_row_stride=y.stride(0),
        num_cols=N, BLOCK_SIZE=BLOCK_SIZE)
    return y

# Benchmark (M=16384, N=16384, bf16):
# manual_softmax (unfused):  ~2.1 ms   (8MN HBM trips)
# pytorch_softmax:           ~0.55 ms  (built-in fused)
# triton_softmax (fused):    ~0.52 ms  (matches PyTorch!)
# torch.compile(manual):     ~0.54 ms  (auto-fused, close)
Why BLOCK_SIZE must be a power of 2. Triton vectorizes loads using 128-bit (16-byte) instructions. For coalesced vectorized access across a block, the block size must be a multiple of the hardware vector width (typically 8 or 16 fp16 elements). Power-of-2 block sizes also allow the compiler to use optimized reduction trees (log₂N stages for max/sum). triton.next_power_of_2(N) rounds N up to the nearest power of 2. If N=1300, BLOCK_SIZE=2048 — the extra 748 elements are masked out.
The fused softmax kernel reduces HBM traffic from 8MN to 2MN operations. This gives a theoretical 4× speedup. In practice (from the benchmarks above), the speedup is ~4×. What happens if the row length N is very large — say N = 128,000 — and no longer fits in the kernel's registers?

Chapter 7: FlashAttention: Online Softmax + Tiling

Standard attention computes S = QKT/√d, P = softmax(S), O = PV. For sequence length N and head dimension d, the score matrix S is N×N. At N=4096, d=64, with 32 heads in bf16: S alone takes 32×4096×4096×2 = 1.07 GB. It is written to HBM, then read back for softmax, then read again for the PV product — three full passes over a 1 GB matrix per layer, per forward pass.

FlashAttention (Dao et al., 2022) eliminates S from HBM entirely. It tiles Q into row-blocks of size Br and K, V into column-blocks of size Bc. For each Q-block, it iterates over all K/V blocks, maintaining a running max and running sum for the online softmax. The PV accumulator is also maintained as a running sum. At the end of the K/V loop, it normalizes the accumulator and writes one block of output O to HBM. S is never materialized.

The challenge: softmax requires the maximum of the entire row before you can compute stable exps. But we're processing the row in tiles — we never have the whole row at once. The online softmax recurrence (Milakov & Gimelshein, 2018) solves this by maintaining two running statistics that can be corrected as each new tile arrives.

Online softmax recurrence. Suppose you've seen tiles 0...(t-1) and computed running max m and running sum s. Now tile t arrives with new values {xi}. The new max is m' = max(m, max(tile t)). The previous sum s was computed assuming max = m; now that the true max is m', it needs to be rescaled: s' = s · exp(m − m') + ∑tile t exp(xi − m'). Similarly, the output accumulator O is rescaled: O' = O · exp(m − m') + Ptile · Vtile, where Ptile = exp(x − m'). At the very end, divide O by s to normalize. This gives exactly the same result as the standard softmax, computed one tile at a time.
pseudocode — FlashAttention forward
# Inputs: Q (N×d), K (N×d), V (N×d); tile sizes Br, Bc
# Output: O (N×d) — attention output

for i_block in range(0, N, Br):
    Q_i = Q[i_block : i_block + Br]          # load Q tile into SRAM [Br × d]

    # Initialize running stats for this Q block
    m_i = -inf   # running max of QKᵀ scores for this Q block
    s_i = 0      # running denominator (sum of exps)
    O_i = 0      # running output accumulator [Br × d]

    for j_block in range(0, N, Bc):
        K_j = K[j_block : j_block + Bc]      # load K tile [Bc × d]
        V_j = V[j_block : j_block + Bc]      # load V tile [Bc × d]

        S_ij = Q_i @ K_j.T / sqrt(d)         # [Br × Bc] — SRAM only, never HBM
        m_ij = max(S_ij, dim=-1)             # [Br] — max per row in this tile

        # Online softmax correction
        m_prev = m_i
        m_i = max(m_i, m_ij)                # update running max
        correction = exp(m_prev - m_i)       # rescale factor for previous sum

        P_ij = exp(S_ij - m_i[:, None])     # [Br × Bc] — local softmax numerators
        s_i = s_i * correction + P_ij.sum(dim=-1)  # update running sum
        O_i = O_i * correction[:, None] + P_ij @ V_j  # update output accumulator

    O_i = O_i / s_i[:, None]               # normalize by final sum
    O[i_block : i_block + Br] = O_i          # write to HBM once per Q block

HBM traffic analysis. Standard attention for seq_len N, head_dim d, num_heads h, bf16 (2 bytes):

Standard: 3hNd⋅2 (read Q,K,V) + 2hN²⋅2 (write+read S) + hN²⋅2 (write P) + hNd⋅2 (write O)
For N=4096, d=64, h=32: 3×32×4096×64×2 + 3×32×4096²×2 = 50 MB + 3221 MB ≈ 3.3 GB
FlashAttention: 3hNd⋅2 (read Q,K,V) + hNd⋅2 (write O) = 50 MB + 17 MB ≈ 67 MB
Speedup: 3300 MB / 67 MB ≈ 49× less HBM traffic
FlashAttention step-by-step: online softmax update

Step through the K/V tiles. Watch the running max m and sum s update as each tile arrives. See why the correction factor exp(m_prev − m_new) ensures the accumulator stays consistent.

FlashAttention achieves ~49× HBM traffic reduction for N=4096. The key technique is tiling. Why is the online softmax recurrence necessary for FlashAttention to be correct (not just fast)?

Chapter 8: Benchmarking & Profiling

Writing a Triton kernel is only the beginning. You must verify it is actually faster — not just different. The performance of GPU code is deeply non-intuitive: a kernel that looks correct can be 10× slower than expected due to register pressure, occupancy, or uncoalesced access patterns that only appear at certain tensor shapes.

The benchmarking protocol has two rules: warmup first (the first kernel execution often triggers JIT compilation, making it artificially slow) and synchronize before timing (CUDA is asynchronous — the CPU submits work to the GPU queue without waiting; torch.cuda.synchronize() blocks until the GPU is truly done). Without synchronization, you measure only the CPU queue-submission time, not the GPU execution time.

python — correct GPU benchmarking
import time, torch

def benchmark(fn, n_warmup=3, n_trials=10):
    # Warmup: run without timing to allow JIT + caching
    for _ in range(n_warmup): fn()
    torch.cuda.synchronize()           # wait for warmup to complete

    times = []
    for _ in range(n_trials):
        t0 = time.perf_counter()
        fn()
        torch.cuda.synchronize()       # CRITICAL: wait for GPU
        times.append((time.perf_counter() - t0) * 1000)  # ms

    return sum(times) / len(times)   # mean across trials

# Profile to see which CUDA kernels are actually called:
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CUDA],
) as prof:
    manual_softmax(x)
    torch.cuda.synchronize()

# Sorted by CUDA time — shows each kernel call and its duration
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# manual_softmax shows 5 separate kernel entries
# triton_softmax shows 1 kernel: triton_softmax_kernel_... 

The profiler output names CUDA kernels and shows their duration. For manual_softmax, you see five separate entries (max, subtract, exp, sum, divide). For triton_softmax, you see exactly one. This is the fusion you implemented — confirmed by the profiler, not just assumed.

Benchmarking gives you scaling; profiling gives you why. Use benchmarking to compare implementations across different input sizes — does the speedup grow with size, or shrink? Memory-bound ops scale linearly with tensor size (double the tensor, double the time). Compute-bound ops can show superlinear speedup as larger tensors hit better SRAM reuse. Use profiling to diagnose unexpected slowness: if your Triton kernel is slower than PyTorch, the profiler shows exactly which sub-operation is the bottleneck.

The speedup from fusion depends on arithmetic intensity. For GELU (8 FLOPs/elem / 4 bytes/elem = 2 FLOPs/byte), the fused version reduces HBM traffic 5×, so you expect ~5× speedup — and the benchmarks show it. For softmax (5+ passes → 1 pass), you expect 4–8× speedup. For FlashAttention (49× less traffic), you expect a large speedup, but the actual measured speedup on A100 is 2–4× — because the tiling and reduction logic adds compute overhead that partially offsets the bandwidth savings.

Speedup from fusion: memory-bound regime analysis

Drag the slider to set the number of chained ops (each with ~2 FLOPs/byte intensity). See the theoretical speedup from fusion vs the estimated actual speedup after accounting for kernel launch overhead and partial L2 reuse.

Chained elementwise ops 5
You benchmark a fused kernel and an unfused chain on an A100. The unfused version takes 2 ms; the fused version takes 0.5 ms. Is this consistent with the memory-bound traffic analysis that predicts a 4× speedup?

Chapter 9: Connections & Cheat Sheet

This lecture sits at the junction of hardware understanding (Lecture 5) and distributed training (Lecture 7). The kernel-writing skills you built here are the foundation for understanding how modern LLMs are actually implemented — not as a stack of generic PyTorch ops, but as a set of carefully engineered fused kernels that collectively achieve 40–70% of theoretical hardware throughput.

ConceptWhat it isWhy it matters
Kernel launch overhead~5–20 µs per CUDA kernel launchDominates small ops; motivates CUDA graphs & fusion
Operator fusionMerge N kernels into 1; intermediates stay in registersN× HBM traffic reduction for elementwise chains
Triton program_idInstance index in the launch gridMaps each instance to its tile: block_start = pid × BLOCK_SIZE
tl.load / tl.storeVector HBM read/write with maskCoalescing automatic; mask handles boundary tiles
BLOCK_SIZE (constexpr)Tile width, compile-time constantAllows vectorization, loop unrolling, optimized reductions
Thread coarseningEach Triton thread processes multiple elementsTriton compiler does this automatically; 8× instruction efficiency
Online softmaxRunning max m + running sum s, corrected per tileEnables tiled softmax without materializing the full row
FlashAttentionTiling + fusion + online softmax + recomputation49× HBM traffic reduction; O(N) memory vs O(N²)
torch.compileAuto-generates Triton kernels from PyTorch codeMatches hand-written Triton for standard elementwise chains
cuda.synchronize()Block CPU until GPU finishesRequired for correct GPU benchmarking
The HBM traffic formula. For a chain of k elementwise ops on a tensor of n elements at dtype_bytes bytes/element:
Unfused: k reads + k writes = 2k × n × dtype_bytes bytes.
Fused: 1 read + 1 write = 2 × n × dtype_bytes bytes.
Traffic reduction: k×. Speedup (memory-bound): k×.
For GELU (k=5, n=16384, bf16): unfused = 327 KB, fused = 65 KB, 5× faster.
"Key principle: organize computation to minimize reads/writes. Key ideas: kernel fusion (warehouse/factory analogy), tiling (shared memory). Automatic compilers will get better over time — but until then, you need to understand what they are compiling toward."
— Tatsu Hashimoto, CS336 Lecture 6 summary