You want a chatbot that runs forever — but the KV cache for a 128K-token conversation is hundreds of GB, and attention is O(N²). This lesson builds the complete toolkit: RoPE position interpolation to extend training range, StreamingLLM's attention-sink insight to stream infinite context at constant memory, H2O heavy-hitter eviction to compress the KV budget, and FlashAttention's tiling trick to make long attention memory-linear. Every claim comes with derived numbers.
You want to build a chatbot that summarizes an entire research paper, reasons over a codebase, or holds a multi-hour conversation. These tasks require 100K–1M tokens of context. Let's compute what that costs — and discover why the KV cache, not the model weights, is the real wall.
Start with a 7B LLaMA-2 model: 7 × 10⁹ parameters × 2 bytes (FP16) = 14 GB for the weights. Now add the KV cache. For LLaMA-2-7B: 32 layers, 32 attention heads (MHA), head dim 128. At N = 128K tokens, the KV cache is: 32 × 2 × 128,000 × 32 × 128 × 2 bytes = 67.1 GB. The KV cache is nearly 5× the weight size — and it grows linearly with N while weights stay fixed.
The crossover point where KV cache overtakes model weights for LLaMA-2-7B MHA: 14 GB = 32 × 2 × N × 32 × 128 × 2 bytes → N = 14 × 10⁹ / (32 × 2 × 32 × 128 × 2) ≈ 26,840 tokens. At 26K tokens, the KV cache equals the weights. Most "long-context" tasks are 50K–200K tokens — in this regime, the KV cache completely dominates.
Attention compute is quadratic: for each generated token, the decode step attends to all N previous tokens — that's N multiplications per head per layer. With 32 layers and 32 heads, each token costs 32 × 32 × N = 1024 × N inner products. At N = 128K that's 1.3 × 10⁸ inner products per generated token. Expensive, but manageable. The memory problem is harder to escape.
This lecture attacks both fronts simultaneously. We'll cover three families of solutions: (1) length extrapolation — train short, run long by adjusting positional encodings (RoPE interpolation); (2) streaming / windowed attention — run infinite-length text with a bounded-size KV cache (StreamingLLM's attention-sink insight); and (3) KV-cache compression — selectively evict less-important token KV pairs (H2O), quantize the KV cache to INT4, and split attention heads into retrieval vs. streaming classes (DuoAttention).
Drag the context-length slider to see when the KV cache overtakes the 14 GB weight line. The crossover at ~27K tokens is the boundary where long-context serving becomes infeasible without compression.
Before we can compress long contexts, we need to understand how modern LLMs represent position — and why they fail when context exceeds the training length. The answer involves Rotary Positional Embedding (RoPE), the positional encoding used in LLaMA, Mistral, GPT-NeoX, and virtually every modern LLM.
The key idea: instead of adding positional vectors to token embeddings, RoPE rotates the query and key vectors before computing attention. Split the d-dimensional embedding into d/2 pairs. Each pair (x_{2i}, x_{2i+1}) is treated as a 2D coordinate and rotated by angle m·θ_i, where m is the token position and θ_i = 10000−2i/d is a dimension-specific frequency. The inner product of Q and K then depends only on the relative position (m−n) — RoPE is a relative positional encoding expressed in the rotation domain.
The rotation matrix R(m) has 2×2 blocks: for dimension pair i, it rotates by m·θ_i. The inner product (R(m)q)ᵀ(R(n)k) = qᵀR(m−n)k — only the relative distance m−n appears. This is elegant: position is encoded in the phase of the rotation, and relative position is the phase difference. No learned position vectors, no additive positional signal that gets mixed with semantic content.
But here's the problem: the model is trained on sequences of length at most Ntrain (e.g., 2K for original LLaMA, 4K for LLaMA-2). During training, it only ever sees rotation angles in the range [0, Ntrain · θ_i]. At inference with longer context, position m > Ntrain produces rotation angles the model has never been trained on. The attention scores computed from out-of-distribution angles are garbage — perplexity spikes from ~5 to hundreds.
ALiBi (Attention with Linear Biases) takes a different approach: instead of rotational embeddings, subtract a linear penalty m_h × |i−j| from every attention score, where m_h is a head-specific slope. No positional vectors at all. This design generalizes better to lengths beyond training (the penalty just keeps growing linearly), and several models (MPT, BLOOM) use it. But most modern LLMs use RoPE — so we need a way to extend it.
There's also the lost-in-the-middle phenomenon: even LLMs with nominally long context windows (128K) show a U-shaped performance curve. Accuracy is high when the relevant information is at the start or end of the context, but drops sharply when it's buried in the middle. This is a training distribution effect: most training documents are short, and models see the beginning and end more often than the middle in long synthetic examples. Long context ≠ effective long context.
The fix for RoPE's length limit is elegantly simple: instead of using position m directly, rescale it so it falls within the trained range. If the model was trained with Ntrain = 4096 and you want to run at Ninf = 32768 (8× extension), map every real position m to m' = m × (Ntrain / Ninf) = m / 8.
Now the largest position in the extended window maps to Ntrain − 1 in the training range. All rotation angles stay inside the trained distribution. The model can see up to 8× longer text without ever encountering out-of-distribution angle values.
The catch: by squashing 32768 positions into 4096 slots, nearby positions become harder to distinguish. Two tokens that were 1 unit apart are now only 1/8 unit apart — the attention mechanism must discriminate much finer angular differences. In practice, a small fine-tuning step (1000–2000 gradient steps on long-context data) is needed to let the model adapt to this compressed positional space. Without fine-tuning, accuracy suffers. With it, LLaMA's context extends from 2K to 32K with near-original perplexity.
NTK-aware interpolation (Neural Tangent Kernel scaling) modifies the base θ rather than the positions: instead of θ_i = 10000−2i/d, use θ_i = (scale × 10000)−2i/d. At high frequencies (small i), the rotation changes fast with position — keep these as-is for local structure. At low frequencies (large i), the rotation changes slowly — these encode long-range dependencies and benefit most from extending the range. By scaling the base, NTK naturally compresses the low-frequency dimensions while leaving high-frequency ones mostly intact. YaRN extends this with an interpolation factor that varies continuously across frequency bands.
Critically, position interpolation requires no architectural change and only minimal compute overhead: one extra multiplication per position during RoPE computation. The extension is purely about what angle you feed to the existing rotation machinery.
The top track shows raw token positions. The bottom track shows interpolated positions after rescaling. Drag the extension ratio slider to see how positions get squashed. Watch how adjacent positions become hard to distinguish at high extension ratios.
You decide to use a sliding window: keep only the most recent L tokens in the KV cache and discard everything older. Memory is now bounded at O(L) regardless of total sequence length. But when you implement this and run a long sequence, perplexity explodes the moment the context length exceeds L and the initial tokens fall out of the window. Why?
The reason is subtle and reveals something deep about how autoregressive transformers work. Examine an attention weight matrix for any large LLM (LLaMA, GPT, Falcon) on a long sequence: the first few tokens always receive disproportionately large attention scores, regardless of their content. A token at position 500 attends heavily to position 0 even when position 0 is just a newline or system prompt character. These are attention sinks — tokens that absorb probability mass without contributing semantic content.
Here's the formal argument. Consider token position m generating attention to positions 1..m. Each attention weight: a_i = softmax(q · k_i / √d)[i] = exp(q·k_i/√d) / Σ_j exp(q·k_j/√d). For the model to learn a "don't care about anything" state (when processing uninformative tokens like punctuation or newlines), it needs some position to absorb the probability mass. The initial tokens — positions 0, 1, 2, 3 — are always in the causal window and are never evicted. They become the default recipients of "floating" attention probability.
Song Han's lab first observed this in the SpAtten project in 2021 (sparse attention for edge inference), but the explanation came in 2023. A key experiment: replace the first 4 tokens of the input with 4 literal "\n" characters (semantically empty). Perplexity stays the same as with the original initial tokens. The model doesn't care about the semantic content of the sinks — it only needs the sink positions to exist. This proved the sink effect is positional, not semantic.
Attention sinks appear in other transformer architectures too. In Vision Transformers (ViT), sink tokens correspond to low-semantic-content background patches. In BERT, the [SEP] token (which appears at the end of every sentence and is always visible) becomes the attention sink. The phenomenon is universal: any transformer trained with causal or structured masking will develop sink tokens at systematically-accessible positions.
Each row is a query token (recent context). Each column is a key token being attended to. Brighter = higher attention weight. Notice: the first few columns (initial tokens) are bright for all rows — these are the sinks. The diagonal is the local/recent context. Use the slider to change how many sink tokens are visible.
Now that we understand attention sinks, the fix is obvious: keep the sinks in the cache. StreamingLLM (Xiao et al., 2023, MIT + Meta) combines a small fixed set of sink tokens (the first 4 tokens) with a standard sliding window of recent tokens. The total cache size is constant: sink_size + window_size, regardless of total sequence length.
Algorithm: maintain two memory regions. Region A: the first 4 tokens (the sinks) — these are never evicted. Region B: a rolling window of the most recent W tokens — when the window is full, the oldest token in region B is evicted. Every new token attends to: (a) all sink tokens, (b) all tokens in the current window. Total attention budget = 4 + W, which is constant.
python class StreamingKVCache: def __init__(self, sink_size=4, window_size=512): self.sink_size = sink_size # keep first N tokens forever self.window_size = window_size # rolling recent buffer self.sink_k, self.sink_v = [], [] # [sink_size, L, H, D] self.recent_k, self.recent_v = [], [] # [window_size, L, H, D] def update(self, new_k, new_v, layer_idx): # new_k, new_v: [batch, heads, 1, head_dim] if len(self.sink_k) < self.sink_size: self.sink_k.append(new_k) # fill sinks first self.sink_v.append(new_v) else: if len(self.recent_k) >= self.window_size: self.recent_k.pop(0) # evict oldest recent token self.recent_v.pop(0) self.recent_k.append(new_k) self.recent_v.append(new_v) def get_kv(self): # Returns (sinks || recent window) for attention k = torch.cat(self.sink_k + self.recent_k, dim=-2) v = torch.cat(self.sink_v + self.recent_v, dim=-2) return k, v # shape: [batch, heads, sink+window, head_dim]
One subtle implementation detail: position IDs must use cache positions, not original text positions. When a token at text position 600 is stored in the window at cache position 20 (sink_size=4, so window position 16), its RoPE encoding should use position 20, not 600. Otherwise the model's positional signal is wrong for the recent tokens. StreamingLLM re-assigns positions based on where in the cache each token lives.
Performance numbers: on LLaMA-2, Falcon, MPT, and Pythia tested up to 4 million tokens, StreamingLLM maintains perplexity within 0.2–0.3 points of the sliding-window-with-recomputation baseline (the gold standard for quality). Compared to naive window attention (which crashes when initial tokens are evicted), StreamingLLM has PPL ~5.40 vs PPL ~5158. Memory is constant at (sink_size + window_size) × KV_per_token. Speedup vs recomputation baseline: up to 22.2× on long sequences.
An additional enhancement: pre-train with a dedicated sink token prepended to every training sequence. This single learnable token becomes the sole attention sink — you only need 1 sink token instead of 4, freeing 3 extra cache slots for recent content. Models pre-trained this way with a sink token use it exclusively and need no other sinks.
StreamingLLM keeps a fixed window of recent tokens plus sinks. But recent ≠ important. What if the key fact in a long document appears 50K tokens ago — well outside any practical window? We need content-aware eviction: keep the tokens that matter, regardless of recency.
H2O (Heavy-Hitter Oracle) (Zhang et al., 2023) is an elegant observation: a small fraction of tokens receive the vast majority of cumulative attention across all layers and heads. Call these the heavy hitters. If we can identify and keep the heavy hitters plus a recent window, we can dramatically compress the KV cache while preserving quality.
The empirical finding: for most LLMs (OPT, LLaMA, GPT-NeoX), fewer than 20% of tokens account for 80%+ of total attention mass. These heavy hitters are consistent across layers and heads — a token that's semantically central (a named entity, a key fact) tends to be a heavy hitter in most heads. This is not random — it reflects the information structure of language.
The algorithm is simple. Maintain a score for each token in the KV cache: its accumulated attention weight summed over all queries so far. When the cache is full and a new token arrives, evict the token with the lowest score (not counting recent tokens, which get a recency bonus to prevent the "cold start" problem).
python class H2OKVCache: def __init__(self, budget=512, recent_ratio=0.2): self.budget = budget # max tokens in cache self.recent = int(budget * recent_ratio) # always keep newest self.heavy = budget - self.recent # keep top-k by score self.scores = [] # accumulated attention per token self.keys, self.values = [], [] def update(self, new_k, new_v, attn_weights): # attn_weights: [seq, current_cache_len] — attention each query gave to each cached token # Accumulate scores: add the column-sum of attention to each token token_scores = attn_weights.sum(dim=0) # [cache_len] — total attention received for i, s in enumerate(token_scores): self.scores[i] += s.item() self.keys.append(new_k); self.values.append(new_v); self.scores.append(0.0) if len(self.keys) > self.budget: # Separate recent tokens (never evict) from candidates candidates = self.scores[:-self.recent] # oldest (budget-recent) tokens evict_idx = candidates.index(min(candidates)) self.keys.pop(evict_idx); self.values.pop(evict_idx); self.scores.pop(evict_idx)
H2O achieves 20× KV cache compression on OPT-6.7B with less than 1 perplexity point loss on text generation tasks. On summarization (CNN/DM) and question answering (TriviaQA), H2O with 20% budget retains over 90% of full-cache accuracy. The baseline (random eviction) with 20% budget drops to near-zero accuracy — heavy-hitter identity is what matters.
Bars show simulated accumulated attention scores for 32 tokens. Budget slider controls how many tokens are kept. Green = kept (heavy hitters + recent), red = evicted. Compare to random eviction on the right.
Eviction strategies (StreamingLLM, H2O) reduce the number of tokens in the KV cache. Quantization reduces the bits per token. These are orthogonal: you can apply both, or choose one based on your use case. Together, INT4 KV + H2O can achieve 8–16× total KV memory reduction.
INT8 KV quantization is straightforward: store K and V tensors as INT8 instead of FP16. Per-token scaling: for each token's key vector k ∈ Rd, compute scale s = max|k| / 127, quantize kq = round(k / s), store (kq, s) instead of k. At attention time, dequantize: k̃ = kq × s. Memory halved, with error bounded by |k - k̃|_∞ ≤ s/2.
INT4 KV quantization is 4× reduction. The challenge: INT4 has only 16 levels (−8 to 7 for signed). Per-token scaling still works, but error is larger: quantization step s = max|k| / 7 vs s = max|k| / 127 for INT8. If max|k| = 5.0 in FP16, INT4 step = 5/7 ≈ 0.71, so values like 2.3 round to 2.0×0.71 = 1.42 — a rounding error of 0.88. INT8 step = 5/127 ≈ 0.039, error ≈ 0.02. INT4 errors are ~25× larger per element.
For LLaMA-2-70B with GQA (8 KV heads, 128 head dim, 80 layers): FP16 at N=4096: 1 × 80 × 2 × 4096 × 8 × 128 × 2 bytes = 1.34 GB. INT8: 0.67 GB. INT4: 0.34 GB. At N=128K: FP16 = 41.9 GB; INT4 = 5.2 GB. INT4 KV makes 128K context practical on a single A100-40GB (40 GB total; model weights in INT4 take ~35 GB, leaving 5 GB for KV cache).
There are additional tricks for KV quantization. Per-channel quantization: different head dimensions have different dynamic ranges — quantize each dimension independently. Per-key quantization: for each new key/value vector, use its own scale. This is more expensive (one scale per vector), but more accurate than a shared scale across multiple tokens. Grouped quantization: group G adjacent tokens and share a scale — balances per-token accuracy vs scale storage overhead (G=32–64 works well in practice).
A second quantization target is the attention score matrix itself. The scores (before softmax) are FP16; for long sequences, the score matrix is N × N. Quantizing scores to FP8 or INT8 saves memory during the prefill computation. FlashAttention's tiling approach (Chapter 9) handles this by never materializing the full N × N matrix at all — it tiles the computation in SRAM and only keeps a small working set in high-bandwidth memory.
StreamingLLM and H2O apply the same compression policy uniformly to all attention heads. But not all heads are equal. DuoAttention (Xiao et al., MIT, 2024) makes a key observation: within a single LLM layer, some heads are retrieval heads (they attend to semantically relevant tokens anywhere in the full context) and others are streaming heads (they attend only to recent tokens and attention sinks). These require fundamentally different KV cache policies.
Retrieval heads need the full KV cache — if you compress their cache, the model can no longer recall distant facts. Streaming heads never attend far back — their "effective window" is just sinks + recency. For streaming heads, a constant-size StreamingLLM-style cache is completely sufficient with no quality loss.
The DuoAttention algorithm: (1) identify retrieval vs. streaming heads using a passkey-retrieval task (embed a random number at a random position in a very long context, ask the model to recall it — heads whose attention patterns actually locate the key are retrieval heads); (2) during inference, apply full KV cache to retrieval heads (typically 25–50% of heads) and streaming cache (sinks + window) to streaming heads.
The head identification is done by assigning a learnable gate α_{i,j} ∈ [0,1] to each KV head (i=layer, j=head). During a short training phase on the passkey-retrieval synthetic dataset (10 random passkey sequences, 32 positions each, very long context): the gate blends full attention and streaming attention outputs — attn_{i,j} = α_{i,j} × full_attn + (1 − α_{i,j}) × streaming_attn. A Lasso regularization pushes gates toward 0 (streaming) or 1 (retrieval). After training, heads with α > 0.5 are classified as retrieval heads.
The head classification is surprisingly stable across models. Retrieval heads tend to cluster in certain layers: early layers tend to have more streaming heads (they process local syntactic patterns), while middle and late layers have more retrieval heads (they handle semantic long-range dependencies). This layer-wise pattern matches our intuition about how transformers process hierarchical information.
This showcase lets you explore the three memory strategies side-by-side: full KV cache (linear growth), naive window (constant but crashes on sink eviction), and StreamingLLM (constant, sink-stable). Watch how memory grows as you stream tokens, and see the proxy-perplexity penalty for each strategy. Drag the sliders to find the break-even point where StreamingLLM's constant memory starts to pay off.
Left panel: KV cache memory (GB) as token count grows. Right panel: proxy perplexity (lower is better). At the "sink eviction point," naive window crashes. StreamingLLM stays stable. Animate to stream tokens, or set context length manually.
We've built a complete toolkit for long-context LLM inference. Here's the cheat sheet — every technique, its mechanism, its complexity, and when to use it.
| Technique | Mechanism | Memory | Quality Loss | Use Case |
|---|---|---|---|---|
| Full KV Cache | Cache all tokens | O(N) — unbounded | None (baseline) | Short contexts (<32K) |
| RoPE + Interpolation | Rescale positions into trained range | O(N) — still linear | Small after fine-tune | Extending training range |
| StreamingLLM | Sinks (4) + sliding window (W) | O(W) — constant | ~+0.3 PPL | Streaming/chat — infinite tokens |
| H2O Eviction | Keep top-k attention-weighted tokens | O(budget) — constant | ~+1 PPL at 20% budget | Retrieval, long-doc QA |
| INT4 KV | Quantize K/V to 4-bit per-token | O(N) but 4× smaller | Negligible (<0.5 PPL) | Any long-context serving |
| DuoAttention | Full cache for retrieval heads, streaming for rest | ~O(N/4) — 4× reduction | ~3% on RULER | Very long context (100K+) |
| FlashAttention | Tiled online-softmax, no O(N²) materialization | O(N) — same tokens | None (exact) | Memory-efficient prefill kernel |
FlashAttention's role (brief, since it's a compute optimization, not a memory-count reduction): standard attention materializes the N × N score matrix in HBM. For N = 32K, that's 32K × 32K × 2 bytes = 2 GB of HBM reads/writes just for the score matrix per layer. FlashAttention tiles the Q, K, V matrices into blocks that fit in SRAM (20–40 MB on A100), computes attention block-by-block with an online softmax algorithm, and never writes the full N × N matrix to HBM. Memory for attention is reduced from O(N²) to O(N) — only the output and the running softmax statistics need HBM. This makes prefill (the forward pass over the full prompt) dramatically faster at long context.
Looking ahead: State Space Models (SSMs) like Mamba replace attention entirely with a recurrent structure that has O(N) total compute and O(1) per-token generation memory — the hidden state is fixed-size regardless of context length. SSMs are a fundamentally different architecture from transformers, where the "KV cache" is replaced by a fixed-size state vector. Hybrid models (Jamba = Mamba + Transformer layers) combine the long-context efficiency of SSMs with the in-context learning power of attention. This is the subject of future lectures.