Your elementwise chain reads HBM five times when one would do. Learn why — then fix it. Cover kernel launch overhead, operator fusion from first principles, the Triton programming model (program_id, tiles, tl.load/tl.store, masks), write fused kernels for GELU and softmax, and arrive at FlashAttention's online-softmax as the canonical tiled-fused kernel. Bench everything.
You have an MLP layer: x = gelu(W₂ · gelu(W₁ · x)). You profile it and see four CUDA kernels firing in sequence — a matmul, a GELU, another matmul, another GELU. On paper, that GELU is trivial: tanh approximation, a handful of arithmetic ops. It should be fast. But the profiler shows it taking 40% of your layer's wall-clock time.
The reason is not the arithmetic. For an elementwise GELU on a tensor of n float16 values, you do about 8 FLOPs per element. With 2 bytes per element and a read + write pass, that's 4 bytes/element. Arithmetic intensity: 8/4 = 2 FLOPs/byte. The A100's ridge point is 156 FLOPs/byte. Your GELU is running at 2/156 = 1.3% of the theoretical compute ceiling — purely because it spends almost all of its time shuffling data between HBM and the registers.
Now chain three ops: y = mul(add(gelu(x), bias), scale). PyTorch eager launches three separate kernels: gelu writes its result to HBM; add reads that result back and writes its output to HBM; mul reads that output and writes again. Six HBM transactions for data that never needed to leave the chip in the first place.
This lecture answers the question: how do you actually fuse kernels? You have three paths. You can write CUDA (C++ + GPU intrinsics) — powerful but verbose. You can use Triton (Python + compiler) — almost as fast, far more readable. Or you can call torch.compile() and let the compiler figure it out. We will cover all three, but focus on Triton — it is what real ML researchers use to write custom kernels for new ops.
Drag the slider to set the number of chained elementwise ops. See how HBM traffic scales for unfused (separate kernels) vs fused (one kernel). Tensor size fixed at 64 MB (bf16).
Every time PyTorch calls a CUDA function — matmul, GELU, softmax — the CPU sends a "launch" command to the GPU. The GPU queues this command, schedules the kernel across its SMs, sets up thread blocks, and then runs. This pipeline has a fixed overhead of roughly 5–20 microseconds per kernel launch, whether the kernel does 1 FLOP or 1 trillion.
For a large matmul (milliseconds of work), 10 µs of overhead is invisible. For a tiny elementwise op on a small tensor, 10 µs can exceed the actual computation time. And in a real transformer training loop, you might launch hundreds or thousands of kernels per forward pass — a 16-layer transformer with attention, FFN, and normalization at each layer can easily launch 500+ kernels. At 10 µs each, that is 5 ms of pure overhead before any arithmetic happens.
PyTorch 2.0 introduced CUDA graphs to amortize this overhead: record an entire forward pass (all kernel launches) as a single graph, then replay it in one shot. The CPU overhead for 500 launches becomes a single graph-replay call — roughly 100× reduction in CPU overhead. For inference-heavy workloads, CUDA graphs can give 20–30% speedup on small-batch regimes where kernel launch dominates.
But CUDA graphs require that tensor shapes and addresses don't change between runs — no dynamic control flow, no variable-length inputs. They are a specialized optimization. The more general technique is to eliminate unnecessary launches in the first place through operator fusion (Chapter 2), which addresses both the launch overhead and the redundant HBM traffic at once.
python # Measuring kernel launch overhead: time a no-op vs a real op import torch, time x = torch.randn(1, device='cuda') # tiny tensor y = torch.randn(10_000_000, device='cuda') # large tensor def bench(fn, n=200): for _ in range(5): fn() # warmup torch.cuda.synchronize() t = time.perf_counter() for _ in range(n): fn() torch.cuda.synchronize() return (time.perf_counter() - t) / n * 1e6 # µs per call tiny_time = bench(lambda: torch.nn.functional.relu(x)) large_time = bench(lambda: torch.nn.functional.relu(y)) print(f"tiny tensor: {tiny_time:.1f} µs (mostly launch overhead)") print(f"large tensor: {large_time:.1f} µs (mostly real work)") # typical output: tiny: ~8 µs large: ~120 µs # The launch overhead is ~8 µs regardless of tensor size
Operator fusion is the practice of merging multiple GPU operations into a single kernel so that intermediate results never leave the chip. Instead of writing output A to HBM and then reading it back as input to operation B, you pipe A directly from registers into B inside the same kernel thread. This eliminates HBM traffic for all intermediates.
Let's count exactly. Consider the GELU activation: gelu(x) = 0.5 · x · (1 + tanh(0.7979 · (x + 0.04471 · x³))). If you implement this naively in PyTorch, it decomposes into: (1) cube x, (2) scale by 0.04471, (3) add x, (4) scale by 0.7979, (5) tanh. Each step is a separate elementwise kernel. With a tensor of n bf16 values (2 bytes each):
For a tensor of n = 16,384 bf16 elements (32 KB): unfused = 327 KB of HBM traffic; fused = 65 KB. On an A100 with 2 TB/s bandwidth: unfused takes ~163 ns just for the memory transfers; fused takes ~33 ns. The 5× traffic reduction translates directly to a 5× bandwidth-limited speedup.
What can be fused? Elementwise operations are the easiest case — any chain of ops where each output element depends only on the corresponding input element. No reduction, no cross-element dependencies. GELU, ReLU, swish, sigmoid, bias-add, scale — all fusible into a single kernel.
Reductions (sum, max, mean) are harder — they require communication across elements. But you can fuse an elementwise op before a reduction in the same kernel (compute exp(x) and accumulate the sum in one pass). Fused softmax does exactly this: compute max(x), subtract, exp, accumulate sum — all in one kernel, with the row of x read exactly once from HBM.
Set the number of chained elementwise ops and tensor size. Each bar shows HBM traffic per op for unfused vs total for fused. Hover a bar to see the exact MB.
To write a fused kernel, you have three options. They form a ladder of abstraction: CUDA at the bottom (maximum control, maximum boilerplate), Triton in the middle (near-CUDA performance, Python-level code), and torch.compile at the top (automatic fusion, zero kernel code, but less controllable).
CUDA is NVIDIA's GPU programming language — C/C++ with extensions for GPU execution. You write code at the level of individual threads. Each thread is responsible for one or a few elements; you compute its (blockIdx, threadIdx) from the launch configuration to figure out which elements it owns. You manage shared memory explicitly: decide what to load, when to synchronize, when to write back. This gives you maximum performance — hand-written CUDA can squeeze every last FLOP out of the hardware. But a correct CUDA kernel for even a simple fused softmax is 100+ lines of C++ with subtle bugs lurking in the synchronization logic.
Triton (OpenAI, 2021) is a Python-based GPU programming language that operates at the level of thread blocks rather than individual threads. You write code that processes a tile of elements, and Triton's compiler handles thread assignment, memory coalescing, and shared memory management automatically. The key insight from the Triton paper: the hard manual work in CUDA (coalescing, shared memory banks, instruction scheduling) can be automated for most practical kernels — you only need to specify the tiling strategy. Result: Triton kernels are typically 80-100% of CUDA performance at 20% of the code complexity.
torch.compile (PyTorch 2.0+) takes regular Python/PyTorch code and compiles it to optimized Triton kernels automatically. It traces the computation graph, identifies fusible op chains, and generates Triton code behind the scenes. For standard elementwise fusion, it matches hand-written Triton. The limitation: it cannot fuse operations that require tiling strategies it doesn't know (like custom FlashAttention variants), and it adds compilation overhead on the first call.
| Approach | Abstraction level | Coalescing | Shared memory | Scheduling across SMs | Typical use |
|---|---|---|---|---|---|
| CUDA | Thread | Manual | Manual | Manual | Maximum perf, novel algorithms |
| Triton | Thread block (tile) | Automatic | Automatic | Manual | Custom ops, research kernels |
| torch.compile | Python op | Automatic | Automatic | Automatic | Standard fusion, production |
| PyTorch eager | Python op | N/A (each op separate) | N/A | N/A | Debugging, prototyping |
Triton replaces CUDA's per-thread view with a per-tile view. Instead of asking "which element does this thread own?", you ask "which block of elements does this program instance own?" The GPU runs many program instances in parallel, each identified by a program_id. Each instance processes a contiguous tile of elements. The tile size, called BLOCK_SIZE, is a compile-time constant — Triton uses it to generate efficient vectorized code.
The three fundamental Triton primitives are tl.program_id(axis), tl.load(ptr + offsets, mask=mask), and tl.store(ptr + offsets, values, mask=mask). Together, they let you say: "instance pid works on elements [pid·BLOCK_SIZE, (pid+1)·BLOCK_SIZE). Load them, compute, store."
offsets < n silently discards out-of-bounds operations. Triton makes masking easy; CUDA requires manual boundary checking.triton import triton import triton.language as tl # Triton kernel: process BLOCK_SIZE elements per program instance @triton.jit def add_kernel(x_ptr, y_ptr, z_ptr, n, BLOCK_SIZE: tl.constexpr): # Step 1: which tile does this program instance own? pid = tl.program_id(axis=0) # e.g. 0, 1, 2, 3, ... block_start = pid * BLOCK_SIZE # first element of this tile # Step 2: compute indices for all elements in this tile offsets = block_start + tl.arange(0, BLOCK_SIZE) # shape [BLOCK_SIZE] # Step 3: mask out-of-bounds elements (if n % BLOCK_SIZE != 0) mask = offsets < n # Step 4: load from HBM (coalescing automatic — Triton ensures consecutive) x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) # Step 5: compute (stays in registers) z = x + y # Step 6: store to HBM tl.store(z_ptr + offsets, z, mask=mask) # Launch: determine grid = number of program instances def add(x, y): z = torch.empty_like(x) n = x.numel() BLOCK_SIZE = 1024 grid = (triton.cdiv(n, BLOCK_SIZE),) # ceiling division add_kernel[grid](x, y, z, n, BLOCK_SIZE=BLOCK_SIZE) return z
The launch grid (triton.cdiv(n, BLOCK_SIZE),) tells the GPU how many program instances to create. For n=8192 and BLOCK_SIZE=1024: 8 instances (pid 0 through 7). Each handles 1024 elements. The GPU schedules these 8 instances across its 108 SMs — some SMs handle 1 instance, the scheduler distributes them. This is the "scheduling across SMs" that Triton leaves manual: you choose the grid size, but the SM assignment is automatic.
Drag BLOCK_SIZE to see how the tensor is partitioned. Each colored block = one program instance (pid). The last block may be partial (uses mask).
Let's build the fused GELU kernel from scratch. GELU uses a tanh approximation: gelu(x) = 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³))). The constant √(2/π) ≈ 0.7979. Since Triton's tl namespace does not have a tanh primitive, we implement it using the identity tanh(a) = (e2a - 1) / (e2a + 1), which only needs tl.exp.
The PyTorch reference implementation is five lines of scalar Python. The Triton version is structurally identical — the same arithmetic — but operates on a BLOCK_SIZE-wide vector at once, keeping the entire computation in registers without touching HBM for intermediates.
python + triton import torch import triton import triton.language as tl # PyTorch reference (unfused — 5 separate kernels) def manual_gelu(x): return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x))) # Triton fused version — same math, one kernel, one HBM read+write @triton.jit def triton_gelu_kernel(x_ptr, y_ptr, n, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n x = tl.load(x_ptr + offsets, mask=mask) # read once from HBM # All arithmetic stays in registers: a = 0.79788456 * (x + 0.044715 * x * x * x) exp2a = tl.exp(2 * a) tanh = (exp2a - 1) / (exp2a + 1) y = 0.5 * x * (1 + tanh) # result in registers tl.store(y_ptr + offsets, y, mask=mask) # write once to HBM def triton_gelu(x): assert x.is_cuda() and x.is_contiguous() y = torch.empty_like(x) n = x.numel() BLOCK_SIZE = 1024 grid = (triton.cdiv(n, BLOCK_SIZE),) triton_gelu_kernel[grid](x, y, n, BLOCK_SIZE=BLOCK_SIZE) return y # Benchmark comparison (A100, n=16384 elements): # manual_gelu (unfused): ~0.8 ms — 5 kernel launches, 5× HBM traffic # pytorch_gelu (fused): ~0.05 ms — PyTorch's built-in fused kernel # triton_gelu (fused): ~0.05 ms — our Triton version, matches PyTorch # torch.compile(manual): ~0.05 ms — auto-fused by compiler
ld.global.v4.b32 = 128-bit load). This reduces instruction overhead and improves register utilization. Triton's compiler does this automatically based on BLOCK_SIZE. CUDA requires you to implement it manually with explicit vectorized load intrinsics.The takeaway: the Triton kernel is almost byte-for-byte identical to the PyTorch reference — same constants, same operations, same structure. The difference is that we placed the load/store calls manually, ensuring all arithmetic happens in registers between them. The compiler handles the rest: thread assignment, vectorization, coalescing. This is why Triton is so useful for research: you express the algorithm, not the hardware.
a, exp2a, tanh, and y live in:GELU is elementwise — each output depends only on the same-index input. Softmax is different: each output element depends on all inputs in the same row. Computing softmax requires a full-row scan to find the maximum (for numerical stability), then a full-row scan to compute the sum, then a full-row division. Three passes over the same data.
The naive PyTorch decomposition of softmax(x) is:
The Triton fused softmax kernel assigns one program instance per row. Each instance loads the entire row into SRAM (within the kernel's register space for small rows, or iterating over tiles for large rows), computes max, subtracts, exps, sums, normalizes, and writes back — one round-trip to HBM per row, regardless of how many sub-ops happen inside the kernel.
triton @triton.jit def softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, num_cols, BLOCK_SIZE: tl.constexpr): row_idx = tl.program_id(0) # one program per row col_offsets = tl.arange(0, BLOCK_SIZE) # Load the entire row from HBM (one read) x_start = x_ptr + row_idx * x_row_stride x_row = tl.load(x_start + col_offsets, mask=col_offsets < num_cols, other=float('-inf')) # out-of-bounds → -inf for safety # Compute in registers (no HBM traffic for intermediates): x_row = x_row - tl.max(x_row, axis=0) # subtract max for stability numerator = tl.exp(x_row) denominator = tl.sum(numerator, axis=0) y_row = numerator / denominator # Write back (one write) y_start = y_ptr + row_idx * y_row_stride tl.store(y_start + col_offsets, y_row, mask=col_offsets < num_cols) def triton_softmax(x): M, N = x.shape y = torch.empty_like(x) BLOCK_SIZE = triton.next_power_of_2(N) # power-of-2 for efficiency softmax_kernel[(M,)](x, y, x_row_stride=x.stride(0), y_row_stride=y.stride(0), num_cols=N, BLOCK_SIZE=BLOCK_SIZE) return y # Benchmark (M=16384, N=16384, bf16): # manual_softmax (unfused): ~2.1 ms (8MN HBM trips) # pytorch_softmax: ~0.55 ms (built-in fused) # triton_softmax (fused): ~0.52 ms (matches PyTorch!) # torch.compile(manual): ~0.54 ms (auto-fused, close)
triton.next_power_of_2(N) rounds N up to the nearest power of 2. If N=1300, BLOCK_SIZE=2048 — the extra 748 elements are masked out.Standard attention computes S = QKT/√d, P = softmax(S), O = PV. For sequence length N and head dimension d, the score matrix S is N×N. At N=4096, d=64, with 32 heads in bf16: S alone takes 32×4096×4096×2 = 1.07 GB. It is written to HBM, then read back for softmax, then read again for the PV product — three full passes over a 1 GB matrix per layer, per forward pass.
FlashAttention (Dao et al., 2022) eliminates S from HBM entirely. It tiles Q into row-blocks of size Br and K, V into column-blocks of size Bc. For each Q-block, it iterates over all K/V blocks, maintaining a running max and running sum for the online softmax. The PV accumulator is also maintained as a running sum. At the end of the K/V loop, it normalizes the accumulator and writes one block of output O to HBM. S is never materialized.
The challenge: softmax requires the maximum of the entire row before you can compute stable exps. But we're processing the row in tiles — we never have the whole row at once. The online softmax recurrence (Milakov & Gimelshein, 2018) solves this by maintaining two running statistics that can be corrected as each new tile arrives.
pseudocode — FlashAttention forward # Inputs: Q (N×d), K (N×d), V (N×d); tile sizes Br, Bc # Output: O (N×d) — attention output for i_block in range(0, N, Br): Q_i = Q[i_block : i_block + Br] # load Q tile into SRAM [Br × d] # Initialize running stats for this Q block m_i = -inf # running max of QKᵀ scores for this Q block s_i = 0 # running denominator (sum of exps) O_i = 0 # running output accumulator [Br × d] for j_block in range(0, N, Bc): K_j = K[j_block : j_block + Bc] # load K tile [Bc × d] V_j = V[j_block : j_block + Bc] # load V tile [Bc × d] S_ij = Q_i @ K_j.T / sqrt(d) # [Br × Bc] — SRAM only, never HBM m_ij = max(S_ij, dim=-1) # [Br] — max per row in this tile # Online softmax correction m_prev = m_i m_i = max(m_i, m_ij) # update running max correction = exp(m_prev - m_i) # rescale factor for previous sum P_ij = exp(S_ij - m_i[:, None]) # [Br × Bc] — local softmax numerators s_i = s_i * correction + P_ij.sum(dim=-1) # update running sum O_i = O_i * correction[:, None] + P_ij @ V_j # update output accumulator O_i = O_i / s_i[:, None] # normalize by final sum O[i_block : i_block + Br] = O_i # write to HBM once per Q block
HBM traffic analysis. Standard attention for seq_len N, head_dim d, num_heads h, bf16 (2 bytes):
Step through the K/V tiles. Watch the running max m and sum s update as each tile arrives. See why the correction factor exp(m_prev − m_new) ensures the accumulator stays consistent.
Writing a Triton kernel is only the beginning. You must verify it is actually faster — not just different. The performance of GPU code is deeply non-intuitive: a kernel that looks correct can be 10× slower than expected due to register pressure, occupancy, or uncoalesced access patterns that only appear at certain tensor shapes.
The benchmarking protocol has two rules: warmup first (the first kernel execution often triggers JIT compilation, making it artificially slow) and synchronize before timing (CUDA is asynchronous — the CPU submits work to the GPU queue without waiting; torch.cuda.synchronize() blocks until the GPU is truly done). Without synchronization, you measure only the CPU queue-submission time, not the GPU execution time.
python — correct GPU benchmarking import time, torch def benchmark(fn, n_warmup=3, n_trials=10): # Warmup: run without timing to allow JIT + caching for _ in range(n_warmup): fn() torch.cuda.synchronize() # wait for warmup to complete times = [] for _ in range(n_trials): t0 = time.perf_counter() fn() torch.cuda.synchronize() # CRITICAL: wait for GPU times.append((time.perf_counter() - t0) * 1000) # ms return sum(times) / len(times) # mean across trials # Profile to see which CUDA kernels are actually called: with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CUDA], ) as prof: manual_softmax(x) torch.cuda.synchronize() # Sorted by CUDA time — shows each kernel call and its duration print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) # manual_softmax shows 5 separate kernel entries # triton_softmax shows 1 kernel: triton_softmax_kernel_...
The profiler output names CUDA kernels and shows their duration. For manual_softmax, you see five separate entries (max, subtract, exp, sum, divide). For triton_softmax, you see exactly one. This is the fusion you implemented — confirmed by the profiler, not just assumed.
The speedup from fusion depends on arithmetic intensity. For GELU (8 FLOPs/elem / 4 bytes/elem = 2 FLOPs/byte), the fused version reduces HBM traffic 5×, so you expect ~5× speedup — and the benchmarks show it. For softmax (5+ passes → 1 pass), you expect 4–8× speedup. For FlashAttention (49× less traffic), you expect a large speedup, but the actual measured speedup on A100 is 2–4× — because the tiling and reduction logic adds compute overhead that partially offsets the bandwidth savings.
Drag the slider to set the number of chained ops (each with ~2 FLOPs/byte intensity). See the theoretical speedup from fusion vs the estimated actual speedup after accounting for kernel launch overhead and partial L2 reuse.
This lecture sits at the junction of hardware understanding (Lecture 5) and distributed training (Lecture 7). The kernel-writing skills you built here are the foundation for understanding how modern LLMs are actually implemented — not as a stack of generic PyTorch ops, but as a set of carefully engineered fused kernels that collectively achieve 40–70% of theoretical hardware throughput.
| Concept | What it is | Why it matters |
|---|---|---|
| Kernel launch overhead | ~5–20 µs per CUDA kernel launch | Dominates small ops; motivates CUDA graphs & fusion |
| Operator fusion | Merge N kernels into 1; intermediates stay in registers | N× HBM traffic reduction for elementwise chains |
| Triton program_id | Instance index in the launch grid | Maps each instance to its tile: block_start = pid × BLOCK_SIZE |
| tl.load / tl.store | Vector HBM read/write with mask | Coalescing automatic; mask handles boundary tiles |
| BLOCK_SIZE (constexpr) | Tile width, compile-time constant | Allows vectorization, loop unrolling, optimized reductions |
| Thread coarsening | Each Triton thread processes multiple elements | Triton compiler does this automatically; 8× instruction efficiency |
| Online softmax | Running max m + running sum s, corrected per tile | Enables tiled softmax without materializing the full row |
| FlashAttention | Tiling + fusion + online softmax + recomputation | 49× HBM traffic reduction; O(N) memory vs O(N²) |
| torch.compile | Auto-generates Triton kernels from PyTorch code | Matches hand-written Triton for standard elementwise chains |
| cuda.synchronize() | Block CPU until GPU finishes | Required for correct GPU benchmarking |
Related Gleams
"Key principle: organize computation to minimize reads/writes. Key ideas: kernel fusion (warehouse/factory analogy), tiling (shared memory). Automatic compilers will get better over time — but until then, you need to understand what they are compiling toward."
— Tatsu Hashimoto, CS336 Lecture 6 summary