Language Modeling from Scratch · CS336 · Lecture 3

Architectures & Hyperparameters

Every "modern" LLM architecture decision explained from first principles: pre-norm vs post-norm, RMSNorm, SwiGLU, RoPE, MQA/GQA, KV-cache math, and every hyperparameter from FFN-ratio to aspect ratio. Derive the formula, run the numbers, see what every real model chose.

Prerequisites: Transformer basics + CS336 Lec 2 resource accounting. Equations in words & HTML math — no LaTeX required.
10
Chapters
5
Live Canvases
300+
Lines of Insight

Chapter 0: The Design Space

You're sitting down to implement a transformer from scratch. You know the rough shape: embedding layer, stack of attention + FFN blocks, final projection to vocabulary. But when you open the original Vaswani 2017 paper and the LLaMA 3 technical report side by side, you find they share almost nothing beyond that rough shape. Different normalization. Different activations. Different positional encoding. Different attention mechanism. Different hyperparameter ratios.

Tatsuya Hashimoto (Tatsu H) opens CS336 Lecture 3 with a striking observation: since the 2024 version of this course, over 19 new dense LLM architectures were publicly released — most with minor but deliberate tweaks to the standard transformer. Each choice was made for a reason. Some of those reasons are theoretical. Most are empirical. A few turned out to be wrong in retrospect.

This lecture is not about which architecture is "best." It's about learning to read the evidence — understanding what choices have near-universal consensus today (pre-norm, RMSNorm, RoPE, SwiGLU), which ones have reasonable tradeoffs (GQA vs MHA), and which ones are surprisingly unconstrained (FFN width ratio, aspect ratio). By the end, you'll be able to read any new model's technical report and immediately classify every architecture decision.

The LLaMA baseline. The "simple, modern variant" you implemented in CS336 assignments is essentially the LLaMA architecture: pre-norm with RMSNorm, SwiGLU activations, RoPE positional embeddings, no bias terms, GQA attention. These choices are not arbitrary — each one has a story. This lecture tells those stories.
Original Transformer (2017)
Post-norm LayerNorm · sinusoidal position embeddings · ReLU activations · bias everywhere · MHA
↓ six years of ablations
Modern LLM (2023+)
Pre-norm RMSNorm · RoPE · SwiGLU (no bias) · GQA · no dropout

We'll cover each of these transitions in depth — deriving why the modern choice works, showing the math and the code, and citing the exact papers and models that drove the change.

The LLaMA architecture (pre-norm + RMSNorm + SwiGLU + RoPE) was invented for LLaMA specifically. True or false?

Chapter 1: Pre- vs Post-Norm

The single thing that almost every modern LLM agrees on — more than any other architectural choice — is pre-norm. In post-norm (the original transformer), LayerNorm sits after the residual addition: you compute the sub-layer output, add it to the residual, then normalize. In pre-norm, the LayerNorm moves inside the block before the sub-layer computation: you normalize first, feed through attention or FFN, then add back to the (un-normalized) residual stream.

Why does this matter? The key is what happens to the residual stream — the running sum that passes through all layers. In post-norm, every residual add is immediately renormalized. The gradient flowing backward must pass through that normalization operation at every layer. LayerNorm divides by the standard deviation of the activations. If the standard deviation grows layer-by-layer (which is common early in training), gradients get divided by a growing denominator at each step, driving them toward zero — gradient attenuation.

In pre-norm, the main residual path is an identity highway. Gradients flow directly from the loss all the way back to early layers without ever passing through a normalizer. The LayerNorm only acts on the branch that feeds into attention or FFN, not on the shortcut connection. This means gradient signals arrive at early layers with much less distortion.

Pre-norm vs Post-norm: residual flow toggle

Click a block to trace the data flow and gradient path. Gradient signal is shown in orange.

The evidence is concrete. Xiong et al. 2020 ("On Layer Normalization in the Transformer Architecture") showed that with post-norm, training from scratch without a warmup period often diverges. Pre-norm allows training with larger learning rates and no warmup. Salazar and Nguyen 2019 documented gradient spikes unique to post-norm. The practical upshot: almost every model trained after 2020 uses pre-norm.

Notable holdouts: GPT-1, GPT-2, GPT-3 (post-norm), BERT (post-norm). Notable modern pre-norm: LLaMA 1/2/3, PaLM, Chinchilla, Mistral, Falcon, OLMo. The one funny exception in modern models is OPT-350M, which uses post-norm for unclear reasons.

Sandwich norm (double norm): A newer variant tried by Grok and Gemma 2 places a second LayerNorm outside the residual stream, on the output of each block — but NOT on the residual path itself. This is not post-norm; it's an extra norm that shapes the block's output contribution before it's added to the residual. OLMo 2 uses only this outer non-residual norm. The intuition: let the residual highway stay clean, but tame the magnitude of what each block injects.

The gradient highway intuition. Think of the residual stream as a river, and each transformer block as a tributary that pours new information in. Pre-norm means the main river channel is unobstructed — water (gradients) flows freely from the ocean (loss) all the way to the source (input). Post-norm is like placing a gate on the main channel at every tributary — each gate can slow or reverse the flow.
In pre-norm placement, where exactly is the LayerNorm applied?

Chapter 2: RMSNorm — Why Drop the Mean?

If pre-norm is the unanimous choice, normalization type is the next fork. The original transformer used LayerNorm. Most models since 2021 use RMSNorm. Understanding the difference requires reading both formulas carefully.

LayerNorm normalizes a vector x of length dmodel in two steps. First it subtracts the mean μ = (1/d) ∑i xi. Then it divides by the standard deviation σ = √((1/d) ∑i(xi − μ)2 + ε). Finally it applies a learned scale γ and learned shift β, one per dimension:

LayerNorm(x)i = γi · (xi − μ) / (σ + ε) + βi

RMSNorm drops both the mean subtraction and the bias term β. It normalizes by the root mean square of the raw values (not the centered values), then rescales with γ:

RMS(x) = √( (1/d) ∑i xi2 + ε)
RMSNorm(x)i = γi · xi / RMS(x)

Two operations gone: the mean pass and the bias. Why does this matter? At first glance it seems minor — two terms in one formula. But the argument for RMSNorm is not about FLOPs. It's about memory bandwidth.

Ivanov et al. 2023 measured the arithmetic intensity of LayerNorm operations on real hardware. Arithmetic intensity is the ratio of FLOPs performed to bytes transferred from memory. A value of 1 means every FLOP requires one byte of memory access — purely memory-bound. A value of 100 means you do 100 FLOPs per byte transferred — compute-bound. Modern matrix multiplications have intensities in the hundreds; you want your other operations to match.

LayerNorm's mean subtraction pass requires reading all dmodel values, computing the mean, then reading them all again to subtract. That second pass is an extra memory access for relatively few FLOPs — the FLOP-to-memory ratio drops to around 43G FLOPs vs 153 bytes per element at LLaMA scale. RMSNorm eliminates one of those passes. On modern GPUs where memory bandwidth is the bottleneck for small operations, this translates to measurable wallclock savings.

Validated in practice. Narang et al. 2020 ("Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention") and others confirmed RMSNorm provides: (1) runtime speedup on real hardware — not just lower FLOPs on paper, (2) no perplexity regression — the model trains just as well, sometimes slightly better due to simpler optimization landscape. Models using RMSNorm: LLaMA 1/2/3, PaLM, Chinchilla, T5, Mistral, OLMo, Falcon 2.

Dropping bias terms more generally. Modern transformers also drop bias terms from the linear projections inside attention and FFN blocks (not just from the norm). The argument is the same: fewer parameters to load = less memory bandwidth consumption per operation. GPT-3 and earlier models kept biases everywhere. LLaMA-family dropped them. The performance difference is negligible; the memory/speed difference compounds across a large model.

python
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """RMSNorm: normalize by RMS, rescale with learned gamma."""
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # gamma is the ONLY learned parameter — no beta, no bias
        self.gamma = nn.Parameter(torch.ones(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [batch, seq_len, d_model]
        # Compute RMS over the last dimension (d_model)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        # Normalize and rescale — NO mean subtraction, NO beta
        return self.gamma * (x / rms)

# Compare parameter counts:
d = 4096
layer_norm = nn.LayerNorm(d)   # gamma + beta = 2*4096 = 8192 params
rms_norm   = RMSNorm(d)        # gamma only  = 4096 params
print(f"LayerNorm params: {sum(p.numel() for p in layer_norm.parameters())}")
print(f"RMSNorm params:   {sum(p.numel() for p in rms_norm.parameters())}")
# LayerNorm params: 8192   RMSNorm params: 4096
Common misconception: "RMSNorm is just LayerNorm without the bias." The mean subtraction is equally important. LayerNorm makes the normalized vector zero-mean AND unit-variance. RMSNorm only makes it unit-RMS (not zero-mean). If the input has a non-zero mean, RMSNorm will preserve that mean (scaled down). Whether this difference matters depends on whether the network learns to zero-center its activations anyway — empirically, it usually does.
Why does RMSNorm provide a wallclock speedup over LayerNorm, even though the FLOPs difference is small?

Chapter 3: Activations & SwiGLU

The FFN sub-layer inside each transformer block computes a two-step projection with a nonlinearity in the middle. That nonlinearity — the activation function — is one of the most actively-debated architecture choices of the last five years. The story runs: ReLU → GELU → GLU variants → SwiGLU. Let's understand each step.

ReLU (original transformer, T5, Gopher, Chinchilla, OPT): max(0, x). Simple, fast. The FFN is: output = max(0, x·W1) · W2, where W1 is dmodel×dff and W2 is dff×dmodel.

GELU (GPT-1/2/3, BLOOM): x·Φ(x), where Φ is the Gaussian CDF. This is a smooth approximation to ReLU that is nonzero for negative inputs. GELU(x) ≈ x·σ(1.702x) as a fast approximation. Widely used; GPT-3's success made it a default for several years.

Gated Linear Units (GLU) are a different shape entirely. Instead of applying a single activation to the projected input, you apply it to one projection and multiply elementwise by a second linear projection (the "gate"):

ReGLU(x) = (max(0, x·W1) ⊗ (x·V)) · W2

The gate x·V is a learned linear function of the input. It says "how much should each dimension of the activated output contribute?" When the gate is large, information flows. When small, it's suppressed. This gives the model a data-dependent gating mechanism that ReLU/GELU lack.

SwiGLU (LLaMA, PaLM, Mistral, OLMo, most 2023+ models) replaces the ReLU gate with the Swish function. Swish(x) = x·σ(x) — a smooth, differentiable function that's roughly linear for large positive x and near-zero for large negative x. The SwiGLU FFN is:

SwiGLU(x) = (Swish(x·W1) ⊗ (x·V)) · W2
where Swish(z) = z · σ(z) = z / (1 + e−z)

GeGLU replaces Swish with GELU: GeGLU(x) = (GELU(x·W1) ⊗ (x·V)) · W2. Used by T5 v1.1, Gemma 2/3, Phi-3.

The 2/3 width correction. SwiGLU needs three matrix multiplications (W1, V, W2) instead of two (W1, W2). To keep FLOPs constant, we must shrink dff. Standard rule: dff = 4×dmodel. For SwiGLU, we want 3×dff,new = 2×dff,old (three matrices vs two, same total FLOPs). So:

dff,new = (2/3) × dff,old = (2/3) × 4 × dmodel = (8/3) × dmodel ≈ 2.67 × dmodel

This is why you see dff/dmodel ≈ 2.67–3.5 in LLaMA-family models (not 4). The exact values vary by model because engineers often round dff to a multiple of 64 for hardware efficiency.

Activation function plotter

Select an activation to see its shape and the derivative. Notice how GELU and Swish are smooth near 0 while ReLU has a kink.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLUFFN(nn.Module):
    """Feed-forward network with SwiGLU activation.

    Three weight matrices: W1 (gate), W_V (values), W2 (projection).
    d_ff = 8/3 * d_model for iso-FLOP with standard 4x FFN.
    No bias anywhere — LLaMA convention.
    """
    def __init__(self, d_model: int):
        super().__init__()
        # 8/3 rounded to multiple of 64 for hardware efficiency
        d_ff = int(8 / 3 * d_model)
        d_ff = (d_ff + 63) // 64 * 64  # round up to multiple of 64
        # W1 computes the 'gate' branch (will be Swished)
        self.W1 = nn.Linear(d_model, d_ff, bias=False)
        # V computes the 'value' branch (linear passthrough)
        self.V  = nn.Linear(d_model, d_ff, bias=False)
        # W2 projects back to d_model
        self.W2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch, seq, d_model]
        gate  = F.silu(self.W1(x))    # silu = Swish = x*sigmoid(x)
        value = self.V(x)              # linear gate
        return self.W2(gate * value)  # elementwise product, then project

# Parameter count comparison for d_model=4096:
d = 4096
ffn_relu   = nn.Sequential(nn.Linear(d, 4*d, bias=False), nn.ReLU(), nn.Linear(4*d, d, bias=False))
ffn_swiglu = SwiGLUFFN(d)
relu_p   = sum(p.numel() for p in ffn_relu.parameters())
swiglu_p = sum(p.numel() for p in ffn_swiglu.parameters())
print(f"ReLU FFN:   {relu_p/1e6:.1f}M params")    # 4096*(4*4096)*2 = 134.2M
print(f"SwiGLU FFN: {swiglu_p/1e6:.1f}M params")  # roughly same with 8/3 rule
Common misconception: "SwiGLU is strictly better than GELU/ReLU." Shazeer 2020 and Narang et al. 2020 show consistent but small perplexity improvements. Nemotron-340B uses squared ReLU; Falcon-2 11B uses plain ReLU. GPT-3 with GELU is still competitive with smaller LLaMA models. The choice matters, but not enormously — it's not the primary lever for model quality. The 2/3 width correction matters more: get the FLOP accounting right, or you're not comparing apples to apples.
A standard ReLU FFN has d_model=2048 and d_ff=8192 (4× ratio). You want to replace it with SwiGLU while keeping FLOPs roughly constant. What should d_ff be for SwiGLU?

Chapter 4: RoPE — Rotary Position Embeddings

The transformer doesn't know the order of tokens unless you tell it. The original model added sinusoidal position vectors to token embeddings before feeding them in. Most models after GPT-3 switched to learned absolute position embeddings. But both approaches have a fundamental problem that RoPE solves elegantly.

The relative position problem. In attention, what matters is whether token A is 3 positions before token B — not their absolute positions 47 and 50 in the sequence. The attention score between two tokens should ideally depend only on their relative position (i−j), not their absolute positions i and j separately.

Sinusoidal embeddings fail this test. When you add a position vector to a token embedding and compute the dot product between two positions, you get cross-terms involving the absolute positions of both tokens separately — not just their difference. Absolute learned embeddings obviously fail since they encode only absolute position. T5-style relative embeddings (Raffel et al.) are explicitly relative but require adding position biases to the attention logits, which is awkward and doesn't naturally compose with scaled dot-product attention.

The RoPE insight (Su et al. 2021). We want a function f(x, i) that encodes token x at position i such that the dot product f(x,i)·f(y,j) depends only on x, y, and (i−j) — not on i or j individually. The key insight: inner products are invariant to rotation. If you rotate two vectors by the same angle, their dot product doesn't change. So if you rotate each token's embedding by an angle that depends on its position, the dot product (attention score) will automatically depend only on the relative rotation — which is the relative position.

The 2D rotation formula. For a 2D vector [x0, x1] at position i, RoPE applies the rotation matrix with angle θ·i:

R(i) · [x0, x1] = [x0cos(θi) − x1sin(θi), x0sin(θi) + x1cos(θi)]

For a dhead-dimensional head vector, you pair up coordinates (0,1), (2,3), ..., (d-2, d-1) and rotate each pair independently. Each pair uses a different base frequency θk = 10000−2k/d, inspired by sinusoidal encodings. The high-frequency pairs (small k) capture fine-grained relative positions; the low-frequency pairs (large k) capture coarse-grained ones.

The relative property, verified. For query vector q at position i and key vector k at position j, the attention score is:

score(i,j) = (R(i)·q)T(R(j)·k) = qT R(i)T R(j) k = qT R(j−i) k

The rotation matrices multiply as R(i)TR(j) = R(j−i), because rotating forward by j then backward by i is the same as rotating forward by (j−i). The attention score is now qTR(j−i)k — it depends only on the relative position (j−i), not on absolute positions i or j. Goal achieved.

RoPE rotating vector: position & frequency explorer

Two tokens (query Q and key K) shown as 2D vectors. Rotate each by its position. Watch how their dot product (attention contribution) depends only on relative position Δpos, not absolute positions.

Query position (posQ) 3
Key position (posK) 7
Base frequency θ (dim 0) 1000

Why not add, like sinusoidal? Sinusoidal embeddings are additive: Embed(x,i) = vx + PEi. When you dot-product two such embeddings, you get vx·vy + vx·PEj + PEi·vy + PEi·PEj — four terms, two of which depend on absolute positions i and j separately. RoPE's multiplicative rotation eliminates these cross-terms entirely. The dot product collapses to a single term that depends only on i−j.

python
import torch
import math

def build_rope_cache(seq_len: int, d_head: int, base: float = 10000.0):
    """Precompute cos/sin tables for RoPE. Returns cos, sin each [seq_len, d_head/2]."""
    # Frequencies: theta_k = 1 / (base ** (2k / d_head)) for k = 0..d_head//2-1
    half = d_head // 2
    freqs = 1.0 / (base ** (torch.arange(0, half).float() / half))  # [half]
    positions = torch.arange(seq_len).float()               # [seq_len]
    angles = torch.outer(positions, freqs)                   # [seq_len, half]
    return angles.cos(), angles.sin()                        # each [seq_len, half]

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    """Apply RoPE to query or key tensor.

    x: [batch, n_heads, seq_len, d_head]
    cos, sin: [seq_len, d_head/2]
    Returns: rotated x, same shape.
    """
    half = x.shape[-1] // 2
    x1, x2 = x[..., :half], x[..., half:]    # split along head dim
    # Broadcast cos/sin over batch and heads: [1,1,seq,half]
    c = cos.unsqueeze(0).unsqueeze(0)
    s = sin.unsqueeze(0).unsqueeze(0)
    # 2D rotation for each pair: [x1*cos - x2*sin, x1*sin + x2*cos]
    rot_x1 = x1 * c - x2 * s
    rot_x2 = x1 * s + x2 * c
    return torch.cat([rot_x1, rot_x2], dim=-1)

# Example:
cos, sin = build_rope_cache(seq_len=4096, d_head=128)
q = torch.randn(2, 32, 4096, 128)  # [batch=2, heads=32, seq=4096, d_head=128]
k = torch.randn(2, 32, 4096, 128)
q_rot = apply_rope(q, cos, sin)
k_rot = apply_rope(k, cos, sin)
# Now q_rot and k_rot's dot products are purely relative-position dependent
RoPE enables length extrapolation. Because RoPE encodes position through rotation angles (continuous functions of position), the model can, in principle, handle sequences longer than the training context by using larger rotation angles — the cos/sin functions naturally extend. This is harder with learned absolute embeddings, which have no values for positions they've never seen. Extensions like YaRN and LongRoPE modify the base frequency θ to push context lengths to 100k+ tokens.
What mathematical property of rotations makes RoPE's relative-position encoding work?

Chapter 5: MHA → GQA + KV-Cache Math

Standard Multi-Head Attention (MHA) runs nheads attention heads in parallel, each with its own query, key, and value projections. At inference time, when generating token by token, you cannot recompute all past keys and values at every step — that would be quadratic in sequence length. Instead, you cache them in a KV cache.

The KV cache holds the key and value tensors for every layer, every head, and every past token. For a model with nlayers layers, nheads heads per layer, sequence length T, and head dimension dhead, stored in bf16 (2 bytes each):

KV cache size = 2 (K and V) × nlayers × nheads × T × dhead × 2 bytes

For LLaMA-2 70B with nlayers=80, nheads=64, dhead=128, T=4096, batch=1:

= 2 × 80 × 64 × 4096 × 128 × 2 bytes = 134 GB

That's more than one H100 GPU. KV cache memory is a serious bottleneck for deploying large models with long contexts or large batches. This is the problem that MQA and GQA solve.

Multi-Query Attention (MQA) (Shazeer 2019) keeps multiple query heads but shares a single key-value head across all queries. Every query head computes attention against the same keys and values. The KV cache now only stores one K and one V tensor per layer:

MQA KV cache = 2 × nlayers × 1 × T × dhead × 2 bytes

For LLaMA-2 70B equivalent: 2×80×1×4096×128×2 = 2.1 GB — a 64× reduction. But MQA trades some model expressiveness for this saving. Shazeer (2019) found a small perplexity hit; subsequent works found it depends heavily on model size and training data.

Grouped-Query Attention (GQA) (Ainslie et al. 2023) is the interpolation. Instead of nheads KV heads (MHA) or 1 (MQA), use nkv_heads KV heads, grouping queries into clusters. Each group of (nheads/nkv_heads) query heads shares one KV head:

GQA KV cache = 2 × nlayers × nkv_heads × T × dhead × 2 bytes
Cache reduction vs MHA = nheads / nkv_heads

LLaMA-2 70B uses nheads=64, nkv_heads=8 (8-way grouping). Cache savings: 64/8 = . New cache size: 134 GB / 8 = 16.8 GB — fits on one A100 80GB.

MHA → MQA → GQA: head sharing + KV cache bar

Slide n_kv_heads from 1 (MQA) to n_heads (MHA). Watch how head groups change and the KV cache bar shrinks.

n_heads (query heads) 16
n_kv_heads 4
Sequence length T (tokens) 4096
python
import torch
import torch.nn.functional as F
from einops import rearrange

def grouped_query_attention(
    q: torch.Tensor,       # [B, n_heads, T_q, d_head]
    k: torch.Tensor,       # [B, n_kv_heads, T_kv, d_head]
    v: torch.Tensor,       # [B, n_kv_heads, T_kv, d_head]
    scale: float = None
) -> torch.Tensor:
    """GQA: n_heads queries, n_kv_heads key-value pairs (n_heads >= n_kv_heads).

    Each group of (n_heads // n_kv_heads) query heads attends the same K, V.
    """
    B, n_heads, T_q, d_head = q.shape
    n_kv = k.shape[1]
    assert n_heads % n_kv == 0, "n_heads must be divisible by n_kv_heads"
    group_size = n_heads // n_kv

    if scale is None:
        scale = d_head ** -0.5

    # Expand K, V to match n_heads by repeating each KV head group_size times
    k_exp = k.repeat_interleave(group_size, dim=1)  # [B, n_heads, T_kv, d_head]
    v_exp = v.repeat_interleave(group_size, dim=1)

    # Standard scaled dot-product attention
    scores = torch.einsum('bhqd,bhkd->bhqk', q * scale, k_exp)  # [B,n_heads,T_q,T_kv]
    weights = F.softmax(scores, dim=-1)
    out = torch.einsum('bhqk,bhkd->bhqd', weights, v_exp)       # [B,n_heads,T_q,d_head]
    return out

# Memory savings calculation:
def kv_cache_gb(n_layers, n_kv_heads, T, d_head, bytes_per_val=2):
    """KV cache in GB for one sequence."""
    # K and V, each [n_layers, n_kv_heads, T, d_head]
    total = 2 * n_layers * n_kv_heads * T * d_head * bytes_per_val
    return total / 1e9

print(f"LLaMA-2 70B MHA  (64 heads): {kv_cache_gb(80,64,4096,128):.1f} GB")
print(f"LLaMA-2 70B GQA  (8 heads):  {kv_cache_gb(80, 8,4096,128):.1f} GB")
print(f"LLaMA-2 70B MQA  (1 head):   {kv_cache_gb(80, 1,4096,128):.2f} GB")
Common misconception: "GQA always hurts quality." Ainslie et al. 2023 showed that for sufficiently large models, GQA with nkv_heads=8 matches full MHA quality with a 4×–8× KV cache savings. The quality hit is most pronounced in very small models (where each head carries more unique information) and with very aggressive grouping (MQA). The empirical consensus: GQA with nkv_heads = nheads/4 to nheads/8 is the sweet spot.
LLaMA-3 8B has n_layers=32, n_heads=32, n_kv_heads=8, d_head=128. How many GB is the KV cache for a single 4096-token sequence in bf16?

Chapter 6: Hyperparameters — What Everyone Agrees On

Once you've fixed your architecture choices (pre-norm, RMSNorm, SwiGLU, RoPE, GQA), you still need to set a handful of numerical hyperparameters: how wide is the model (dmodel), how many layers (nlayers), how many heads (nheads), how wide is the FFN (dff), and how large is the vocabulary (V). These values have surprising consensus across models.

1. FFN ratio (dff / dmodel)

Almost every non-gated model uses dff = 4 × dmodel. Kaplan et al. 2020 showed empirically that this hyperparameter has a wide "plateau" between 1× and 10× — the loss doesn't degrade dramatically if you deviate. But the 4× rule became a default because it worked well and everyone copied it.

For SwiGLU/GeGLU models: the iso-FLOP correction gives dff ≈ (8/3)×dmodel ≈ 2.67. In practice models round to clean numbers: 2.5, 3.5, or 3.5. Real-world values: Mistral 7B (3.5), LLaMA-2 70B (3.5), LLaMA 70B (2.68), Qwen 14B (2.67), DeepSeek 67B (2.68).

One wild outlier: T5 11B used dff = 65,536 with dmodel = 1,024 — a 64× ratio! Their follow-up T5 v1.1 reverted to 2.5× with GeGLU, implying the 64× was suboptimal.

2. Head dimension rule: dhead = dmodel / nheads

Most models keep head dimension times number of heads equal to model dimension: nheads × dhead = dmodel. This is not a mathematical requirement — you could have dhead > dmodel/nheads (the T5 exception, where they used nheads=128, dhead=128, dmodel=1024 → ratio=16). But empirically, the 1:1 ratio works well and keeps the architecture simple.

Modeldmodelnheadsdheadratio
GPT-3 175B12288961281
LLaMA-2 70B8192641281
PaLM 540B18432482561.48
T5 11B102412812816 (!)

3. Aspect ratio: dmodel / nlayers

Should your model be deep (many layers, narrow dmodel) or wide (few layers, large dmodel)? Looking across real models, the sweet spot for dmodel/nlayers is roughly 100–200. BLOOM: 205. PaLM 540B: 156. GPT-3 and Mistral: 128. LLaMA-2: 102. GPT-2: 33 (outlier, very deep for its size).

The intuition: extremely deep models are harder to pipeline across GPUs (more sequential communication barriers) and have higher latency per token at inference. Extremely wide models may be over-parameterized in depth. The 100–200 range balances these concerns empirically.

4. Vocabulary size

Monolingual English models: 30,000–50,000 tokens. GPT-2/3: 50,257. LLaMA 1: 32,000. Multilingual / production systems: 100,000–256,000. GPT-4: 100,276. PaLM: 256,000. mT5: 250,000. Qwen 15B: 152,064.

Larger vocabulary = fewer tokens per sequence (more compact) but larger embedding matrix (dvocab × dmodel parameters) and a heavier final projection layer.

5. Dropout and regularization

Modern frontier models mostly do no dropout during pretraining. The intuition: with trillions of tokens and a single pass through the data, there's virtually no overfitting risk. LLaMA, Chinchilla, PaLM: no dropout. But they do use weight decay (0.1 is typical). Why? Andriushchenko et al. 2023 showed weight decay interacts with the learning rate schedule (cosine decay) to improve optimization dynamics — it's not about preventing overfitting.

The boring truth about hyperparameters. The single most useful takeaway from this section: almost all successful large LLMs use remarkably conservative, similar hyperparameters. dff=4d (or 2.67d for SwiGLU), n_heads × d_head = d_model, aspect ratio 100–200, vocab 32k–100k. The creative space is mostly in data and scale, not in hyperparameter novelty. If you need to tune one thing, tune data quality.
A model has d_model=4096, n_heads=32. What is d_head, and what is the standard FFN width with ReLU? With SwiGLU?

Chapter 7: Live Hyperparameter Explorer

All the hyperparameter rules from Chapter 6 — FFN ratio, head dimension, aspect ratio — come together in this live parameter counter. Drag the sliders and watch how parameter count breaks down across embedding, attention, FFN, and output layers. Real-model presets let you verify your intuition against published numbers.

Transformer parameter counter & shape explorer

Adjust the architecture. Live breakdown of parameters per component. Check "SwiGLU" to apply the 8/3 FFN correction. Load presets to match real models.

dmodel 4096
nlayers 32
nheads 32
FFN ratio (dff/dmodel) 4.0
Vocab size (K tokens) 32K

The parameter count formula (no biases, SwiGLU optional) per transformer block:

Attention: 4 × dmodel × dmodel (WQ, WK, WV, WO each dmodel×dmodel)
FFN (ReLU): 2 × dmodel × dff
FFN (SwiGLU): 3 × dmodel × dff, with dff = (8/3) × dmodel
Norm (RMSNorm): 2 × dmodel (two norms per block)
Total non-embedding params ≈ nlayers × (4d2 + 2dff×d + 2d)
Embedding + output: 2 × V × dmodel (input embed + output unembedding, often tied)

For GQA with nkv_heads < nheads, the attention formula becomes: (nheads + 2×nkv_heads) × dhead × dmodel + dmodel2 (for WO). The KV projections are smaller.

Rule of thumb: For a dense transformer with no weight tying: N ≈ 12 × nlayers × dmodel2 (the 12 comes from 4d2 for attention + 8d2 for FFN at 4× ratio). Check: LLaMA-2 7B — n_layers=32, d_model=4096 → 12×32×40962 = 6.44B. Actual: ~6.7B (plus embeddings). The rule is accurate to ~5%.
Using the 12×n_layers×d_model² approximation, estimate the non-embedding parameters for a model with n_layers=80 and d_model=8192 (LLaMA-2 70B scale).

Chapter 8: Stability Tricks & Softmax Dangers

Training a large model without instability is harder than it sounds. The culprit is usually the softmax function — and there are two places where it appears in a transformer.

The output softmax at the end converts logits (raw scores for each vocabulary token) to probabilities. If any logit grows very large, the exponential in softmax can overflow to infinity. Even before overflow, large logits concentrate the softmax output to near-zero for most tokens, making the gradient signal very sparse. PaLM pioneered a trick called the z-loss to prevent this.

The z-loss adds a small penalty proportional to the square of the log-partition function (the log-sum-exp of all logits). This penalizes the model for making any single logit too large, stabilizing the output softmax. Concretely, if z = log(∑i exp(logiti)) is the log-partition, the z-loss penalty is ε×z2 added to the cross-entropy loss. DCLM (2024) and OLMo 2 (2025) also use this.

The attention softmax inside each attention head can also be problematic. If query and key vectors have large magnitudes, the dot products QKT/√dhead can become very large even after scaling, driving the softmax into a saturated regime where gradients vanish (the "attention collapse" problem). Two tricks address this:

QK-Norm: Apply RMSNorm (or LayerNorm) to the query and key vectors before computing attention scores. This caps their magnitude and prevents the dot products from exploding. Used by DCLM, OLMo 2, Gemma 2, and originally proposed for vision transformers (Dehghani 2023, IDEFICS, Chameleon). Implementation is simple: add a learnable norm after the Q and K projections.

Logit soft-capping: Clamp attention logits to a maximum value via a tanh transformation. If raw logit is x, the capped version is max_val × tanh(x / max_val). This is differentiable and prevents logits from ever exceeding max_val. Used by Gemma 2. The potential downside: the tanh compression might hurt performance for cases where very large logits are genuinely meaningful — early reports suggest a small but real PPL cost.

Parallel layers are an orthogonal stability/efficiency trick. Normal transformers compute attention then FFN sequentially. A few models (GPT-J, PaLM, GPT-NeoX, Cohere Command A) compute attention and FFN in parallel (adding both to the residual simultaneously). When done carefully, the LayerNorm can be shared between both branches, and the first matrix multiply in attention and FFN can be fused — reducing memory bandwidth. The speed gain is real; whether it hurts quality depends on the scale.

The instability fingerprint. Training loss curves that spike suddenly (large loss jump, then recovery) or diverge after a warmup period are the classic signs of softmax instability. The fix hierarchy: (1) try z-loss first (cheapest), (2) add QK-Norm if attention collapse is the issue, (3) reduce learning rate or add gradient clipping as a last resort. Monitoring the magnitude of logits and attention entropy during training is the diagnostic.

No-bias recommendation also connects to stability. When linear layers have bias terms, those biases can grow arbitrarily without being regularized by the norm operations. Removing biases (as LLaMA-family does) constrains the optimization landscape and reduces the surface area for instability.

python
import torch
import torch.nn.functional as F

def cross_entropy_with_z_loss(
    logits: torch.Tensor,   # [B, T, V]
    targets: torch.Tensor,  # [B, T] token ids
    z_loss_weight: float = 1e-4
) -> torch.Tensor:
    """Cross-entropy with PaLM's z-loss stability term.

    z = log(sum_i exp(logit_i)) = log-partition
    z_loss = z_loss_weight * z^2
    This penalizes very large logits (large log-partition).
    """
    B, T, V = logits.shape
    logits_2d = logits.reshape(-1, V)   # [B*T, V]
    targets_1d = targets.reshape(-1)    # [B*T]

    # Standard cross-entropy
    ce_loss = F.cross_entropy(logits_2d, targets_1d)

    # Z-loss: penalize large log-partition = log(sum exp(logits))
    log_z = torch.logsumexp(logits_2d, dim=-1)   # [B*T], = log-partition per position
    z_loss = z_loss_weight * (log_z ** 2).mean()

    return ce_loss + z_loss

class QKNorm(torch.nn.Module):
    """Apply RMSNorm to Q and K before attention score computation."""
    def __init__(self, d_head: int):
        super().__init__()
        self.q_norm = RMSNorm(d_head)
        self.k_norm = RMSNorm(d_head)

    def forward(self, q, k):
        return self.q_norm(q), self.k_norm(k)
    # Insert between QKV projection and attention score computation
What problem does the z-loss solve, and what mathematical quantity does it penalize?

Chapter 9: Connections & Cheat Sheet

You've now seen every major architecture decision made in modern LLMs — derived from first principles, with math, code, and real-model evidence. Here's how it all fits together.

The Modern LLM Architecture Checklist

ChoiceConsensus (2024+)WhyNotable exceptions
Norm placementPre-normClean gradient highways, stable trainingOPT-350M (post-norm)
Norm typeRMSNormLower memory bandwidth, same qualityGPT-2/3 (LayerNorm)
Bias termsNoneFewer params to move, less instability surfaceBERT, GPT-2
ActivationSwiGLU or GeGLUConsistent small gains over ReLU/GELUGPT-3 (GELU), Falcon 2 (ReLU)
Position encodingRoPETrue relative position, length extrapolationT5 (relative bias), OPT (learned abs)
AttentionGQA64× smaller KV cache vs MHA at 8 kv-headsGPT-3, BLOOM (full MHA)
FFN ratio4× (or 8/3× for SwiGLU)Empirical plateau; iso-FLOP correctionT5 (64×!)
Aspect ratiodmodel/nlayers ≈ 100–200Parallelism & latency constraintsT5 (43), GPT-2 (33)
DropoutNone (pretraining)Trillions of tokens, no overfit riskGPT-2/3, T5, Qwen (0.1)
Stabilityz-loss, QK-NormPrevents softmax overflow/collapseLLaMA (no z-loss)

Parameter Count Quick Reference

N ≈ 12 × nlayers × dmodel2 (non-embedding, ±5%)
FFN params per layer: (8/3) × dmodel2 (SwiGLU) or 8 × dmodel2 (ReLU 4×)
Attn params per layer: 4 × dmodel2 (MHA) or ≈ 4 × dmodel2 × (1 + 2/nheads) for GQA
KV cache: 4 × nlayers × nkv_heads × T × dhead bytes (bf16)

Real Model Reference Table

Modeldmodelnlayersnheadsnkvdff/dVocab~Params
GPT-3 175B1228896969650k175B
LLaMA-2 7B40963232322.67×32k7B
LLaMA-2 70B8192806483.5×32k70B
Mistral 7B4096323283.5×32k7B
PaLM 540B184321184848256k540B

Go Deeper

This lesson covered the architecture; the following dive into adjacent topics:

Papers Referenced

  1. Xiong et al. 2020. "On Layer Normalization in the Transformer Architecture." ICML 2020.
  2. Zhang & Sennrich 2019. "Root Mean Square Layer Normalization." NeurIPS 2019.
  3. Shazeer 2020. "GLU Variants Improve Transformer." arXiv 2002.05202.
  4. Su et al. 2021. "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv 2104.09864.
  5. Ainslie et al. 2023. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023.
  6. Kaplan et al. 2020. "Scaling Laws for Neural Language Models." arXiv 2001.08361.
  7. Ivanov et al. 2023. "Data Movement is All You Need." MLSys 2023.
Feynman test for this lecture: Can you explain — without notes — why a modern LLM uses pre-norm RMSNorm instead of post-norm LayerNorm, why SwiGLU needs a 2/3 width correction, and how RoPE achieves relative position encoding through rotation matrices? If yes, you've internalized CS336 Lecture 3. If not, the canvas simulations in Chapters 1, 3, and 4 are waiting.
Which combination of choices does LLaMA-2 70B use?