TinyML & Efficient Deep Learning · MIT 6.5940 · Lecture 12

Efficient Transformers & LLMs

Attention is O(N²) — double the context, quadruple the compute. And generating one token at a time means re-reading the entire key-value cache from memory on every step. This lesson derives both bottlenecks from arithmetic, then works through every practical fix: KV cache mechanics, MQA/GQA cache shrinkage (with exact MB formulas), sparse and linear attention, RoPE/ALiBi length extrapolation, and the prefill-vs-decode split that determines what hardware you actually need.

Prerequisites: TinyML L2 (NN Building Blocks) — Transformer block cost formulas. TinyML L1 (Efficiency Metrics) — arithmetic intensity, roofline model.
10
Chapters
5
Live Canvases
Derived
From First Principles

Chapter 0: The Two Walls

You've built a Transformer. It achieves state-of-the-art on your benchmark. Now you want to deploy it — run it in a server to answer queries, or on a device at the edge. You hit two walls immediately, and they're completely different problems.

Wall 1 — the quadratic wall. Training on sequences of length 512 was fine. You try 4096-token context (about 6 pages of text). Attention memory goes up 64×. Attention compute goes up 64×. The model that fit comfortably in memory now crashes with OOM. The culprit: attention's O(N²) scaling with sequence length N.

Wall 2 — the memory-bandwidth wall. Even at modest context lengths, generating one token at a time is slow. Not because of FLOPs — the GPU's ALUs are mostly idle. The bottleneck is reading the model's weights (and the growing key-value cache) from GPU memory on every single decode step. You're memory-bandwidth bound, not compute bound. Adding more tensor cores won't help.

The key insight: Training and inference have completely different bottlenecks. Training (parallel over all tokens) is compute-bound — throw more FLOPs at it. Autoregressive decode (one token at a time) is memory-bandwidth-bound — the bottleneck is how fast you can read the KV cache from DRAM. Efficient LLM inference must attack BOTH walls with different tools.

The efficiency stack for Transformers is therefore split into two halves. The first half attacks the quadratic wall: sparse attention, linear attention, multi-scale approaches. The second half attacks the bandwidth wall: KV cache compression (MQA/GQA), better positional encodings for length generalization, and quantization of the cache itself.

We'll derive both problems numerically before looking at solutions — because only when you've seen that LLaMA-2-70B's KV cache reaches 10 GB for a single 4096-token sequence do the solutions feel necessary rather than optional.

The Two Walls: Attention Cost vs Sequence Length

Drag the slider. Watch compute cost (FLOPs, orange) grow as N² while KV-cache memory (teal) grows linearly. At N=4096 the compute bar dwarfs everything — but the memory bar is what kills your decode step.

Sequence length N 512
An LLM serving system is generating tokens slowly despite having a powerful GPU. A profiling tool shows the GPU ALU utilization is only 12% during decode. What is the most likely bottleneck, and what kind of fix will help?

Chapter 1: Attention Mechanics — QKV from First Principles

Before we can quantify the problem, we need to understand what attention actually computes. Start with an input sequence of N tokens, each represented as a vector of dimension d. Stack them into a matrix X of shape [N, d].

Queries, Keys, and Values are three linear projections of X. For head h with head dimension dh = d / H (where H is the number of heads):

Q = X · WQ    K = X · WK    V = X · WV

Each W is shape [d, dh]. So Q, K, V are each [N, dh]. The analogy: Q is a search query, K is a document title, V is the document content. The attention mechanism asks "how relevant is each document to this query?" and returns a weighted average of content.

Attention(Q, K, V) = softmax( Q · KT / √dh ) · V

Step by step: (1) compute Q·KT, shape [N, N] — each entry is the dot product of token i's query with token j's key. (2) Divide by √dh to prevent softmax saturation (dot products grow with dh). (3) Apply softmax row-wise → attention weights, shape [N, N], rows sum to 1. (4) Multiply by V, shape [N, dh] → output, shape [N, dh].

Why scale by √dh? If Q and K have unit-variance entries, Q·KT has variance dh (sum of dh products). Without scaling, logits grow as √dh, pushing softmax into saturation where gradients vanish. Dividing by √dh restores unit variance — the softmax sees logits with standard deviation ≈ 1 regardless of dh.

Multi-Head Attention (MHA) runs H independent attention heads in parallel, each with its own WQ, WK, WV. The H outputs (each [N, dh]) are concatenated back to [N, d] and projected through WO. Each head can learn a different attention pattern — one head might focus on syntactic relationships, another on coreference.

Causal masking for autoregressive models: token i can only attend to tokens j ≤ i. This is implemented by adding −∞ to attention logits at positions j > i before softmax, which makes those weights exactly zero. The attention matrix becomes lower-triangular.

python
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    # Q, K, V: [batch, heads, seq_len, d_head]
    d_h = Q.shape[-1]
    scores = Q @ K.transpose(-2, -1) / d_h**0.5  # [B, H, N, N]
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)  # [B, H, N, N]
    return weights @ V                   # [B, H, N, d_head]

# For N=512, d=768, H=12: d_h = 64
# Q @ K.T shape: [B, 12, 512, 512] — the N×N attention matrix
# Memory for this matrix: 512×512×4 bytes × 12 heads = 12.6 MB per batch
In standard MHA with d=1024, H=16 heads, and sequence length N=512, what is the shape of the attention weight matrix (before multiplying by V) for a single head?

Chapter 2: The Quadratic Wall — Counting FLOPs and Memory

Let's count exactly what happens when we compute attention for a sequence of length N with model dimension d and H heads (head dimension dh = d/H).

Step 1 — QKV projections. Three matrix multiplications: X[N,d] · W[d,dh] each. Cost per head: N · d · dh multiply-adds × 2 (mul+add) = 2Nd·dh FLOPs. For all H heads: 6N·d² FLOPs total (since H·dh = d).

Step 2 — Q·KT. Matrix multiply [N, dh] · [dh, N] → [N, N]. Cost per head: 2N²dh FLOPs. For H heads: 2N²d FLOPs. This is the quadratic term.

Step 3 — Attn · V. Matrix multiply [N, N] · [N, dh] → [N, dh]. Same cost as Q·KT: 2N²d FLOPs for all heads.

Step 4 — Output projection. [N, d] · [d, d] → [N, d]. Cost: 2Nd² FLOPs.

Total attention FLOPs per layer: 4Nd² (projections + output) + 4N²d (Q·KT + Attn·V). The FFN adds ~8Nd² per layer (two 4d projections). At short sequences, FFN dominates. At long sequences, the 4N²d term dominates — and it grows with N².

Let's plug in numbers. LLaMA-2-7B: d=4096, H=32, dh=128, L=32 layers.

At N=512: attention cost = 4·512²·4096 = 4.3 GFLOPs per layer. FFN cost = 8·512·4096² = 68.7 GFLOPs per layer. FFN is 16× larger — attention is cheap.

At N=4096: attention cost = 4·4096²·4096 = 275 GFLOPs per layer. FFN cost = 8·4096·4096² = 549 GFLOPs per layer. Now attention is half the cost — and growing.

At N=32768 (long context): attention = 4·32768²·4096 = 17,600 GFLOPs per layer. FFN = 8·32768·4096² = 4,398 GFLOPs per layer. Attention is now 4× more expensive than FFN.

Memory for the attention matrix is equally brutal. At N=4096, H=32, FP16: each layer stores [N, N] per head = 4096² × 32 heads × 2 bytes = 1,073 MB ≈ 1 GB per layer. FlashAttention exists to avoid materializing this matrix in HBM (it tiles into SRAM), but the FLOPs remain.

Attention vs FFN FLOPs — The Crossover Point

Drag the slider to change sequence length N. Orange bar = attention FLOPs (grows N²). Teal bar = FFN FLOPs (grows N linearly). The crossover marks where attention dominates.

Sequence length N 512
For a model with d=4096 and a single layer, at approximately what sequence length N does attention FLOPs (4N²d) equal FFN FLOPs (8Nd²)?

Chapter 3: The KV Cache — Why Decode Is Different

During prefill (processing the prompt), all N tokens exist simultaneously. You compute Q, K, V for all of them in parallel — it's one big matrix multiply. This is compute-bound.

During decode (generating new tokens), you're generating one token at a time. At step t, you have one new token. You need to compute its query Qt [1, dh]. But attention requires Qt to attend to ALL previous keys and values — keys K0...Kt-1 and values V0...Vt-1. You can't throw away the past.

The naive solution: recompute K and V for all past tokens at every decode step. Cost: O(t) per step × O(T) steps = O(T²) total compute just for the decode phase. Unacceptable.

The KV cache is the solution: cache the keys and values from all previous steps. At decode step t, append Kt and Vt to the growing cache, then attend Qt over the full cache. No recomputation — the past is frozen.

python
# KV cache decode loop
def decode_with_kv_cache(model, prompt_tokens, max_new_tokens):
    # Prefill: process entire prompt in parallel
    K_cache = []  # list of K tensors per layer
    V_cache = []  # list of V tensors per layer

    x = model.embed(prompt_tokens)  # [N_prompt, d]
    for layer in model.layers:
        K_new = x @ layer.W_K  # [N_prompt, d_h]
        V_new = x @ layer.W_V  # [N_prompt, d_h]
        K_cache.append(K_new)
        V_cache.append(V_new)
        x = layer.forward_with_kv(x, K_new, V_new)

    # Decode: one token at a time
    tokens = []
    x_t = x[-1:, :]  # take last token's embedding [1, d]
    for step in range(max_new_tokens):
        for i, layer in enumerate(model.layers):
            Q_t = x_t @ layer.W_Q        # [1, d_h]
            K_t = x_t @ layer.W_K        # [1, d_h] — new key
            V_t = x_t @ layer.W_V        # [1, d_h] — new value
            K_cache[i] = torch.cat([K_cache[i], K_t], dim=0)  # append
            V_cache[i] = torch.cat([V_cache[i], V_t], dim=0)  # append
            # Attend Q_t over full cache: [1, t+1] attention weights
            x_t = layer.forward_with_kv(x_t, K_cache[i], V_cache[i])
        next_token = model.lm_head(x_t).argmax()
        tokens.append(next_token)
        x_t = model.embed(next_token.unsqueeze(0))
    return tokens

Now let's count the KV cache size. Each layer stores K and V. Each is shape [N_ctx, dkv] where dkv = H × dh = d. For Hkv key-value heads, dkv = Hkv × dh.

KV cache size = B × L × 2 × N × Hkv × dh × bytes

Where B = batch size, L = number of layers, 2 = K and V, N = sequence length (grows during decode), Hkv = number of KV heads, dh = head dimension, bytes = 2 for FP16.

LLaMA-2-7B: B=1, L=32, H_kv=32, d_h=128, N=4096, FP16: 1×32×2×4096×32×128×2 = 1,073,741,824 bytes = 1 GB.

LLaMA-2-70B: B=1, L=80, H_kv=64, d_h=128, N=4096, FP16: 1×80×2×4096×64×128×2 = 10,737,418,240 bytes = 10 GB. For a single sequence! The 70B model's weights are 140 GB in FP16 — the KV cache adds 7% of that for one user.

Misconception to destroy: "The KV cache is a minor detail." For long-context LLM serving, the KV cache is often LARGER than the model weights divided by batch size. At bs=16, N=4096, LLaMA-70B: KV cache = 16×10 GB = 160 GB. That exceeds the weight budget of two A100-80GB cards. The KV cache is the dominant memory consumer at scale — not the weights.
LLaMA-2-13B has L=40 layers, H=40 heads, d_h=128, and uses FP16. What is the KV cache size (in MB) for a single sequence of length N=2048?

Chapter 4: Prefill vs Decode — Two Completely Different Regimes

LLM inference has two phases with opposite bottlenecks. Understanding which phase you're in determines every optimization decision.

Prefill phase: You receive a prompt of N tokens. Process all N tokens simultaneously — a single large matrix multiply. Input: [N, d] × [d, d] = [N, d]. This is a fat matrix multiply. For N=512, d=4096: the matrix [512, 4096] × [4096, 4096] has N×d × d operations = 4.3 GFLOPs. The GPU's arithmetic units are fully busy. The ratio of FLOPs to bytes read is high. Prefill is compute-bound.

Decode phase: You generate one token at a time. Each decode step: input is [1, d] (just the new token). The weight matrices are still [d, d]. But [1, d] × [d, d] is a matrix-vector multiply. You read the entire weight matrix — d² × bytes — to perform only d² multiply-adds. The arithmetic intensity (FLOPs / bytes) is:

AIdecode = 2d2 FLOPs / (2d2 bytes) = 1 FLOP/byte

A100 GPU: 312 TFLOPS FP16, 2 TB/s HBM bandwidth. Ridge point = 312T / 2T = 156 FLOPs/byte. Any operation with AI < 156 is memory-bandwidth-bound. Decode has AI ≈ 1 FLOP/byte — 156× below the ridge. The GPU's compute units are 99.4% idle during decode. You're limited entirely by how fast you can read weights and KV cache from HBM.

The roofline conclusion: Adding a faster GPU (more TFLOPS) does almost nothing for decode latency. What matters: (1) higher memory bandwidth, (2) smaller working set (fewer bytes to read = less time). This is why KV cache compression (MQA, GQA, quantization) directly improves decode throughput — it shrinks the bytes read per step. More FLOPs cannot.

There's also a batch size effect. At larger batch sizes, multiple sequences are decoded in parallel. The weight matrix read is amortized across bs sequences: AI ≈ bs FLOPs/byte. At bs=156, you reach the ridge point — now you're compute-bound and adding FLOPs helps. For most serving scenarios, bs is 1–32, and you're still memory-bound.

Prefill vs Decode: Arithmetic Intensity on the Roofline

Toggle between prefill and decode. See where each falls on the roofline. Adjust batch size to watch AI grow. The dashed line is the A100's ridge point at 156 FLOPs/byte.

Batch size 1
A serving system has A100 GPUs with 2 TB/s memory bandwidth and 312 TFLOPS. During single-user decode (batch size=1), what is the arithmetic intensity, and what determines the tokens-per-second rate?

Chapter 5: MQA & GQA — Shrinking the KV Cache

The KV cache grows with Hkv (number of KV heads). In standard Multi-Head Attention (MHA), every head has its own K and V projections: H query heads, H key heads, H value heads. Cache size = 2 × L × N × H × dh × bytes.

Multi-Query Attention (MQA) (Shazeer 2019): keep H query heads, but reduce to 1 key head and 1 value head. All H query heads share the same K and V. The attention computation per head h becomes: softmax(Qh · KT / √dh) · V — where K and V are the single shared projections.

Cache reduction with MQA: Hkv drops from H to 1. Cache size = 2 × L × N × 1 × dh × bytes. Reduction factor = H. For H=32: 32× smaller KV cache. LLaMA-2-7B: 1 GB → 31 MB per sequence.

Grouped-Query Attention (GQA) (Ainslie et al. 2023): interpolate between MHA and MQA. Use G groups, each with one K/V head shared among H/G query heads. Hkv = G. Typical G = H/8: for H=32, G=4.

GQA cache = 2 × L × N × G × dh × bytes
Reduction vs MHA = H / G     (e.g., H=32, G=4 → 8× smaller)

LLaMA-2-70B uses GQA with H=64 query heads, G=8 KV groups (dh=128). Cache: 1×80×2×N×8×128×2 = 327,680×N bytes. At N=4096: 1.25 GB vs 10 GB for MHA — an 8× reduction.

Why does quality degrade less with GQA than MQA? With MQA, all 32 query heads share a single K/V — they can't specialize. GQA gives groups of 4 heads their own K/V, preserving some diversity. Empirically, GQA at G=H/8 matches MHA quality on large models, while MQA shows small degradation. The sweet spot: G is set so the cache reduction is ~8× and accuracy matches MHA.

The projection matrices for MQA/GQA also shrink. In MHA, WK and WV are each [d, d] (d = H × dh). In GQA with G groups: WK and WV are [d, G × dh]. For G=4, H=32, dh=128: WK goes from [4096, 4096] to [4096, 512] — 8× fewer parameters for the KV projections. Smaller projections also mean less memory bandwidth per decode step — double win.

python
# GQA projection — fewer KV heads than Q heads
class GroupedQueryAttention(nn.Module):
    def __init__(self, d, n_heads, n_kv_groups):
        super().__init__()
        self.n_heads = n_heads       # e.g. 32
        self.n_kv = n_kv_groups      # e.g. 4 (GQA) or 1 (MQA)
        self.d_h = d // n_heads      # head dim = 128
        self.W_Q = nn.Linear(d, d)                         # [d, H*d_h]
        self.W_K = nn.Linear(d, n_kv_groups * self.d_h)   # [d, G*d_h]
        self.W_V = nn.Linear(d, n_kv_groups * self.d_h)   # [d, G*d_h]
        self.W_O = nn.Linear(d, d)

    def forward(self, x, K_cache=None, V_cache=None):
        B, N, d = x.shape
        Q = self.W_Q(x).view(B, N, self.n_heads, self.d_h)   # [B,N,H,d_h]
        K = self.W_K(x).view(B, N, self.n_kv, self.d_h)     # [B,N,G,d_h]
        V = self.W_V(x).view(B, N, self.n_kv, self.d_h)     # [B,N,G,d_h]
        # Expand K,V from G groups to H heads by repeating
        reps = self.n_heads // self.n_kv
        K = K.repeat_interleave(reps, dim=2)  # [B,N,H,d_h]
        V = V.repeat_interleave(reps, dim=2)  # [B,N,H,d_h]
        # Standard attention from here...
LLaMA-3-8B uses H=32 query heads, G=8 KV groups, d_h=128, L=32 layers, FP16. What is the KV cache size reduction compared to MHA, and what is the actual cache size at N=2048, bs=1?

Chapter 6: Sparse & Linear Attention — Attacking the O(N²) Wall

MQA/GQA shrink the KV cache but don't change the O(N²) FLOPs of the attention computation itself. For very long contexts (N ≥ 8192), the N² term becomes the dominant cost. Three families of approaches attack this.

1. Sparse / Local Attention. The key observation: in practice, most attention weights are near zero. Token i strongly attends to nearby tokens (local window) plus a few long-range tokens. Windowed attention restricts each token to attend only to a window of size W tokens around it. Cost drops from O(N²d) to O(NWd) — linear in N for fixed W.

Local attention cost = N × W × d_h × H × 2     (O(NW) not O(N²))

Sliding window attention (Mistral, Longformer): each token attends to the W/2 tokens before it and W/2 after. With W=4096 and N=32768: cost is 4096/32768 = 1/8 of full attention. The effective receptive field can grow through layers — after L layers, each token has "seen" L×W tokens of context. For L=32, W=4096: effective field = 131,072 tokens.

2. Strided / Sparse Global Attention. Combine local windows with a few "global" tokens (like [CLS]) that attend everywhere. BigBird: local + global + random attention. Cost O(N) with a larger constant. Works well for document understanding.

3. Linear Attention. The expensive step is materializing the N×N attention matrix: softmax(QKT/√d)·V. What if we avoid this entirely? Using the kernel trick: replace softmax(q·kT) with φ(q)·φ(k)T where φ is a positive feature map. Then:

Output = ( φ(Q) · (φ(K)T · V) ) / ( φ(Q) · φ(K)T · 1 )

By computing the inner parenthesis first — φ(K)T·V of shape [dh, dh] — and then multiplying by φ(Q) [N, dh], the total cost is O(N·dh²) instead of O(N²·dh). Since dh ≪ N for long contexts, this is a huge win.

Linear attention tradeoff: Removing softmax changes what the model can represent. Standard softmax attention can "select" a single position (place all weight on one key); linear attention approximates this but can't express it exactly. For tasks requiring exact position lookup (e.g., copying a specific earlier token), linear attention is weaker. In practice, it works well for many NLP tasks but loses ground on very structured retrieval tasks.

The attention pattern visualizer below lets you compare full attention (O(N²)) vs local windowed (O(NW)) vs strided sparse — and see how many FLOPs each saves.

Attention Pattern Visualizer

Each cell (i,j) = token i attends to token j. Orange = attended. Toggle patterns and watch the FLOPs counter update.

Linear attention reformulates softmax attention using the kernel trick. What is the key algebraic rearrangement that reduces complexity from O(N²d) to O(Nd²)?

Chapter 7: RoPE & ALiBi — Length Extrapolation

A Transformer trained on sequences of length 2048 often fails completely on sequences of length 4096 — even though the architecture could theoretically handle it. The problem is positional encoding: the model learns what "position 1500" means but has never seen "position 3000" during training. It extrapolates into uncharted territory.

Absolute Positional Encoding (original Transformer, BERT) adds a learned or sinusoidal vector to each token's embedding before the first layer. This fuses position into the representation directly. Problem: the model never sees positions beyond the training context length. At inference on longer sequences, it encounters out-of-distribution positional embeddings and breaks badly.

Rotary Positional Embedding (RoPE) (Su et al. 2021, used in LLaMA) encodes position directly into Q and K, not into the token embedding. The key idea: for each pair of dimensions (2i, 2i+1), rotate the query/key vector by an angle θ proportional to position m:

RoPEm(x) = [x2icos(mθi) − x2i+1sin(mθi),   x2isin(mθi) + x2i+1cos(mθi)]

Where θi = 10000−2i/d — the same frequencies as sinusoidal PE but applied as rotation. The magic: the dot product Qm·Kn depends only on (m−n), the relative distance. Absolute positions are never stored — only differences matter.

Why RoPE extrapolates better: At position m=3000 (unseen at training time), we've still seen angle differences (m−n) for most values of (m−n). The model has learned that "token 50 positions back" has certain relevance — it doesn't need to have seen "position 3000 attending to position 2950" specifically. The representation is relative, not absolute.

RoPE context extension: To extend from training length L_train to L_test, interpolate by scaling positions: replace position m with m × (L_train / L_test). This keeps all angles within the training distribution. Chen et al. 2023 extended LLaMA from 2048 to 32768 tokens this way with fine-tuning on longer sequences.

Attention with Linear Biases (ALiBi) (Press et al. 2021): instead of adding positional info to embeddings, add a distance-proportional negative bias to attention scores before softmax. For head h:

ALiBi attention scoreij = Qi·Kj / √dh − mh × (i−j)

Where mh is a fixed per-head slope (manually set, geometric sequence: 2−8/H, 2−16/H, etc.). The bias penalizes attending to distant tokens proportionally to distance — a soft locality prior. ALiBi adds no parameters and requires no modifications to the embedding layer.

ALiBi advantage: Because the bias is purely a function of relative distance (i−j) with no learned parameters, ALiBi extrapolates seamlessly to sequences longer than training. A model trained on L=1024 with ALiBi often works well at L=4096 with zero fine-tuning. This makes ALiBi attractive for models where inference context will be longer than training context.
python
# ALiBi bias mask — computed once, reused for all layers
def get_alibi_slopes(n_heads):
    # Geometric sequence of slopes, one per head
    m = 2 ** (- 8 / n_heads)
    slopes = [m ** (i+1) for i in range(n_heads)]  # [H]
    return torch.tensor(slopes)

def get_alibi_mask(seq_len, n_heads):
    slopes = get_alibi_slopes(n_heads)               # [H]
    distances = torch.arange(seq_len).unsqueeze(0) - \
                torch.arange(seq_len).unsqueeze(1)  # [N, N], entry (i,j) = i-j
    distances = distances.abs().float()
    # For causal attention, distances[i,j] = i-j for j<=i, else mask
    alibi = -slopes.view(-1,1,1) * distances.unsqueeze(0)  # [H, N, N]
    return alibi  # add to attention scores before softmax
Why does RoPE generalize better to longer sequences than absolute learned positional embeddings?

Chapter 8: Showcase: The KV Cache Memory Lab

Everything comes together here. Set model architecture parameters, choose an attention variant (MHA / GQA / MQA), and watch the KV cache grow as you generate tokens. The arithmetic intensity readout shows you when you're memory-bound vs compute-bound. This is the calculator every LLM deployment engineer uses daily.

KV Cache Memory Lab — Live Calculator
Layers L32
Q heads H32
Head dim d_h128
Seq length N2048
Batch size1
Try this: Set L=80, H=64, d_h=128, N=4096, bs=1, MHA — that's LLaMA-2-70B. Watch the cache hit 10 GB. Now switch to GQA (H/8 = 8 groups). It drops to 1.25 GB. Now set bs=16: MHA needs 160 GB (two A100s just for KV cache). GQA needs only 20 GB. The difference between deployment and OOM is 8 groups.
You're deploying a model with L=40 layers, H=40 heads, d_h=128, FP16. You want to serve batch size 8 at context length N=4096. Your GPU has 80 GB total. The weights use 26 GB. How much GPU memory does the KV cache consume with MHA? Would GQA (G=8) allow the deployment?

Chapter 9: Connections & Cheat Sheet

You've derived two walls and the full toolkit for tearing them down. Here's the summary every LLM deployment engineer should have memorized.

The Efficiency Cheat Sheet

ConceptFormula / ValueKey Insight
Attention FLOPs (per layer)4Nd² + 4N²dN² term dominates at N > 2d
Attention crossoverN = 2dSet 4N²d = 8Nd² → N = 2d
MHA KV cache (per seq)2·L·N·H·d_h·bytesGrows linearly with N, L, H
GQA cache reductionH / G (e.g. 8× for G=H/8)Cache shrinks, quality preserved
MQA cache reductionH (e.g. 32× for H=32)Max shrinkage, small accuracy cost
Decode arithmetic intensity≈ bs FLOPs/byteMemory-bound below ridge point
A100 ridge point~156 FLOPs/byteNeed bs ≥ 156 to be compute-bound
Local attention costO(NWd)W = window size, linear in N
Linear attention costO(Nd_h²)No N² term; avoids N×N matrix
RoPERotates Q,K by mθ_iRelative position = m−n in dot product
ALiBiSubtracts m_h·|i−j| from scoresNo learned params; zero-shot extrapolation

Worked Numbers: LLaMA-2-70B

Model config: d=8192, H=64, G=8 (GQA), d_h=128, L=80, FP16

Weight size: ~140 GB FP16

KV cache (MHA, N=4096, bs=1): 1×80×2×4096×64×128×2 = 10.7 GB

KV cache (GQA G=8, N=4096, bs=1): 1×80×2×4096×8×128×2 = 1.34 GB

Attention FLOPs at N=4096: 4N²d = 4×4096²×8192 = 549 GFLOPs/layer

Decode AI (bs=1): ≈ 1 FLOP/byte → memory-bound

Prefill time (A100, N=512): dominated by FFN FLOPs → compute-bound

Decode throughput: ≈ bandwidth / (2 × params) = 2TB/s / 280GB ≈ 7 tokens/s per A100 (bs=1)

GQA speedup: 8× less KV cache reads → decode ~8× more throughput at same bandwidth

LLaMA-3-70B uses: H=64, G=8 (same GQA config)

Mistral-7B uses: H=32, G=8 KV groups + sliding window W=4096

The Technique Decision Tree

Context > 8K tokens?
Attention FLOPs become dominant
yes → use sparse/linear attention  |  no → standard attention is fine
Decode memory-bound?
KV cache > GPU memory / batch
yes → use GQA (quality) or MQA (max compression)
Need length extrapolation?
Inference context > training context
yes → RoPE (with interpolation) or ALiBi
Maximize throughput at fixed quality?
High-traffic serving
→ GQA + KV cache quantization (INT8 keys) + continuous batching

Related Lessons

The techniques here connect directly to several other lessons:

"The purpose of computing is insight, not numbers." — Richard Hamming. The formulas in this lesson are a means to an end: understanding why LLMs are hard to serve, and what to do about it.
A new LLM architecture paper claims "our method achieves 10× faster generation than standard Transformers by doubling the number of attention heads." Is this claim plausible? What bottleneck analysis would you apply?