Language Modeling from Scratch · CS336 · Lecture 10

Inference: Speed, Memory & Tricks

You trained a 70B-parameter model. Now a million users want answers right now. Each generated token reads the entire model from memory — not once, but every step. This lesson derives why decode is fundamentally memory-bound, sizes the KV cache to the byte, plots inference on the roofline, and shows four families of techniques — KV-cache reduction, speculative decoding, quantization, and continuous batching — that make production inference possible.

Prerequisites: CS336 Lec 2 (FLOPs, arithmetic intensity), CS336 Lec 3 (Transformer shapes), basic algebra.
10
Chapters
5
Live Canvases
Derived
Memory-Bound Proof

Chapter 0: The Slow Token

You've just trained a 70B-parameter language model. Training took weeks, but now it's done. Someone types a question into your chatbot and hits Enter. How long before the first word appears?

On a single H100 GPU, with a fresh user message of 100 tokens, the first token of the reply arrives in roughly 30 milliseconds. That feels fast. But then the model keeps generating — and for every subsequent token, it must read all 140 GB of parameters back from memory, plus the growing record of everything generated so far. At the H100's memory bandwidth of 3.35 TB/s, reading 140 GB takes about 42 milliseconds per token. That's a theoretical ceiling of roughly 24 tokens/second for a single user.

Compare that to a human reading speed of about 4–5 words/second, and 24 tok/s sounds fine. But now imagine 1,000 simultaneous users — each wanting their own 24 tok/s stream. You need 1,000 H100s just to serve them at minimum quality. At roughly $2/hr per H100, that's $2,000/hr just to keep the lights on. Inference cost is not a footnote — it is the business model.

Why inference is harder than training. During training you see an entire sequence of T tokens all at once and can compute attention as a giant parallel matrix multiply. During generation you must produce tokens one at a time — each new token depends on all previous ones, and you cannot start token 5 until token 4 is done. That sequential dependence shatters the GPU utilization that makes training so efficient.

Inference shows up in far more places than just chatbots. Every time you run an evaluation benchmark, you run inference over thousands of examples. Reinforcement learning from human feedback (RLHF) and GRPO — the techniques behind GPT-4 and DeepSeek-R1 — require generating millions of rollouts from the policy model. Test-time compute (o1, R1) literally multiplies inference cost by 100–1000× per query. The one-time training cost is increasingly dwarfed by cumulative inference cost over the model's lifetime.

This lecture is about understanding why inference is slow, quantifying the bottleneck precisely, and then surveying the most important techniques for fighting back. We will derive every number from first principles — no hand-waving, no "it's just fast."

Prefill vs Decode: a timeline animation

Watch how prefill processes all prompt tokens in parallel (fast), then decode produces one token at a time (slow). Hit Play to animate, adjust prompt/response length.

Prompt length 12
Response length 16
Inference is expensive at scale primarily because:

Chapter 1: Prefill vs Decode

Every LLM inference call splits into exactly two phases, and they have almost nothing in common. Understanding this split is the foundation of everything in inference optimization.

Prefill is the first phase. You receive a prompt — say, a system message plus a user question, totalling 500 tokens. The model processes all 500 tokens simultaneously in one giant forward pass. Every token can attend to every earlier token in parallel, exactly like during training. This phase is compute-bound: the GPU's arithmetic units are busy the whole time, squeezing out FLOPs efficiently.

Decode is the second phase. Now the model generates the response, one token per step. At each step, the new token can attend to all previous tokens — but only one new token is being processed. This phase is memory-bound: the GPU reads the full weight matrices and the growing key-value cache from HBM, does a tiny amount of arithmetic, and repeats. The arithmetic units sit mostly idle, starved for data.

The key asymmetry. In prefill, the GPU reads each weight once and uses it for T multiplications (one per token in the prompt). In decode, the GPU reads each weight once and uses it for exactly one multiplication (for the single new token). The work-per-byte ratio collapses, turning a compute-limited device into a bandwidth-limited one.

The metrics that matter differ between phases too. For prefill, the relevant metric is time-to-first-token (TTFT) — how long the user waits before any output appears. For decode, the relevant metric is tokens per second (latency per token) or throughput (tokens generated across all users per second). These can be traded off against each other, as we'll see in Ch 4.

There is a third metric worth mentioning: time-to-first-token depends mostly on prefill speed, while end-to-end latency for a full response also depends on decode speed. For a 50-token response on a slow decode path, the two can differ by 5–10× even on identical hardware.

python — naive inference (no KV cache, O(T³) total)
import torch

def naive_generate(model, prompt_ids, max_new_tokens):
    # prompt_ids: [T_prompt] — the starting tokens
    ctx = prompt_ids.clone()  # grows each step

    for step in range(max_new_tokens):
        # Feed ENTIRE context every step — O(T²) per step → O(T³) total
        logits = model(ctx.unsqueeze(0))  # [1, T, vocab]
        next_token = logits[0, -1, :].argmax()  # greedy decode
        ctx = torch.cat([ctx, next_token.unsqueeze(0)])

    return ctx

# With 4096 context and 256 new tokens:
# Step 1: forward pass over 4097 tokens
# Step 2: forward pass over 4098 tokens
# ...
# Step 256: forward pass over 4352 tokens
# Total: ~4200 * 256 ≈ 1M tokens processed for 256 output tokens!

This naive approach is catastrophically inefficient. We process 1 million tokens to generate 256 — a 4,000× overhead. The key observation saving us: everything the model computes for token i at layer l is identical in step t and step t+1. We can cache it. That's the KV cache (Chapter 3).

Two phases, two bottlenecks. Prefill: compute-bound (like training). Decode: memory-bound (like reading a big file). Almost every inference optimization targets one or both phases specifically — understanding which phase a technique targets tells you when it helps.
Prompt arrives
Tp tokens, e.g. 500
PREFILL phase
Process all Tp tokens in parallel • compute-bound • produces KV cache
DECODE phase
Generate 1 token/step • memory-bound • reads weights + KV cache each step
Response complete
Tr tokens generated
During the decode phase, generating token number 200 (with a 500-token prompt and 199 tokens already generated), how many tokens does the model need to process in each forward pass?

Chapter 2: Arithmetic Intensity

To understand why decode is memory-bound, we need to be precise about a concept called arithmetic intensity: the ratio of floating-point operations (FLOPs) to bytes transferred from HBM (high-bandwidth memory). This single number determines whether a workload is compute-limited or memory-limited.

Consider the core operation in every Transformer layer: a matrix multiply X × W. Say X has shape [B×T, D] and W has shape [D, F]. How many FLOPs does this take? For each of the B×T rows of X and each of the F columns of W, we compute a dot product of length D: that's 2×D multiply-adds. Total: 2×B×T×D×F FLOPs.

How many bytes are transferred? We read X (2×B×T×D bytes in bf16), read W (2×D×F bytes), and write the output Y (2×B×T×F bytes). Total: 2BT(D+F) + 2DF bytes.

Arithmetic Intensity = FLOPs / Bytes = (2·B·T·D·F) / (2BT(D+F) + 2DF)

When B×T << D and F (i.e., when the batch is tiny compared to weight dimensions), the denominator is dominated by 2DF — reading the weight matrix. The numerator is 2×B×T×D×F. So the intensity simplifies to:

Intensity ≈ (2·B·T·D·F) / (2·D·F) = B·T

Remarkable: arithmetic intensity equals the number of tokens being processed simultaneously. Let's put numbers on it. An H100 SXM5 has a peak compute of 989 TFLOPs/s and memory bandwidth of 3.35 TB/s. The ratio — the GPU's "hardware intensity" — is:

Hardware intensity = 989×1012 / 3.35×1012 ≈ 295 FLOPs/byte

If your workload's arithmetic intensity is above 295, you're compute-limited (good — GPU fully utilized). If it's below 295, you're memory-limited (bad — GPU sitting idle waiting for data).

The decode bottleneck in one inequality. During decode, T = 1 (one new token per step). With batch size B = 1, arithmetic intensity = 1 FLOPs/byte. The H100 needs 295 to be compute-limited. We are 295× below the breakeven point. This is not a small inefficiency — it is a fundamental structural mismatch between how autoregressive decoding works and how GPUs are designed.

What about attention? The intensity formula there is different because attention's memory depends on context length S rather than weight dimensions. With the KV cache, attending to S tokens while generating T new tokens costs 4×B×S×T×D FLOPs and (4×B×S×D + 4×B×T×D) bytes. Intensity = S×T/(S+T).

For prefill (T = S): intensity = S/2. For a 4096-token prompt, intensity = 2048 — well above the 295 breakeven. For decode (T = 1): intensity = S/(S+1) < 1 — always memory-bound, no matter how long the context. And worse: unlike MLP layers where larger batch size B raises intensity, attention intensity has no dependence on B (each sequence has its own KV cache, so bigger batches don't share weight reads).

Roofline: prefill compute-bound, decode memory-bound

Drag the batch size and sequence length sliders. The dot shows where the workload sits on the roofline. Above the knee = compute-limited; below = memory-limited.

Batch size B 1
Tokens T (decode=1, prefill=T) 1
python — deriving arithmetic intensity for a matmul
def arithmetic_intensity(B, T, D, F):
    # Matrix multiply: X [B*T, D] @ W [D, F] -> Y [B*T, F]
    flops = 2 * B * T * D * F
    bytes_in  = 2*B*T*D + 2*D*F  # read X and W (bf16 = 2 bytes each)
    bytes_out = 2*B*T*F            # write Y
    total_bytes = bytes_in + bytes_out
    return flops / total_bytes

# H100 breakeven: 989e12 FLOPs/s / 3.35e12 bytes/s ≈ 295 FLOPs/byte
H100_INTENSITY = 295

# Decode: B=1, T=1 (single token, single user)
D, F = 4096, 11008  # Llama-2 7B dimensions
print(arithmetic_intensity(1, 1, D, F))   # ≈ 1.0  → memory-bound!
print(arithmetic_intensity(1, 4096, D, F)) # ≈ 2731  → compute-bound
print(arithmetic_intensity(295, 1, D, F)) # ≈ 270  → near breakeven

# Rule of thumb: need B*T > ~295 to become compute-limited on H100
An H100 has hardware intensity ~295 FLOPs/byte. During decode with batch size B = 64 and T = 1 (one new token per step), the MLP arithmetic intensity is approximately B×T = 64. What does this mean?

Chapter 3: The KV Cache

The KV cache is the single most important data structure in LLM inference. Without it, generating T tokens would require O(T³) total FLOPs — impractical for any reasonable context length. With it, we get O(T²) total work, and each decode step costs a constant number of FLOPs per layer.

Here is the key insight. In multi-head attention, each token position i computes a query Qi, key Ki, and value Vi by multiplying the hidden state hi by weight matrices WQ, WK, WV. When we generate token t+1, we need to compute attention scores between the new token and all previous tokens. That requires the K and V vectors for all positions 0 through t. But those were already computed during earlier steps! We would be recomputing them identically if we ran the naive approach.

The KV cache solves this by storing the K and V tensors for every (layer, head, position) that has been computed. At each decode step, we only compute Q, K, V for the new token (one position), then look up all previous K, V from the cache for the attention computation.

Why only K and V, not Q? Query vectors are used to ask "what am I looking for?" For the current token, we need a fresh Q every step (it changes with each new token's representation). But K and V represent "what information do I have?" for each past token — and past tokens never change. Once you've computed K₅ and V₅ for position 5, they remain identical in all future steps.

Now let's size the KV cache precisely. For a Transformer with:

Each key or value vector has dimension NKV×H = d_model (for MHA). For each of the 2 (key + value), L layers, B sequences, S positions:

KV cache size = 2 · L · NKV · H · S · B · 2 bytes

Let's compute this for Llama-2 70B with a 4096-token context (batch size 1, bf16):

KV size = 2 × 80 × 8 × 128 × 4096 × 1 × 2 bytes = 1,342,177,280 bytes ≈ 1.34 GB.

That's for one request! Llama-2 70B model weights are about 140 GB (70B params × 2 bytes/bf16). So for a single 4096-token request, the KV cache adds about 1% to memory. But at batch size 64 and context length 32768 (common in production), the KV cache becomes: 1.34 GB × 64 × (32768/4096) = 172 GB — larger than the model weights themselves.

KV Cache Size Explorer

Adjust model shape, context, and batch size. See how KV cache memory compares to model weight memory. At long contexts, the cache dominates.

Layers L 80
KV heads NKV 8
Head dim H 128
Context length S (tokens) 4096
Batch size B 1
python — KV cache size formula with Llama-2 70B example
def kv_cache_bytes(L, n_kv_heads, head_dim, seq_len, batch, dtype_bytes=2):
    # 2 for key + value, dtype_bytes=2 for bf16
    return 2 * L * n_kv_heads * head_dim * seq_len * batch * dtype_bytes

# Llama-2 70B configuration
L      = 80    # transformer layers
n_kv   = 8     # GQA: only 8 KV heads (vs 64 query heads)
H      = 128   # head dimension

# Single request, 4096 tokens
kv_single = kv_cache_bytes(L, n_kv, H, seq_len=4096, batch=1)
print(f"Single request 4k: {kv_single/1e9:.2f} GB")  # 1.34 GB

# Production: B=64 requests, 32768 tokens each
kv_prod = kv_cache_bytes(L, n_kv, H, seq_len=32768, batch=64)
print(f"64 requests @ 32k:  {kv_prod/1e9:.1f} GB")   # 171.8 GB

model_weights_gb = 70e9 * 2 / 1e9  # 70B params * 2 bytes = 140 GB
print(f"Model weights:      {model_weights_gb:.0f} GB")
print(f"KV/weights ratio:   {kv_prod/1e9/model_weights_gb:.1f}x")  # 1.2x!
Llama-2 70B uses Grouped Query Attention (GQA) with 8 KV heads instead of the full 64 query heads. By what factor does this reduce KV cache size compared to standard Multi-Head Attention (MHA)?

Chapter 4: Roofline & Batching

We've established that decode is memory-bound. The question now is: how do we improve throughput without destroying latency? The answer is batching — processing multiple requests simultaneously — and its more sophisticated cousin, continuous batching.

When serving B requests in parallel, arithmetic intensity for MLP layers becomes B (as derived in Ch 2). For B = 295 on an H100, you'd hit the compute-limited regime. In practice, B = 64–128 is common — intensity 64–128, memory-limited but far better than B = 1.

Let's compute theoretical latency and throughput for Llama-2 13B on an H100. The model has ~13B parameters. In bf16 (2 bytes each), that's 26 GB. The KV cache for S = 1024 tokens, batch B, is 2 × 40 × 40 × 128 × 1024 × B × 2 = 838 MB × B.

Total memory read per decode step = model weights + KV cache = 26,000 MB + 838 MB × B. At H100 bandwidth 3350 GB/s, latency per token (seconds) = total_bytes / bandwidth.

Batch size BKV cache (GB)Total memory (GB)Latency (ms/tok)Throughput (tok/s)
10.8426.848.0125
6453.679.623.82,690
25621424071.63,575

Two observations stand out. First, latency degrades as batch size grows — each token takes longer to appear because the GPU is also working on 63 other requests. Second, throughput grows with batch size but with diminishing returns — going from B=1 to B=64 gives 21× throughput improvement, but going to B=256 only adds another 1.3×. The memory bandwidth is saturating.

The naive batching problem. Traditional batching waits until a full batch of B requests has arrived, then processes them all together, then releases results. For language generation, this is terrible: request #1 might finish in 50 tokens, request #2 might need 500 tokens, but the server holds request #1's completed response hostage until request #2 finishes. Users experience wildly variable and often terrible latency.

Continuous batching (also called iteration-level scheduling, introduced in the ORCA paper, 2022) fixes this by operating at the granularity of a single decode step rather than a full request. After every token generation step, the scheduler checks: did any request finish? If yes, evict it from the batch and add a waiting request. If new requests are waiting, slot them in during the next step.

This requires one technical trick. During attention, each sequence has its own KV cache — sequences can have different lengths. The non-attention operations (MLP layers, LayerNorm) work on individual token representations. So for those layers, we can concatenate all sequences into a flat [total_tokens, D] tensor and process them together, even if they have different lengths. For attention, we process each sequence's KV cache separately. This is called selective batching.

Throughput vs Latency: the batching tradeoff

Drag batch size to see how latency and throughput evolve. The knee of the curve is where adding more users hurts latency faster than it gains throughput. Model: Llama-2 13B on H100.

Batch size B 32
Context length S (tokens) 1024
In continuous batching, what happens when one request in a batch of 32 finishes generating?

Chapter 5: Speculative Decoding

So far we've accepted that decode generates exactly one token per forward pass through the big model. What if we could generate multiple tokens per pass — and guarantee the output is identical to what the big model would have produced alone? This is speculative decoding, and it is one of the most elegant ideas in inference optimization.

The core insight is an asymmetry: verifying a token sequence is much faster than generating it. During prefill, the model processes many tokens simultaneously (compute-bound). If we could turn some decode work into prefill-like work, we'd be much faster. Speculative decoding does exactly this.

The algorithm has three steps. First, a cheap draft model (say 8B parameters) autoregressively generates k token candidates — call them ˜x1, ..., ˜xk. Second, the expensive target model (say 70B) processes all k draft tokens in parallel (one forward pass), producing probability distributions q(˜xi | context) for each position simultaneously. Third, a modified rejection-sampling procedure accepts or rejects each draft token, with a guarantee that the output distribution exactly matches what the target model would have generated token-by-token.

Why is the output distribution exact? Standard rejection sampling with proposal p and target q accepts a sample x with probability min(1, q(x)/p(x)). If rejected, it resamples from the residual distribution (q - p) / normalization. This guarantees the marginal distribution of accepted samples is exactly q — the target. Speculative decoding adds one twist: when a token is rejected, it always generates at least one correction token from the residual, preventing infinite loops.

Let's work through the accept/reject math with a concrete two-vocabulary example. Suppose the vocabulary is just {A, B}. The target model assigns probabilities q(A) = 0.3, q(B) = 0.7. The draft model assigns p(A) = 0.5, p(B) = 0.5 (oversamples A).

Draft samples token A. Accept probability = min(1, q(A)/p(A)) = min(1, 0.3/0.5) = 0.6. So A is accepted 60% of the time. Draft samples token B. Accept probability = min(1, q(B)/p(B)) = min(1, 0.7/0.5) = 1.0. So B is always accepted.

What is the overall probability of outputting A? = P(draft A) × P(accept A) = 0.5 × 0.6 = 0.3 = q(A). ✓ What is the probability of outputting B? = P(draft B) × P(accept B) + P(draft A) × P(reject A) × P(resample B) = 0.5 × 1.0 + 0.5 × 0.4 × 1.0 = 0.5 + 0.2 = 0.7 = q(B). ✓ The marginal distribution is exactly q.

E[accepted tokens per step] = k · α + 1    where α = expected acceptance rate per token

If the draft model is very good (high acceptance rate α ≈ 0.9) and we speculate k = 4 tokens, expected accepted per step = 4 × 0.9 + 1 = 4.6 tokens per target-model pass. The speedup factor over naive decode is approximately (k×α+1) / (1 + k × cost_ratio), where cost_ratio is the draft model's cost relative to the target. For an 8B draft vs 70B target (cost ratio ≈ 0.11), with α = 0.8, k = 4: speedup ≈ 4.2/1.44 ≈ 2.9×.

python — speculative decoding accept/reject sketch
import torch

def speculative_step(draft_model, target_model, context, k=4):
    # Step 1: Draft model generates k candidate tokens autoregressively
    draft_tokens = []
    draft_probs  = []
    ctx = context.clone()
    for _ in range(k):
        p = draft_model(ctx).softmax(dim=-1)[0, -1]  # p(x | ctx)
        tok = torch.multinomial(p, num_samples=1)
        draft_tokens.append(tok)
        draft_probs.append(p[tok])
        ctx = torch.cat([ctx, tok.unsqueeze(0)], dim=1)

    # Step 2: Target model verifies ALL k tokens in ONE parallel pass
    draft_seq = torch.cat([context] + [t.unsqueeze(0) for t in draft_tokens], dim=1)
    target_logits = target_model(draft_seq)  # processes k+1 tokens at once!

    # Step 3: Accept/reject each draft token
    accepted = []
    for i, (tok, p_i) in enumerate(zip(draft_tokens, draft_probs)):
        q_i = target_logits[0, context.shape[1]+i-1].softmax(dim=-1)[tok]
        accept_prob = (q_i / p_i).clamp(max=1.0)
        if torch.rand(1) < accept_prob:
            accepted.append(tok)
        else:
            # Resample from residual distribution max(q-p, 0)
            residual = (target_logits[0, context.shape[1]+i-1].softmax(dim=-1)
                        - draft_logits[i].softmax(dim=-1)).clamp(min=0)
            correction = torch.multinomial(residual / residual.sum(), 1)
            accepted.append(correction)
            break  # stop at first rejection

    return accepted  # at least 1 token, up to k tokens
Speculative Decoding: Expected Speedup Explorer

Drag acceptance rate α and draft count k. See expected tokens per step and theoretical speedup. At α=1, you get k+1 tokens per pass — perfect draft. At α=0, you get exactly 1.

Acceptance rate α 0.80
Draft tokens k 4
Draft/target cost ratio 0.11
What guarantees that speculative decoding produces the exact same output distribution as running only the target model?

Chapter 6: KV Cache Compression

The KV cache is the bottleneck at long contexts (as we saw in Ch 3, it can exceed model weights at production batch sizes). Three architectural families attack this directly: grouped-query attention, multi-head latent attention, and local/hybrid attention.

Grouped-Query Attention (GQA) is the simplest idea. Standard multi-head attention (MHA) has N query heads, N key heads, and N value heads — all three count in the KV cache. Multi-query attention (MQA) reduces to a single KV head shared by all queries. GQA is the middle ground: K KV heads shared among groups of N/K query heads. Llama-2 70B uses K=8, N=64 — an 8× KV cache reduction with minimal accuracy loss.

GQA accuracy holds up. The original GQA paper (Ainslie et al. 2023) showed that GQA models trained from scratch match MHA on perplexity and downstream tasks, while reducing decode memory and latency proportionally to K/N. This is not a lossy approximation — with proper training it's free.

Multi-head Latent Attention (MLA), introduced in DeepSeek-V2, takes a more aggressive approach. Instead of caching N×H-dimensional K and V vectors for each token, MLA projects them down to a compressed latent vector of dimension C before caching, then expands back at inference time. DeepSeek-V2 reduces from N×H = 64×128 = 8192 dimensions to C = 512 + 64 = 576 dimensions — a 14× reduction over MHA.

The extra 64 dimensions handle RoPE positional embeddings, which cannot be absorbed into the low-rank projection. The result: MLA is not only cheaper than MHA in KV memory, it's slightly more accurate — the low-rank bottleneck acts as a regularizer on the key-value representations.

Local (sliding window) attention takes a completely different approach. Instead of each token attending to all S previous tokens, it only attends to the last W tokens (a local window). The KV cache becomes O(W) per sequence instead of O(S) — independent of total sequence length. Multiple layers of local attention with window W effectively cover a context of W×L tokens through transitive information flow across layers.

The catch: local attention can hurt accuracy on tasks requiring long-range recall (finding information from thousands of tokens back). The practical solution, used by Mistral 7B, Character.AI, and many production systems: interleave one global attention layer every 6–8 local layers. The global layers handle long-range retrieval; local layers handle neighboring context at low cost.

MethodKV dims per tokenKV reduction vs MHAAccuracy impact
MHA (baseline)N×H = 8192
MQAH = 12864×Noticeable degradation
GQA (K=8)K×H = 1024Negligible
MLAC = 57614×Slight improvement
Local (W=4096)W×NKV×HS/W× at ctx>WDepends on task

Cross-layer attention (CLA) extends GQA across layers rather than within a layer: multiple consecutive transformer layers share the same KV cache rather than computing fresh K, V for each layer. If 2 consecutive layers share a KV cache, storage halves for those layers. Character.AI's production system combines CLA with GQA to achieve very aggressive KV compression with minimal accuracy loss.

Multi-head latent attention (MLA) stores a compressed C=576-dimensional latent per token instead of the full N×H=8192-dimensional KV vectors. What is the approximate KV memory reduction factor?

Chapter 7: Showcase: Inference System

Time to put everything together. This showcase simulation lets you configure a complete LLM inference deployment — model size, hardware, KV compression, batching strategy, and speculative decoding — and see how each choice affects TTFT, decode latency, throughput, and memory budget. This is the tool Percy uses mentally when designing a serving stack.

The simulation models the theoretical roofline limits. Real systems are within 80% of these numbers for well-optimized serving engines like vLLM, TensorRT-LLM, and TGI.

Inference System Configurator

Configure your serving stack. All numbers are theoretical roofline estimates assuming perfect memory bandwidth utilization. Memory bar shows: weights (blue) + KV cache (orange). Red line = GPU capacity.

Model params (B)13
Layers L40
KV heads NKV40
Head dim H128
Batch size B32
Context S (tokens)2048
Spec. decoding k0 (off)
GPU memory (GB)80
Real deployment insight. In practice, operators tune batch size to maximize GPU memory utilization (keep KV caches full) without overflowing. Continuous batching means new requests fill slots dynamically. The "right" batch size is not a fixed number — it's the maximum that fits in memory given the distribution of context lengths at any moment.

Chapter 8: Quantization & Pruning

Everything discussed so far has been lossless — exact outputs, exact distributions. Quantization and pruning accept a controlled quality tradeoff in exchange for dramatic efficiency gains. When done carefully, the degradation is undetectable on most benchmarks.

Quantization reduces the number of bits used to represent each weight or activation. Since inference is memory-bandwidth-limited, fewer bytes per weight means proportionally faster inference — not through better arithmetic, but through less data movement.

The number formats commonly used in production:

The key challenge in quantization is outliers. In large language models, a small fraction of activations (and weights) have very large magnitudes. Standard quantization maps the full range to -128..127 — when a few values are 100× larger than the rest, the quantization grid becomes too coarse for the typical values.

LLM.int8() (Dettmers et al. 2022) solves this by using mixed precision: identify the outlier dimensions (typically 0.1–1% of columns), keep those in fp16, and quantize everything else to int8. The overhead of handling two precision levels costs about 15–23% more time than pure fp16, but weight memory halves — worth it at large model sizes where the bottleneck is bandwidth.

Activation-aware weight quantization (AWQ) takes a different approach. Rather than handling outliers explicitly, it identifies the 1% of weights that are most important (determined by activation statistics) and keeps them at high precision, quantizing the rest to int3. This gives 4× memory reduction and 3.2× speedup with minimal accuracy loss.

Model pruning removes parts of the model entirely — entire layers, attention heads, or hidden dimensions — and then repairs the damage through distillation. The NVIDIA paper (Muralidharan et al. 2024) demonstrates a three-step recipe: (1) identify unimportant structure using 1024 calibration examples; (2) prune to a smaller architecture; (3) distill from the original model into the pruned one. This produces models like Llama-3.1-Minitron-4B from Llama-3.1-8B — half the size, 95% of the accuracy.

Distillation vs quantization. Quantization is applied post-training and preserves the architecture. Distillation creates a fundamentally smaller model but requires training time. For latency-critical applications, distillation is often superior because a smaller bf16 model can be faster than a larger int4 model of the same parameter count — the model is genuinely smaller, not just compressed.
MethodMemory factorLatency factorAccuracyTraining needed?
bf16 baselineExactNo
fp80.5×~1.8×Nearly losslessOptional
int8 (LLM.int8)0.5×~1.2×~1% dropNo
int4 (AWQ)0.25×~3×~2–5% dropCalibration only
Pruning + distill0.5× model size2× (smaller model)<5% dropYes
Why does quantization improve decode latency even though the GPU does the same number of arithmetic operations?

Chapter 9: Connections & Cheat Sheet

Inference optimization is a vast field with techniques from multiple disciplines — systems, statistics, hardware architecture, and ML. Here's how the concepts in this lesson connect to the broader CS336 series and adjacent topics.

Cheat Sheet

ConceptFormula / Key NumberWhat it means
Arithmetic intensityFLOPs / bytes = B×T (MLP)Work per byte read; must exceed ~295 on H100 to be compute-limited
KV cache size2×L×NKV×H×S×B×2 bytesGrows linearly with all dimensions; dominates at long context
Decode latency (theoretical)(weights + KV) / bandwidthLlama-2 13B @ B=1: ~8ms/token
Speculative speedup(k×α+1) / (1+k×r)r = draft/target cost ratio; α = acceptance rate
GQA KV reductionNQ/NKVLlama-2 70B: 64/8 = 8×
H100 breakeven intensity989 TF/s ÷ 3.35 TB/s ≈ 295Need B×T > 295 for compute-limited MLP

Connections in CS336

Related Gleams

Key Papers

Closing thought. Training a model is a one-time capital expense. Every inference call is an operating expense — paid per token, per user, per day, forever. The engineers who understand both sides of this equation — what it costs to train and what it costs to serve — are the ones who design systems that actually scale. This lecture is about developing that second set of instincts.
Which combination of techniques would you apply to maximize throughput (tokens/sec across all users) for a batch-processing workload where per-request latency doesn't matter, running a 70B model on 2× H100 GPUs?