Every "modern" LLM architecture decision explained from first principles: pre-norm vs post-norm, RMSNorm, SwiGLU, RoPE, MQA/GQA, KV-cache math, and every hyperparameter from FFN-ratio to aspect ratio. Derive the formula, run the numbers, see what every real model chose.
You're sitting down to implement a transformer from scratch. You know the rough shape: embedding layer, stack of attention + FFN blocks, final projection to vocabulary. But when you open the original Vaswani 2017 paper and the LLaMA 3 technical report side by side, you find they share almost nothing beyond that rough shape. Different normalization. Different activations. Different positional encoding. Different attention mechanism. Different hyperparameter ratios.
Tatsuya Hashimoto (Tatsu H) opens CS336 Lecture 3 with a striking observation: since the 2024 version of this course, over 19 new dense LLM architectures were publicly released — most with minor but deliberate tweaks to the standard transformer. Each choice was made for a reason. Some of those reasons are theoretical. Most are empirical. A few turned out to be wrong in retrospect.
This lecture is not about which architecture is "best." It's about learning to read the evidence — understanding what choices have near-universal consensus today (pre-norm, RMSNorm, RoPE, SwiGLU), which ones have reasonable tradeoffs (GQA vs MHA), and which ones are surprisingly unconstrained (FFN width ratio, aspect ratio). By the end, you'll be able to read any new model's technical report and immediately classify every architecture decision.
We'll cover each of these transitions in depth — deriving why the modern choice works, showing the math and the code, and citing the exact papers and models that drove the change.
The single thing that almost every modern LLM agrees on — more than any other architectural choice — is pre-norm. In post-norm (the original transformer), LayerNorm sits after the residual addition: you compute the sub-layer output, add it to the residual, then normalize. In pre-norm, the LayerNorm moves inside the block before the sub-layer computation: you normalize first, feed through attention or FFN, then add back to the (un-normalized) residual stream.
Why does this matter? The key is what happens to the residual stream — the running sum that passes through all layers. In post-norm, every residual add is immediately renormalized. The gradient flowing backward must pass through that normalization operation at every layer. LayerNorm divides by the standard deviation of the activations. If the standard deviation grows layer-by-layer (which is common early in training), gradients get divided by a growing denominator at each step, driving them toward zero — gradient attenuation.
In pre-norm, the main residual path is an identity highway. Gradients flow directly from the loss all the way back to early layers without ever passing through a normalizer. The LayerNorm only acts on the branch that feeds into attention or FFN, not on the shortcut connection. This means gradient signals arrive at early layers with much less distortion.
Click a block to trace the data flow and gradient path. Gradient signal is shown in orange.
The evidence is concrete. Xiong et al. 2020 ("On Layer Normalization in the Transformer Architecture") showed that with post-norm, training from scratch without a warmup period often diverges. Pre-norm allows training with larger learning rates and no warmup. Salazar and Nguyen 2019 documented gradient spikes unique to post-norm. The practical upshot: almost every model trained after 2020 uses pre-norm.
Notable holdouts: GPT-1, GPT-2, GPT-3 (post-norm), BERT (post-norm). Notable modern pre-norm: LLaMA 1/2/3, PaLM, Chinchilla, Mistral, Falcon, OLMo. The one funny exception in modern models is OPT-350M, which uses post-norm for unclear reasons.
Sandwich norm (double norm): A newer variant tried by Grok and Gemma 2 places a second LayerNorm outside the residual stream, on the output of each block — but NOT on the residual path itself. This is not post-norm; it's an extra norm that shapes the block's output contribution before it's added to the residual. OLMo 2 uses only this outer non-residual norm. The intuition: let the residual highway stay clean, but tame the magnitude of what each block injects.
If pre-norm is the unanimous choice, normalization type is the next fork. The original transformer used LayerNorm. Most models since 2021 use RMSNorm. Understanding the difference requires reading both formulas carefully.
LayerNorm normalizes a vector x of length dmodel in two steps. First it subtracts the mean μ = (1/d) ∑i xi. Then it divides by the standard deviation σ = √((1/d) ∑i(xi − μ)2 + ε). Finally it applies a learned scale γ and learned shift β, one per dimension:
RMSNorm drops both the mean subtraction and the bias term β. It normalizes by the root mean square of the raw values (not the centered values), then rescales with γ:
Two operations gone: the mean pass and the bias. Why does this matter? At first glance it seems minor — two terms in one formula. But the argument for RMSNorm is not about FLOPs. It's about memory bandwidth.
Ivanov et al. 2023 measured the arithmetic intensity of LayerNorm operations on real hardware. Arithmetic intensity is the ratio of FLOPs performed to bytes transferred from memory. A value of 1 means every FLOP requires one byte of memory access — purely memory-bound. A value of 100 means you do 100 FLOPs per byte transferred — compute-bound. Modern matrix multiplications have intensities in the hundreds; you want your other operations to match.
LayerNorm's mean subtraction pass requires reading all dmodel values, computing the mean, then reading them all again to subtract. That second pass is an extra memory access for relatively few FLOPs — the FLOP-to-memory ratio drops to around 43G FLOPs vs 153 bytes per element at LLaMA scale. RMSNorm eliminates one of those passes. On modern GPUs where memory bandwidth is the bottleneck for small operations, this translates to measurable wallclock savings.
Dropping bias terms more generally. Modern transformers also drop bias terms from the linear projections inside attention and FFN blocks (not just from the norm). The argument is the same: fewer parameters to load = less memory bandwidth consumption per operation. GPT-3 and earlier models kept biases everywhere. LLaMA-family dropped them. The performance difference is negligible; the memory/speed difference compounds across a large model.
python import torch import torch.nn as nn class RMSNorm(nn.Module): """RMSNorm: normalize by RMS, rescale with learned gamma.""" def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.eps = eps # gamma is the ONLY learned parameter — no beta, no bias self.gamma = nn.Parameter(torch.ones(d_model)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: [batch, seq_len, d_model] # Compute RMS over the last dimension (d_model) rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() # Normalize and rescale — NO mean subtraction, NO beta return self.gamma * (x / rms) # Compare parameter counts: d = 4096 layer_norm = nn.LayerNorm(d) # gamma + beta = 2*4096 = 8192 params rms_norm = RMSNorm(d) # gamma only = 4096 params print(f"LayerNorm params: {sum(p.numel() for p in layer_norm.parameters())}") print(f"RMSNorm params: {sum(p.numel() for p in rms_norm.parameters())}") # LayerNorm params: 8192 RMSNorm params: 4096
The FFN sub-layer inside each transformer block computes a two-step projection with a nonlinearity in the middle. That nonlinearity — the activation function — is one of the most actively-debated architecture choices of the last five years. The story runs: ReLU → GELU → GLU variants → SwiGLU. Let's understand each step.
ReLU (original transformer, T5, Gopher, Chinchilla, OPT): max(0, x). Simple, fast. The FFN is: output = max(0, x·W1) · W2, where W1 is dmodel×dff and W2 is dff×dmodel.
GELU (GPT-1/2/3, BLOOM): x·Φ(x), where Φ is the Gaussian CDF. This is a smooth approximation to ReLU that is nonzero for negative inputs. GELU(x) ≈ x·σ(1.702x) as a fast approximation. Widely used; GPT-3's success made it a default for several years.
Gated Linear Units (GLU) are a different shape entirely. Instead of applying a single activation to the projected input, you apply it to one projection and multiply elementwise by a second linear projection (the "gate"):
The gate x·V is a learned linear function of the input. It says "how much should each dimension of the activated output contribute?" When the gate is large, information flows. When small, it's suppressed. This gives the model a data-dependent gating mechanism that ReLU/GELU lack.
SwiGLU (LLaMA, PaLM, Mistral, OLMo, most 2023+ models) replaces the ReLU gate with the Swish function. Swish(x) = x·σ(x) — a smooth, differentiable function that's roughly linear for large positive x and near-zero for large negative x. The SwiGLU FFN is:
GeGLU replaces Swish with GELU: GeGLU(x) = (GELU(x·W1) ⊗ (x·V)) · W2. Used by T5 v1.1, Gemma 2/3, Phi-3.
The 2/3 width correction. SwiGLU needs three matrix multiplications (W1, V, W2) instead of two (W1, W2). To keep FLOPs constant, we must shrink dff. Standard rule: dff = 4×dmodel. For SwiGLU, we want 3×dff,new = 2×dff,old (three matrices vs two, same total FLOPs). So:
This is why you see dff/dmodel ≈ 2.67–3.5 in LLaMA-family models (not 4). The exact values vary by model because engineers often round dff to a multiple of 64 for hardware efficiency.
Select an activation to see its shape and the derivative. Notice how GELU and Swish are smooth near 0 while ReLU has a kink.
python import torch import torch.nn as nn import torch.nn.functional as F class SwiGLUFFN(nn.Module): """Feed-forward network with SwiGLU activation. Three weight matrices: W1 (gate), W_V (values), W2 (projection). d_ff = 8/3 * d_model for iso-FLOP with standard 4x FFN. No bias anywhere — LLaMA convention. """ def __init__(self, d_model: int): super().__init__() # 8/3 rounded to multiple of 64 for hardware efficiency d_ff = int(8 / 3 * d_model) d_ff = (d_ff + 63) // 64 * 64 # round up to multiple of 64 # W1 computes the 'gate' branch (will be Swished) self.W1 = nn.Linear(d_model, d_ff, bias=False) # V computes the 'value' branch (linear passthrough) self.V = nn.Linear(d_model, d_ff, bias=False) # W2 projects back to d_model self.W2 = nn.Linear(d_ff, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [batch, seq, d_model] gate = F.silu(self.W1(x)) # silu = Swish = x*sigmoid(x) value = self.V(x) # linear gate return self.W2(gate * value) # elementwise product, then project # Parameter count comparison for d_model=4096: d = 4096 ffn_relu = nn.Sequential(nn.Linear(d, 4*d, bias=False), nn.ReLU(), nn.Linear(4*d, d, bias=False)) ffn_swiglu = SwiGLUFFN(d) relu_p = sum(p.numel() for p in ffn_relu.parameters()) swiglu_p = sum(p.numel() for p in ffn_swiglu.parameters()) print(f"ReLU FFN: {relu_p/1e6:.1f}M params") # 4096*(4*4096)*2 = 134.2M print(f"SwiGLU FFN: {swiglu_p/1e6:.1f}M params") # roughly same with 8/3 rule
The transformer doesn't know the order of tokens unless you tell it. The original model added sinusoidal position vectors to token embeddings before feeding them in. Most models after GPT-3 switched to learned absolute position embeddings. But both approaches have a fundamental problem that RoPE solves elegantly.
The relative position problem. In attention, what matters is whether token A is 3 positions before token B — not their absolute positions 47 and 50 in the sequence. The attention score between two tokens should ideally depend only on their relative position (i−j), not their absolute positions i and j separately.
Sinusoidal embeddings fail this test. When you add a position vector to a token embedding and compute the dot product between two positions, you get cross-terms involving the absolute positions of both tokens separately — not just their difference. Absolute learned embeddings obviously fail since they encode only absolute position. T5-style relative embeddings (Raffel et al.) are explicitly relative but require adding position biases to the attention logits, which is awkward and doesn't naturally compose with scaled dot-product attention.
The RoPE insight (Su et al. 2021). We want a function f(x, i) that encodes token x at position i such that the dot product f(x,i)·f(y,j) depends only on x, y, and (i−j) — not on i or j individually. The key insight: inner products are invariant to rotation. If you rotate two vectors by the same angle, their dot product doesn't change. So if you rotate each token's embedding by an angle that depends on its position, the dot product (attention score) will automatically depend only on the relative rotation — which is the relative position.
The 2D rotation formula. For a 2D vector [x0, x1] at position i, RoPE applies the rotation matrix with angle θ·i:
For a dhead-dimensional head vector, you pair up coordinates (0,1), (2,3), ..., (d-2, d-1) and rotate each pair independently. Each pair uses a different base frequency θk = 10000−2k/d, inspired by sinusoidal encodings. The high-frequency pairs (small k) capture fine-grained relative positions; the low-frequency pairs (large k) capture coarse-grained ones.
The relative property, verified. For query vector q at position i and key vector k at position j, the attention score is:
The rotation matrices multiply as R(i)TR(j) = R(j−i), because rotating forward by j then backward by i is the same as rotating forward by (j−i). The attention score is now qTR(j−i)k — it depends only on the relative position (j−i), not on absolute positions i or j. Goal achieved.
Two tokens (query Q and key K) shown as 2D vectors. Rotate each by its position. Watch how their dot product (attention contribution) depends only on relative position Δpos, not absolute positions.
Why not add, like sinusoidal? Sinusoidal embeddings are additive: Embed(x,i) = vx + PEi. When you dot-product two such embeddings, you get vx·vy + vx·PEj + PEi·vy + PEi·PEj — four terms, two of which depend on absolute positions i and j separately. RoPE's multiplicative rotation eliminates these cross-terms entirely. The dot product collapses to a single term that depends only on i−j.
python import torch import math def build_rope_cache(seq_len: int, d_head: int, base: float = 10000.0): """Precompute cos/sin tables for RoPE. Returns cos, sin each [seq_len, d_head/2].""" # Frequencies: theta_k = 1 / (base ** (2k / d_head)) for k = 0..d_head//2-1 half = d_head // 2 freqs = 1.0 / (base ** (torch.arange(0, half).float() / half)) # [half] positions = torch.arange(seq_len).float() # [seq_len] angles = torch.outer(positions, freqs) # [seq_len, half] return angles.cos(), angles.sin() # each [seq_len, half] def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): """Apply RoPE to query or key tensor. x: [batch, n_heads, seq_len, d_head] cos, sin: [seq_len, d_head/2] Returns: rotated x, same shape. """ half = x.shape[-1] // 2 x1, x2 = x[..., :half], x[..., half:] # split along head dim # Broadcast cos/sin over batch and heads: [1,1,seq,half] c = cos.unsqueeze(0).unsqueeze(0) s = sin.unsqueeze(0).unsqueeze(0) # 2D rotation for each pair: [x1*cos - x2*sin, x1*sin + x2*cos] rot_x1 = x1 * c - x2 * s rot_x2 = x1 * s + x2 * c return torch.cat([rot_x1, rot_x2], dim=-1) # Example: cos, sin = build_rope_cache(seq_len=4096, d_head=128) q = torch.randn(2, 32, 4096, 128) # [batch=2, heads=32, seq=4096, d_head=128] k = torch.randn(2, 32, 4096, 128) q_rot = apply_rope(q, cos, sin) k_rot = apply_rope(k, cos, sin) # Now q_rot and k_rot's dot products are purely relative-position dependent
Standard Multi-Head Attention (MHA) runs nheads attention heads in parallel, each with its own query, key, and value projections. At inference time, when generating token by token, you cannot recompute all past keys and values at every step — that would be quadratic in sequence length. Instead, you cache them in a KV cache.
The KV cache holds the key and value tensors for every layer, every head, and every past token. For a model with nlayers layers, nheads heads per layer, sequence length T, and head dimension dhead, stored in bf16 (2 bytes each):
For LLaMA-2 70B with nlayers=80, nheads=64, dhead=128, T=4096, batch=1:
That's more than one H100 GPU. KV cache memory is a serious bottleneck for deploying large models with long contexts or large batches. This is the problem that MQA and GQA solve.
Multi-Query Attention (MQA) (Shazeer 2019) keeps multiple query heads but shares a single key-value head across all queries. Every query head computes attention against the same keys and values. The KV cache now only stores one K and one V tensor per layer:
For LLaMA-2 70B equivalent: 2×80×1×4096×128×2 = 2.1 GB — a 64× reduction. But MQA trades some model expressiveness for this saving. Shazeer (2019) found a small perplexity hit; subsequent works found it depends heavily on model size and training data.
Grouped-Query Attention (GQA) (Ainslie et al. 2023) is the interpolation. Instead of nheads KV heads (MHA) or 1 (MQA), use nkv_heads KV heads, grouping queries into clusters. Each group of (nheads/nkv_heads) query heads shares one KV head:
LLaMA-2 70B uses nheads=64, nkv_heads=8 (8-way grouping). Cache savings: 64/8 = 8×. New cache size: 134 GB / 8 = 16.8 GB — fits on one A100 80GB.
Slide n_kv_heads from 1 (MQA) to n_heads (MHA). Watch how head groups change and the KV cache bar shrinks.
python import torch import torch.nn.functional as F from einops import rearrange def grouped_query_attention( q: torch.Tensor, # [B, n_heads, T_q, d_head] k: torch.Tensor, # [B, n_kv_heads, T_kv, d_head] v: torch.Tensor, # [B, n_kv_heads, T_kv, d_head] scale: float = None ) -> torch.Tensor: """GQA: n_heads queries, n_kv_heads key-value pairs (n_heads >= n_kv_heads). Each group of (n_heads // n_kv_heads) query heads attends the same K, V. """ B, n_heads, T_q, d_head = q.shape n_kv = k.shape[1] assert n_heads % n_kv == 0, "n_heads must be divisible by n_kv_heads" group_size = n_heads // n_kv if scale is None: scale = d_head ** -0.5 # Expand K, V to match n_heads by repeating each KV head group_size times k_exp = k.repeat_interleave(group_size, dim=1) # [B, n_heads, T_kv, d_head] v_exp = v.repeat_interleave(group_size, dim=1) # Standard scaled dot-product attention scores = torch.einsum('bhqd,bhkd->bhqk', q * scale, k_exp) # [B,n_heads,T_q,T_kv] weights = F.softmax(scores, dim=-1) out = torch.einsum('bhqk,bhkd->bhqd', weights, v_exp) # [B,n_heads,T_q,d_head] return out # Memory savings calculation: def kv_cache_gb(n_layers, n_kv_heads, T, d_head, bytes_per_val=2): """KV cache in GB for one sequence.""" # K and V, each [n_layers, n_kv_heads, T, d_head] total = 2 * n_layers * n_kv_heads * T * d_head * bytes_per_val return total / 1e9 print(f"LLaMA-2 70B MHA (64 heads): {kv_cache_gb(80,64,4096,128):.1f} GB") print(f"LLaMA-2 70B GQA (8 heads): {kv_cache_gb(80, 8,4096,128):.1f} GB") print(f"LLaMA-2 70B MQA (1 head): {kv_cache_gb(80, 1,4096,128):.2f} GB")
Once you've fixed your architecture choices (pre-norm, RMSNorm, SwiGLU, RoPE, GQA), you still need to set a handful of numerical hyperparameters: how wide is the model (dmodel), how many layers (nlayers), how many heads (nheads), how wide is the FFN (dff), and how large is the vocabulary (V). These values have surprising consensus across models.
Almost every non-gated model uses dff = 4 × dmodel. Kaplan et al. 2020 showed empirically that this hyperparameter has a wide "plateau" between 1× and 10× — the loss doesn't degrade dramatically if you deviate. But the 4× rule became a default because it worked well and everyone copied it.
For SwiGLU/GeGLU models: the iso-FLOP correction gives dff ≈ (8/3)×dmodel ≈ 2.67. In practice models round to clean numbers: 2.5, 3.5, or 3.5. Real-world values: Mistral 7B (3.5), LLaMA-2 70B (3.5), LLaMA 70B (2.68), Qwen 14B (2.67), DeepSeek 67B (2.68).
One wild outlier: T5 11B used dff = 65,536 with dmodel = 1,024 — a 64× ratio! Their follow-up T5 v1.1 reverted to 2.5× with GeGLU, implying the 64× was suboptimal.
Most models keep head dimension times number of heads equal to model dimension: nheads × dhead = dmodel. This is not a mathematical requirement — you could have dhead > dmodel/nheads (the T5 exception, where they used nheads=128, dhead=128, dmodel=1024 → ratio=16). But empirically, the 1:1 ratio works well and keeps the architecture simple.
| Model | dmodel | nheads | dhead | ratio |
|---|---|---|---|---|
| GPT-3 175B | 12288 | 96 | 128 | 1 |
| LLaMA-2 70B | 8192 | 64 | 128 | 1 |
| PaLM 540B | 18432 | 48 | 256 | 1.48 |
| T5 11B | 1024 | 128 | 128 | 16 (!) |
Should your model be deep (many layers, narrow dmodel) or wide (few layers, large dmodel)? Looking across real models, the sweet spot for dmodel/nlayers is roughly 100–200. BLOOM: 205. PaLM 540B: 156. GPT-3 and Mistral: 128. LLaMA-2: 102. GPT-2: 33 (outlier, very deep for its size).
The intuition: extremely deep models are harder to pipeline across GPUs (more sequential communication barriers) and have higher latency per token at inference. Extremely wide models may be over-parameterized in depth. The 100–200 range balances these concerns empirically.
Monolingual English models: 30,000–50,000 tokens. GPT-2/3: 50,257. LLaMA 1: 32,000. Multilingual / production systems: 100,000–256,000. GPT-4: 100,276. PaLM: 256,000. mT5: 250,000. Qwen 15B: 152,064.
Larger vocabulary = fewer tokens per sequence (more compact) but larger embedding matrix (dvocab × dmodel parameters) and a heavier final projection layer.
Modern frontier models mostly do no dropout during pretraining. The intuition: with trillions of tokens and a single pass through the data, there's virtually no overfitting risk. LLaMA, Chinchilla, PaLM: no dropout. But they do use weight decay (0.1 is typical). Why? Andriushchenko et al. 2023 showed weight decay interacts with the learning rate schedule (cosine decay) to improve optimization dynamics — it's not about preventing overfitting.
All the hyperparameter rules from Chapter 6 — FFN ratio, head dimension, aspect ratio — come together in this live parameter counter. Drag the sliders and watch how parameter count breaks down across embedding, attention, FFN, and output layers. Real-model presets let you verify your intuition against published numbers.
Adjust the architecture. Live breakdown of parameters per component. Check "SwiGLU" to apply the 8/3 FFN correction. Load presets to match real models.
The parameter count formula (no biases, SwiGLU optional) per transformer block:
For GQA with nkv_heads < nheads, the attention formula becomes: (nheads + 2×nkv_heads) × dhead × dmodel + dmodel2 (for WO). The KV projections are smaller.
Training a large model without instability is harder than it sounds. The culprit is usually the softmax function — and there are two places where it appears in a transformer.
The output softmax at the end converts logits (raw scores for each vocabulary token) to probabilities. If any logit grows very large, the exponential in softmax can overflow to infinity. Even before overflow, large logits concentrate the softmax output to near-zero for most tokens, making the gradient signal very sparse. PaLM pioneered a trick called the z-loss to prevent this.
The z-loss adds a small penalty proportional to the square of the log-partition function (the log-sum-exp of all logits). This penalizes the model for making any single logit too large, stabilizing the output softmax. Concretely, if z = log(∑i exp(logiti)) is the log-partition, the z-loss penalty is ε×z2 added to the cross-entropy loss. DCLM (2024) and OLMo 2 (2025) also use this.
The attention softmax inside each attention head can also be problematic. If query and key vectors have large magnitudes, the dot products QKT/√dhead can become very large even after scaling, driving the softmax into a saturated regime where gradients vanish (the "attention collapse" problem). Two tricks address this:
QK-Norm: Apply RMSNorm (or LayerNorm) to the query and key vectors before computing attention scores. This caps their magnitude and prevents the dot products from exploding. Used by DCLM, OLMo 2, Gemma 2, and originally proposed for vision transformers (Dehghani 2023, IDEFICS, Chameleon). Implementation is simple: add a learnable norm after the Q and K projections.
Logit soft-capping: Clamp attention logits to a maximum value via a tanh transformation. If raw logit is x, the capped version is max_val × tanh(x / max_val). This is differentiable and prevents logits from ever exceeding max_val. Used by Gemma 2. The potential downside: the tanh compression might hurt performance for cases where very large logits are genuinely meaningful — early reports suggest a small but real PPL cost.
Parallel layers are an orthogonal stability/efficiency trick. Normal transformers compute attention then FFN sequentially. A few models (GPT-J, PaLM, GPT-NeoX, Cohere Command A) compute attention and FFN in parallel (adding both to the residual simultaneously). When done carefully, the LayerNorm can be shared between both branches, and the first matrix multiply in attention and FFN can be fused — reducing memory bandwidth. The speed gain is real; whether it hurts quality depends on the scale.
No-bias recommendation also connects to stability. When linear layers have bias terms, those biases can grow arbitrarily without being regularized by the norm operations. Removing biases (as LLaMA-family does) constrains the optimization landscape and reduces the surface area for instability.
python import torch import torch.nn.functional as F def cross_entropy_with_z_loss( logits: torch.Tensor, # [B, T, V] targets: torch.Tensor, # [B, T] token ids z_loss_weight: float = 1e-4 ) -> torch.Tensor: """Cross-entropy with PaLM's z-loss stability term. z = log(sum_i exp(logit_i)) = log-partition z_loss = z_loss_weight * z^2 This penalizes very large logits (large log-partition). """ B, T, V = logits.shape logits_2d = logits.reshape(-1, V) # [B*T, V] targets_1d = targets.reshape(-1) # [B*T] # Standard cross-entropy ce_loss = F.cross_entropy(logits_2d, targets_1d) # Z-loss: penalize large log-partition = log(sum exp(logits)) log_z = torch.logsumexp(logits_2d, dim=-1) # [B*T], = log-partition per position z_loss = z_loss_weight * (log_z ** 2).mean() return ce_loss + z_loss class QKNorm(torch.nn.Module): """Apply RMSNorm to Q and K before attention score computation.""" def __init__(self, d_head: int): super().__init__() self.q_norm = RMSNorm(d_head) self.k_norm = RMSNorm(d_head) def forward(self, q, k): return self.q_norm(q), self.k_norm(k) # Insert between QKV projection and attention score computation
You've now seen every major architecture decision made in modern LLMs — derived from first principles, with math, code, and real-model evidence. Here's how it all fits together.
| Choice | Consensus (2024+) | Why | Notable exceptions |
|---|---|---|---|
| Norm placement | Pre-norm | Clean gradient highways, stable training | OPT-350M (post-norm) |
| Norm type | RMSNorm | Lower memory bandwidth, same quality | GPT-2/3 (LayerNorm) |
| Bias terms | None | Fewer params to move, less instability surface | BERT, GPT-2 |
| Activation | SwiGLU or GeGLU | Consistent small gains over ReLU/GELU | GPT-3 (GELU), Falcon 2 (ReLU) |
| Position encoding | RoPE | True relative position, length extrapolation | T5 (relative bias), OPT (learned abs) |
| Attention | GQA | 64× smaller KV cache vs MHA at 8 kv-heads | GPT-3, BLOOM (full MHA) |
| FFN ratio | 4× (or 8/3× for SwiGLU) | Empirical plateau; iso-FLOP correction | T5 (64×!) |
| Aspect ratio | dmodel/nlayers ≈ 100–200 | Parallelism & latency constraints | T5 (43), GPT-2 (33) |
| Dropout | None (pretraining) | Trillions of tokens, no overfit risk | GPT-2/3, T5, Qwen (0.1) |
| Stability | z-loss, QK-Norm | Prevents softmax overflow/collapse | LLaMA (no z-loss) |
| Model | dmodel | nlayers | nheads | nkv | dff/d | Vocab | ~Params |
|---|---|---|---|---|---|---|---|
| GPT-3 175B | 12288 | 96 | 96 | 96 | 4× | 50k | 175B |
| LLaMA-2 7B | 4096 | 32 | 32 | 32 | 2.67× | 32k | 7B |
| LLaMA-2 70B | 8192 | 80 | 64 | 8 | 3.5× | 32k | 70B |
| Mistral 7B | 4096 | 32 | 32 | 8 | 3.5× | 32k | 7B |
| PaLM 540B | 18432 | 118 | 48 | 48 | 4× | 256k | 540B |
This lesson covered the architecture; the following dive into adjacent topics: