You trained a 70B-parameter model. Now a million users want answers right now. Each generated token reads the entire model from memory — not once, but every step. This lesson derives why decode is fundamentally memory-bound, sizes the KV cache to the byte, plots inference on the roofline, and shows four families of techniques — KV-cache reduction, speculative decoding, quantization, and continuous batching — that make production inference possible.
You've just trained a 70B-parameter language model. Training took weeks, but now it's done. Someone types a question into your chatbot and hits Enter. How long before the first word appears?
On a single H100 GPU, with a fresh user message of 100 tokens, the first token of the reply arrives in roughly 30 milliseconds. That feels fast. But then the model keeps generating — and for every subsequent token, it must read all 140 GB of parameters back from memory, plus the growing record of everything generated so far. At the H100's memory bandwidth of 3.35 TB/s, reading 140 GB takes about 42 milliseconds per token. That's a theoretical ceiling of roughly 24 tokens/second for a single user.
Compare that to a human reading speed of about 4–5 words/second, and 24 tok/s sounds fine. But now imagine 1,000 simultaneous users — each wanting their own 24 tok/s stream. You need 1,000 H100s just to serve them at minimum quality. At roughly $2/hr per H100, that's $2,000/hr just to keep the lights on. Inference cost is not a footnote — it is the business model.
Inference shows up in far more places than just chatbots. Every time you run an evaluation benchmark, you run inference over thousands of examples. Reinforcement learning from human feedback (RLHF) and GRPO — the techniques behind GPT-4 and DeepSeek-R1 — require generating millions of rollouts from the policy model. Test-time compute (o1, R1) literally multiplies inference cost by 100–1000× per query. The one-time training cost is increasingly dwarfed by cumulative inference cost over the model's lifetime.
This lecture is about understanding why inference is slow, quantifying the bottleneck precisely, and then surveying the most important techniques for fighting back. We will derive every number from first principles — no hand-waving, no "it's just fast."
Watch how prefill processes all prompt tokens in parallel (fast), then decode produces one token at a time (slow). Hit Play to animate, adjust prompt/response length.
Every LLM inference call splits into exactly two phases, and they have almost nothing in common. Understanding this split is the foundation of everything in inference optimization.
Prefill is the first phase. You receive a prompt — say, a system message plus a user question, totalling 500 tokens. The model processes all 500 tokens simultaneously in one giant forward pass. Every token can attend to every earlier token in parallel, exactly like during training. This phase is compute-bound: the GPU's arithmetic units are busy the whole time, squeezing out FLOPs efficiently.
Decode is the second phase. Now the model generates the response, one token per step. At each step, the new token can attend to all previous tokens — but only one new token is being processed. This phase is memory-bound: the GPU reads the full weight matrices and the growing key-value cache from HBM, does a tiny amount of arithmetic, and repeats. The arithmetic units sit mostly idle, starved for data.
The metrics that matter differ between phases too. For prefill, the relevant metric is time-to-first-token (TTFT) — how long the user waits before any output appears. For decode, the relevant metric is tokens per second (latency per token) or throughput (tokens generated across all users per second). These can be traded off against each other, as we'll see in Ch 4.
There is a third metric worth mentioning: time-to-first-token depends mostly on prefill speed, while end-to-end latency for a full response also depends on decode speed. For a 50-token response on a slow decode path, the two can differ by 5–10× even on identical hardware.
python — naive inference (no KV cache, O(T³) total) import torch def naive_generate(model, prompt_ids, max_new_tokens): # prompt_ids: [T_prompt] — the starting tokens ctx = prompt_ids.clone() # grows each step for step in range(max_new_tokens): # Feed ENTIRE context every step — O(T²) per step → O(T³) total logits = model(ctx.unsqueeze(0)) # [1, T, vocab] next_token = logits[0, -1, :].argmax() # greedy decode ctx = torch.cat([ctx, next_token.unsqueeze(0)]) return ctx # With 4096 context and 256 new tokens: # Step 1: forward pass over 4097 tokens # Step 2: forward pass over 4098 tokens # ... # Step 256: forward pass over 4352 tokens # Total: ~4200 * 256 ≈ 1M tokens processed for 256 output tokens!
This naive approach is catastrophically inefficient. We process 1 million tokens to generate 256 — a 4,000× overhead. The key observation saving us: everything the model computes for token i at layer l is identical in step t and step t+1. We can cache it. That's the KV cache (Chapter 3).
To understand why decode is memory-bound, we need to be precise about a concept called arithmetic intensity: the ratio of floating-point operations (FLOPs) to bytes transferred from HBM (high-bandwidth memory). This single number determines whether a workload is compute-limited or memory-limited.
Consider the core operation in every Transformer layer: a matrix multiply X × W. Say X has shape [B×T, D] and W has shape [D, F]. How many FLOPs does this take? For each of the B×T rows of X and each of the F columns of W, we compute a dot product of length D: that's 2×D multiply-adds. Total: 2×B×T×D×F FLOPs.
How many bytes are transferred? We read X (2×B×T×D bytes in bf16), read W (2×D×F bytes), and write the output Y (2×B×T×F bytes). Total: 2BT(D+F) + 2DF bytes.
When B×T << D and F (i.e., when the batch is tiny compared to weight dimensions), the denominator is dominated by 2DF — reading the weight matrix. The numerator is 2×B×T×D×F. So the intensity simplifies to:
Remarkable: arithmetic intensity equals the number of tokens being processed simultaneously. Let's put numbers on it. An H100 SXM5 has a peak compute of 989 TFLOPs/s and memory bandwidth of 3.35 TB/s. The ratio — the GPU's "hardware intensity" — is:
If your workload's arithmetic intensity is above 295, you're compute-limited (good — GPU fully utilized). If it's below 295, you're memory-limited (bad — GPU sitting idle waiting for data).
What about attention? The intensity formula there is different because attention's memory depends on context length S rather than weight dimensions. With the KV cache, attending to S tokens while generating T new tokens costs 4×B×S×T×D FLOPs and (4×B×S×D + 4×B×T×D) bytes. Intensity = S×T/(S+T).
For prefill (T = S): intensity = S/2. For a 4096-token prompt, intensity = 2048 — well above the 295 breakeven. For decode (T = 1): intensity = S/(S+1) < 1 — always memory-bound, no matter how long the context. And worse: unlike MLP layers where larger batch size B raises intensity, attention intensity has no dependence on B (each sequence has its own KV cache, so bigger batches don't share weight reads).
Drag the batch size and sequence length sliders. The dot shows where the workload sits on the roofline. Above the knee = compute-limited; below = memory-limited.
python — deriving arithmetic intensity for a matmul def arithmetic_intensity(B, T, D, F): # Matrix multiply: X [B*T, D] @ W [D, F] -> Y [B*T, F] flops = 2 * B * T * D * F bytes_in = 2*B*T*D + 2*D*F # read X and W (bf16 = 2 bytes each) bytes_out = 2*B*T*F # write Y total_bytes = bytes_in + bytes_out return flops / total_bytes # H100 breakeven: 989e12 FLOPs/s / 3.35e12 bytes/s ≈ 295 FLOPs/byte H100_INTENSITY = 295 # Decode: B=1, T=1 (single token, single user) D, F = 4096, 11008 # Llama-2 7B dimensions print(arithmetic_intensity(1, 1, D, F)) # ≈ 1.0 → memory-bound! print(arithmetic_intensity(1, 4096, D, F)) # ≈ 2731 → compute-bound print(arithmetic_intensity(295, 1, D, F)) # ≈ 270 → near breakeven # Rule of thumb: need B*T > ~295 to become compute-limited on H100
The KV cache is the single most important data structure in LLM inference. Without it, generating T tokens would require O(T³) total FLOPs — impractical for any reasonable context length. With it, we get O(T²) total work, and each decode step costs a constant number of FLOPs per layer.
Here is the key insight. In multi-head attention, each token position i computes a query Qi, key Ki, and value Vi by multiplying the hidden state hi by weight matrices WQ, WK, WV. When we generate token t+1, we need to compute attention scores between the new token and all previous tokens. That requires the K and V vectors for all positions 0 through t. But those were already computed during earlier steps! We would be recomputing them identically if we ran the naive approach.
The KV cache solves this by storing the K and V tensors for every (layer, head, position) that has been computed. At each decode step, we only compute Q, K, V for the new token (one position), then look up all previous K, V from the cache for the attention computation.
Now let's size the KV cache precisely. For a Transformer with:
Each key or value vector has dimension NKV×H = d_model (for MHA). For each of the 2 (key + value), L layers, B sequences, S positions:
Let's compute this for Llama-2 70B with a 4096-token context (batch size 1, bf16):
KV size = 2 × 80 × 8 × 128 × 4096 × 1 × 2 bytes = 1,342,177,280 bytes ≈ 1.34 GB.
That's for one request! Llama-2 70B model weights are about 140 GB (70B params × 2 bytes/bf16). So for a single 4096-token request, the KV cache adds about 1% to memory. But at batch size 64 and context length 32768 (common in production), the KV cache becomes: 1.34 GB × 64 × (32768/4096) = 172 GB — larger than the model weights themselves.
Adjust model shape, context, and batch size. See how KV cache memory compares to model weight memory. At long contexts, the cache dominates.
python — KV cache size formula with Llama-2 70B example def kv_cache_bytes(L, n_kv_heads, head_dim, seq_len, batch, dtype_bytes=2): # 2 for key + value, dtype_bytes=2 for bf16 return 2 * L * n_kv_heads * head_dim * seq_len * batch * dtype_bytes # Llama-2 70B configuration L = 80 # transformer layers n_kv = 8 # GQA: only 8 KV heads (vs 64 query heads) H = 128 # head dimension # Single request, 4096 tokens kv_single = kv_cache_bytes(L, n_kv, H, seq_len=4096, batch=1) print(f"Single request 4k: {kv_single/1e9:.2f} GB") # 1.34 GB # Production: B=64 requests, 32768 tokens each kv_prod = kv_cache_bytes(L, n_kv, H, seq_len=32768, batch=64) print(f"64 requests @ 32k: {kv_prod/1e9:.1f} GB") # 171.8 GB model_weights_gb = 70e9 * 2 / 1e9 # 70B params * 2 bytes = 140 GB print(f"Model weights: {model_weights_gb:.0f} GB") print(f"KV/weights ratio: {kv_prod/1e9/model_weights_gb:.1f}x") # 1.2x!
We've established that decode is memory-bound. The question now is: how do we improve throughput without destroying latency? The answer is batching — processing multiple requests simultaneously — and its more sophisticated cousin, continuous batching.
When serving B requests in parallel, arithmetic intensity for MLP layers becomes B (as derived in Ch 2). For B = 295 on an H100, you'd hit the compute-limited regime. In practice, B = 64–128 is common — intensity 64–128, memory-limited but far better than B = 1.
Let's compute theoretical latency and throughput for Llama-2 13B on an H100. The model has ~13B parameters. In bf16 (2 bytes each), that's 26 GB. The KV cache for S = 1024 tokens, batch B, is 2 × 40 × 40 × 128 × 1024 × B × 2 = 838 MB × B.
Total memory read per decode step = model weights + KV cache = 26,000 MB + 838 MB × B. At H100 bandwidth 3350 GB/s, latency per token (seconds) = total_bytes / bandwidth.
| Batch size B | KV cache (GB) | Total memory (GB) | Latency (ms/tok) | Throughput (tok/s) |
|---|---|---|---|---|
| 1 | 0.84 | 26.84 | 8.0 | 125 |
| 64 | 53.6 | 79.6 | 23.8 | 2,690 |
| 256 | 214 | 240 | 71.6 | 3,575 |
Two observations stand out. First, latency degrades as batch size grows — each token takes longer to appear because the GPU is also working on 63 other requests. Second, throughput grows with batch size but with diminishing returns — going from B=1 to B=64 gives 21× throughput improvement, but going to B=256 only adds another 1.3×. The memory bandwidth is saturating.
Continuous batching (also called iteration-level scheduling, introduced in the ORCA paper, 2022) fixes this by operating at the granularity of a single decode step rather than a full request. After every token generation step, the scheduler checks: did any request finish? If yes, evict it from the batch and add a waiting request. If new requests are waiting, slot them in during the next step.
This requires one technical trick. During attention, each sequence has its own KV cache — sequences can have different lengths. The non-attention operations (MLP layers, LayerNorm) work on individual token representations. So for those layers, we can concatenate all sequences into a flat [total_tokens, D] tensor and process them together, even if they have different lengths. For attention, we process each sequence's KV cache separately. This is called selective batching.
Drag batch size to see how latency and throughput evolve. The knee of the curve is where adding more users hurts latency faster than it gains throughput. Model: Llama-2 13B on H100.
So far we've accepted that decode generates exactly one token per forward pass through the big model. What if we could generate multiple tokens per pass — and guarantee the output is identical to what the big model would have produced alone? This is speculative decoding, and it is one of the most elegant ideas in inference optimization.
The core insight is an asymmetry: verifying a token sequence is much faster than generating it. During prefill, the model processes many tokens simultaneously (compute-bound). If we could turn some decode work into prefill-like work, we'd be much faster. Speculative decoding does exactly this.
The algorithm has three steps. First, a cheap draft model (say 8B parameters) autoregressively generates k token candidates — call them ˜x1, ..., ˜xk. Second, the expensive target model (say 70B) processes all k draft tokens in parallel (one forward pass), producing probability distributions q(˜xi | context) for each position simultaneously. Third, a modified rejection-sampling procedure accepts or rejects each draft token, with a guarantee that the output distribution exactly matches what the target model would have generated token-by-token.
Let's work through the accept/reject math with a concrete two-vocabulary example. Suppose the vocabulary is just {A, B}. The target model assigns probabilities q(A) = 0.3, q(B) = 0.7. The draft model assigns p(A) = 0.5, p(B) = 0.5 (oversamples A).
Draft samples token A. Accept probability = min(1, q(A)/p(A)) = min(1, 0.3/0.5) = 0.6. So A is accepted 60% of the time. Draft samples token B. Accept probability = min(1, q(B)/p(B)) = min(1, 0.7/0.5) = 1.0. So B is always accepted.
What is the overall probability of outputting A? = P(draft A) × P(accept A) = 0.5 × 0.6 = 0.3 = q(A). ✓ What is the probability of outputting B? = P(draft B) × P(accept B) + P(draft A) × P(reject A) × P(resample B) = 0.5 × 1.0 + 0.5 × 0.4 × 1.0 = 0.5 + 0.2 = 0.7 = q(B). ✓ The marginal distribution is exactly q.
If the draft model is very good (high acceptance rate α ≈ 0.9) and we speculate k = 4 tokens, expected accepted per step = 4 × 0.9 + 1 = 4.6 tokens per target-model pass. The speedup factor over naive decode is approximately (k×α+1) / (1 + k × cost_ratio), where cost_ratio is the draft model's cost relative to the target. For an 8B draft vs 70B target (cost ratio ≈ 0.11), with α = 0.8, k = 4: speedup ≈ 4.2/1.44 ≈ 2.9×.
python — speculative decoding accept/reject sketch import torch def speculative_step(draft_model, target_model, context, k=4): # Step 1: Draft model generates k candidate tokens autoregressively draft_tokens = [] draft_probs = [] ctx = context.clone() for _ in range(k): p = draft_model(ctx).softmax(dim=-1)[0, -1] # p(x | ctx) tok = torch.multinomial(p, num_samples=1) draft_tokens.append(tok) draft_probs.append(p[tok]) ctx = torch.cat([ctx, tok.unsqueeze(0)], dim=1) # Step 2: Target model verifies ALL k tokens in ONE parallel pass draft_seq = torch.cat([context] + [t.unsqueeze(0) for t in draft_tokens], dim=1) target_logits = target_model(draft_seq) # processes k+1 tokens at once! # Step 3: Accept/reject each draft token accepted = [] for i, (tok, p_i) in enumerate(zip(draft_tokens, draft_probs)): q_i = target_logits[0, context.shape[1]+i-1].softmax(dim=-1)[tok] accept_prob = (q_i / p_i).clamp(max=1.0) if torch.rand(1) < accept_prob: accepted.append(tok) else: # Resample from residual distribution max(q-p, 0) residual = (target_logits[0, context.shape[1]+i-1].softmax(dim=-1) - draft_logits[i].softmax(dim=-1)).clamp(min=0) correction = torch.multinomial(residual / residual.sum(), 1) accepted.append(correction) break # stop at first rejection return accepted # at least 1 token, up to k tokens
Drag acceptance rate α and draft count k. See expected tokens per step and theoretical speedup. At α=1, you get k+1 tokens per pass — perfect draft. At α=0, you get exactly 1.
The KV cache is the bottleneck at long contexts (as we saw in Ch 3, it can exceed model weights at production batch sizes). Three architectural families attack this directly: grouped-query attention, multi-head latent attention, and local/hybrid attention.
Grouped-Query Attention (GQA) is the simplest idea. Standard multi-head attention (MHA) has N query heads, N key heads, and N value heads — all three count in the KV cache. Multi-query attention (MQA) reduces to a single KV head shared by all queries. GQA is the middle ground: K KV heads shared among groups of N/K query heads. Llama-2 70B uses K=8, N=64 — an 8× KV cache reduction with minimal accuracy loss.
Multi-head Latent Attention (MLA), introduced in DeepSeek-V2, takes a more aggressive approach. Instead of caching N×H-dimensional K and V vectors for each token, MLA projects them down to a compressed latent vector of dimension C before caching, then expands back at inference time. DeepSeek-V2 reduces from N×H = 64×128 = 8192 dimensions to C = 512 + 64 = 576 dimensions — a 14× reduction over MHA.
The extra 64 dimensions handle RoPE positional embeddings, which cannot be absorbed into the low-rank projection. The result: MLA is not only cheaper than MHA in KV memory, it's slightly more accurate — the low-rank bottleneck acts as a regularizer on the key-value representations.
Local (sliding window) attention takes a completely different approach. Instead of each token attending to all S previous tokens, it only attends to the last W tokens (a local window). The KV cache becomes O(W) per sequence instead of O(S) — independent of total sequence length. Multiple layers of local attention with window W effectively cover a context of W×L tokens through transitive information flow across layers.
The catch: local attention can hurt accuracy on tasks requiring long-range recall (finding information from thousands of tokens back). The practical solution, used by Mistral 7B, Character.AI, and many production systems: interleave one global attention layer every 6–8 local layers. The global layers handle long-range retrieval; local layers handle neighboring context at low cost.
| Method | KV dims per token | KV reduction vs MHA | Accuracy impact |
|---|---|---|---|
| MHA (baseline) | N×H = 8192 | 1× | — |
| MQA | H = 128 | 64× | Noticeable degradation |
| GQA (K=8) | K×H = 1024 | 8× | Negligible |
| MLA | C = 576 | 14× | Slight improvement |
| Local (W=4096) | W×NKV×H | S/W× at ctx>W | Depends on task |
Cross-layer attention (CLA) extends GQA across layers rather than within a layer: multiple consecutive transformer layers share the same KV cache rather than computing fresh K, V for each layer. If 2 consecutive layers share a KV cache, storage halves for those layers. Character.AI's production system combines CLA with GQA to achieve very aggressive KV compression with minimal accuracy loss.
Time to put everything together. This showcase simulation lets you configure a complete LLM inference deployment — model size, hardware, KV compression, batching strategy, and speculative decoding — and see how each choice affects TTFT, decode latency, throughput, and memory budget. This is the tool Percy uses mentally when designing a serving stack.
The simulation models the theoretical roofline limits. Real systems are within 80% of these numbers for well-optimized serving engines like vLLM, TensorRT-LLM, and TGI.
Configure your serving stack. All numbers are theoretical roofline estimates assuming perfect memory bandwidth utilization. Memory bar shows: weights (blue) + KV cache (orange). Red line = GPU capacity.
Everything discussed so far has been lossless — exact outputs, exact distributions. Quantization and pruning accept a controlled quality tradeoff in exchange for dramatic efficiency gains. When done carefully, the degradation is undetectable on most benchmarks.
Quantization reduces the number of bits used to represent each weight or activation. Since inference is memory-bandwidth-limited, fewer bytes per weight means proportionally faster inference — not through better arithmetic, but through less data movement.
The number formats commonly used in production:
The key challenge in quantization is outliers. In large language models, a small fraction of activations (and weights) have very large magnitudes. Standard quantization maps the full range to -128..127 — when a few values are 100× larger than the rest, the quantization grid becomes too coarse for the typical values.
LLM.int8() (Dettmers et al. 2022) solves this by using mixed precision: identify the outlier dimensions (typically 0.1–1% of columns), keep those in fp16, and quantize everything else to int8. The overhead of handling two precision levels costs about 15–23% more time than pure fp16, but weight memory halves — worth it at large model sizes where the bottleneck is bandwidth.
Activation-aware weight quantization (AWQ) takes a different approach. Rather than handling outliers explicitly, it identifies the 1% of weights that are most important (determined by activation statistics) and keeps them at high precision, quantizing the rest to int3. This gives 4× memory reduction and 3.2× speedup with minimal accuracy loss.
Model pruning removes parts of the model entirely — entire layers, attention heads, or hidden dimensions — and then repairs the damage through distillation. The NVIDIA paper (Muralidharan et al. 2024) demonstrates a three-step recipe: (1) identify unimportant structure using 1024 calibration examples; (2) prune to a smaller architecture; (3) distill from the original model into the pruned one. This produces models like Llama-3.1-Minitron-4B from Llama-3.1-8B — half the size, 95% of the accuracy.
| Method | Memory factor | Latency factor | Accuracy | Training needed? |
|---|---|---|---|---|
| bf16 baseline | 1× | 1× | Exact | No |
| fp8 | 0.5× | ~1.8× | Nearly lossless | Optional |
| int8 (LLM.int8) | 0.5× | ~1.2× | ~1% drop | No |
| int4 (AWQ) | 0.25× | ~3× | ~2–5% drop | Calibration only |
| Pruning + distill | 0.5× model size | 2× (smaller model) | <5% drop | Yes |
Inference optimization is a vast field with techniques from multiple disciplines — systems, statistics, hardware architecture, and ML. Here's how the concepts in this lesson connect to the broader CS336 series and adjacent topics.
| Concept | Formula / Key Number | What it means |
|---|---|---|
| Arithmetic intensity | FLOPs / bytes = B×T (MLP) | Work per byte read; must exceed ~295 on H100 to be compute-limited |
| KV cache size | 2×L×NKV×H×S×B×2 bytes | Grows linearly with all dimensions; dominates at long context |
| Decode latency (theoretical) | (weights + KV) / bandwidth | Llama-2 13B @ B=1: ~8ms/token |
| Speculative speedup | (k×α+1) / (1+k×r) | r = draft/target cost ratio; α = acceptance rate |
| GQA KV reduction | NQ/NKV | Llama-2 70B: 64/8 = 8× |
| H100 breakeven intensity | 989 TF/s ÷ 3.35 TB/s ≈ 295 | Need B×T > 295 for compute-limited MLP |