A clean, minimal PyTorch re-implementation of the Gemma 4 forward pass. Mirrors the HuggingFace source as closely as possible while stripping away framework boilerplate so the architecture is easy to read and experiment with.
Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma4/modeling_gemma4.py
Shared primitives
RMSNorm — root-mean-square layer normalisationrotate_half / apply_rotary_pos_emb — 1-D RoPE helpersapply_2d_rope — 2-D RoPE for vision patch gridsText tower
TextRotaryEmbedding — dual-frequency RoPE (local / global)ScaledEmbedding — token embeddings scaled by √hidden_sizeTextMLP — SwiGLU dense FFNTextAttention — GQA + QK-norm + sliding-window + KV-sharingTextRouter — top-k MoE router with per-expert scaleTextExperts — sparse expert FFN bankTextDecoderLayer — full layer: Attention + MLP + optional MoE + per-layer gateTextModel — stack of decoder layers with shared RoPEVision tower
VisionRotaryEmbedding — 2-D RoPE for patch (x, y) positionsVisionPatchEmbedder — raw pixels → patch vectorsVisionMLP — SwiGLU FFN with optional activation clippingVisionAttention — ViT-style full attention with 2-D RoPE and QKV normsVisionEncoderLayer — single ViT layerVisionEncoder — stack of ViT layersVisionPooler — k×k spatial average pooling → soft tokensVisionModel — patch embedder → encoder → poolerMultimodal fusion
MultimodalEmbedder — project vision soft-tokens → text hidden sizeGemma4Model — merge image soft-tokens into token stream, run language modelGemma4ForCausalLM — text-only causal LM headGemma4ForConditionalGeneration — full multimodal generation model1 2from __future__ import annotations 3import math 4from dataclasses import dataclass, field 5from typing import Optional 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F
63@dataclass
64class TextConfig:
65 vocab_size: int = 262_144 66 hidden_size: int = 2560 67 num_hidden_layers: int = 34 68 num_attention_heads: int = 16 69 num_key_value_heads: int = 8 70 head_dim: int = 256 71 intermediate_size: int = 8192 # dense MLP width
MoE
73 num_experts: int = 0 # 0 = no MoE layers 74 num_experts_per_tok: int = 2 75 moe_layers: list[int] = field(default_factory=list) # which layers have MoE 76 expert_intermediate_size: int = 4096
KV sharing: layers >= kv_share_from reuse KV from layer kv_share_from-1
78 kv_share_from: int | None = None
Attention
80 sliding_window: int | None = 1024 # None → global attention on all layers 81 sliding_window_pattern: int = 6 # every Nth layer uses global attn 82 rope_theta: float = 10_000.0 83 rope_local_base_freq: float = 10_000.0 84 rope_global_base_freq: float = 1_000_000.0 85 attn_logit_softcapping: float | None = None 86 final_logit_softcapping: float | None = 30.0
Global attention layers use a larger head_dim and may have different KV head count
88 global_head_dim: int | None = None # None → same as head_dim for all layers 89 global_partial_rotary_factor: float = 1.0 # fraction of global_head_dim to rotate 90 num_global_key_value_heads: int | None = None # None → same as num_key_value_heads
attention_k_eq_v=True: full-attention layers share K/V projection (v_proj=None, V=k_proj(x)) use_v_norm applies v_norm regardless; when attention_k_eq_v=True, v_norm is always applied
93 attention_k_eq_v: bool = False 94 use_v_norm: bool = False
Per-layer input gate (Gemma4-specific)
96 hidden_size_per_layer_input: int = 0
Misc
98 rms_norm_eps: float = 1e-6 99 pad_token_id: int = 0 100 embed_scale: float = 1.0
103@dataclass
104class VisionConfig:
105 hidden_size: int = 768 106 num_hidden_layers: int = 16 107 num_attention_heads: int = 12 108 head_dim: int = 64 109 intermediate_size: int = 3072 110 patch_size: int = 16 111 position_embedding_size: int = 10240 # lookup table size per axis 112 pooling_kernel_size: int = 3 # spatial pooling factor 113 rope_theta: float = 100.0 114 rms_norm_eps: float = 1e-6 115 standardize: bool = False # optional post-pooling standardization 116 use_clipped_linears: bool = True # clamp activations in ClippableLinear
119@dataclass
120class Gemma4Config:
121 text: TextConfig = field(default_factory=TextConfig) 122 vision: VisionConfig = field(default_factory=VisionConfig)
Attention: 'bidirectional_vision' → vision tokens use bidirectional mask
124 use_bidirectional_attention: str = "none" 125 image_token_id: int = 258_880 # sentinel id for image placeholder tokens
Unlike LayerNorm, RMSNorm omits the mean-centering step and normalises only by the root-mean-square of the activations:
$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon}} \cdot w$$
where $w \in \mathbb{R}^d$ is a learned per-channel scale initialised to 1.
The computation is done in float32 for numerical stability regardless of
the input dtype, then cast back — matching the HuggingFace implementation exactly.
Pass with_scale=False to get the scale-free variant used for $v$-normalisation
inside attention.
132class RMSNorm(nn.Module):
148 def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
149 super().__init__() 150 self.eps = eps 151 self.weight = nn.Parameter(torch.ones(dim)) if with_scale else None
153 def forward(self, x: torch.Tensor) -> torch.Tensor:
Upcast to float32 for numerical stability, then cast back
155 x_f = x.float() 156 normed = x_f * torch.rsqrt(x_f.pow(2).mean(-1, keepdim=True) + self.eps) 157 if self.weight is not None: 158 normed = normed * self.weight.float() 159 return normed.to(x.dtype)
Splits the last dimension into two halves $[x_1, x_2]$ and returns $[-x_2, x_1]$. Combined with the cosine term this implements the complex-number rotation $x \cdot e^{i\theta}$.
162def rotate_half(x: torch.Tensor) -> torch.Tensor:
170 half = x.shape[-1] // 2 171 x1, x2 = x[..., :half], x[..., half:] 172 return torch.cat([-x2, x1], dim=-1)
Encodes absolute position into the query/key vectors by rotating pairs of channels by a position-dependent angle $\theta_i \cdot m$ (where $m$ is the token position and $\theta_i = \text{base}^{-2i/d}$):
$$x^\prime = x \cos\theta + \text{rotate\_half}(x) \sin\theta$$
Because rotation is an isometry, the dot product $q \cdot k$ depends only on the relative offset $m - n$, giving the model free relative-position information without any learned parameters.
175def apply_rotary_pos_emb( 176 x: torch.Tensor, 177 cos: torch.Tensor, 178 sin: torch.Tensor, 179 unsqueeze_dim: int = 1, 180) -> torch.Tensor:
194 cos = cos.unsqueeze(unsqueeze_dim) 195 sin = sin.unsqueeze(unsqueeze_dim) 196 return (x * cos) + (rotate_half(x) * sin)
Vision patches have a 2-D grid position $(r, c)$. We encode both axes independently by splitting the head channels into two equal halves and applying 1-D RoPE with the row frequencies to the first half and column frequencies to the second:
$$x^\prime_{\text{row}} = \text{RoPE}(x_{:d/2},\ \theta_r),\quad x^\prime_{\text{col}} = \text{RoPE}(x_{d/2:},\ \theta_c)$$
cos/sin carry both sets of frequencies concatenated along the last dim.
199def apply_2d_rope( 200 x: torch.Tensor, 201 cos: torch.Tensor, 202 sin: torch.Tensor, 203 position_ids: torch.Tensor, 204 unsqueeze_dim: int = 2, 205) -> torch.Tensor:
219 ndim = position_ids.shape[-1] # should be 2 220 channels_per_dim = 2 * (x.shape[-1] // (2 * ndim)) 221 x_parts = torch.split(x, [channels_per_dim] * ndim, dim=-1) 222 cos_parts = torch.split(cos, [channels_per_dim] * ndim, dim=-1) 223 sin_parts = torch.split(sin, [channels_per_dim] * ndim, dim=-1) 224 rotated = [ 225 apply_rotary_pos_emb(x_parts[k], cos_parts[k], sin_parts[k], unsqueeze_dim) 226 for k in range(ndim) 227 ] 228 return torch.cat(rotated, dim=-1)
Dual-frequency RoPE for Gemma 4 text. Local layers use rope_local_base_freq / head_dim. Global layers use rope_global_base_freq / global_head_dim, with optional partial rotation (partial_rotary_factor): only rope_angles = int(factor * global_head_dim // 2) dimensions are rotated; the rest carry zero inv_freq (→ cos=1, sin=0 = identity). Matches HF's _compute_proportional_rope_parameters exactly.
235class TextRotaryEmbedding(nn.Module):
244 def __init__(self, cfg: TextConfig):
245 super().__init__()
Local (sliding attention) — full rotation over head_dim
247 local_dim = cfg.head_dim 248 local_inv_freq = 1.0 / ( 249 cfg.rope_local_base_freq ** (torch.arange(0, local_dim, 2).float() / local_dim) 250 ) 251 self.register_buffer("local_inv_freq", local_inv_freq, persistent=False) 252 self._local_head_dim = local_dim
Global (full attention) — proportional/partial rotation over global_head_dim
255 global_dim = cfg.global_head_dim if cfg.global_head_dim is not None else cfg.head_dim 256 self._global_head_dim = global_dim 257 partial = getattr(cfg, "global_partial_rotary_factor", 1.0)
rope_angles: number of (freq, freq) pairs that are actually rotated
259 rope_angles = int(partial * global_dim // 2)
inv_freq for the rotated part: denominator is global_head_dim (matches HF)
261 rotated_inv = 1.0 / ( 262 cfg.rope_global_base_freq 263 ** (torch.arange(0, 2 * rope_angles, 2).float() / global_dim) 264 )
Pad the remaining nope dimensions with zero (identity: cos=1, sin=0)
266 nope_count = global_dim // 2 - rope_angles 267 if nope_count > 0: 268 global_inv_freq = torch.cat( 269 [rotated_inv, torch.zeros(nope_count)], dim=0 270 ) 271 else: 272 global_inv_freq = rotated_inv 273 self.register_buffer("global_inv_freq", global_inv_freq, persistent=False)
275 @torch.no_grad()
276 def forward( 277 self, 278 x: torch.Tensor, 279 position_ids: torch.Tensor, 280 layer_type: str = "global", # "local" or "global" 281 ):
Select inv_freq table for this layer type (local=small base, global=large base)
283 inv_freq = getattr(self, f"{layer_type}_inv_freq")
Outer product: inv_freq [D/2] × position [L] → [B, L, D/2] Each entry (b,l,i) = inv_freq[i] * position_ids[b,l]
286 inv_freq_exp = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) 287 pos_exp = position_ids[:, None, :].float() 288 freqs = (inv_freq_exp.float() @ pos_exp.float()).transpose(1, 2)
Duplicate to get full head_dim: [cos(θ_0·m), …, cos(θ_{D/2}·m), cos(θ_0·m), …] This layout means the second half mirrors the first, enabling rotate_half
291 emb = torch.cat([freqs, freqs], dim=-1) 292 return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
Token embedding scaled by sqrt(hidden_size) as in Gemma. The scale is stored as a float32 buffer and cast to the weight dtype at forward time, matching HF's Gemma4TextScaledWordEmbedding exactly (important for bf16).
295class ScaledEmbedding(nn.Embedding):
300 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, 301 embed_scale: float = 1.0):
302 super().__init__(num_embeddings, embedding_dim, padding_idx) 303 self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
305 def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
306 return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
A gated variant of the FFN where one linear branch acts as a gate:
$$\text{MLP}(x) = W_{\text{down}}\bigl(\text{GELU}(W_{\text{gate}}\, x) \odot W_{\text{up}}\, x\bigr)$$
The GELU gate (tanh approximation) selectively suppresses or amplifies each dimension of the intermediate representation before the final down-projection, giving the network a multiplicative, content-dependent nonlinearity at essentially no parameter cost beyond the extra gate projection.
Used as the dense shared FFN that runs on every layer. In MoE layers it runs in parallel with the sparse expert bank and both outputs are summed.
309class TextMLP(nn.Module):
325 def __init__(self, hidden_size: int, intermediate_size: int):
326 super().__init__() 327 self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 328 self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 329 self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
331 def forward(self, x: torch.Tensor) -> torch.Tensor:
GELU(gate) ⊙ up → down
333 return self.down_proj(F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
Standard scaled dot-product attention with several Gemma-4 twists:
$$\text{Attn}(Q,K,V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$
QK-normalisation — RMSNorm is applied to each query and key head before RoPE. Because the norms are learnable scales, the effective temperature is absorbed into the weights, so the scaling constant is $1.0$ rather than $1/\sqrt{d_k}$.
V-normalisation — A scale-free RMSNorm is applied to values, stabilising training in mixed-precision without introducing extra parameters.
Grouped-Query Attention — there are fewer KV heads than Q heads
($H_{kv} < H_q$). Each KV head is shared across $H_q / H_{kv}$ query heads
via expand + reshape before the dot product.
Local vs global layers — layers at positions divisible by
sliding_window_pattern use full (global) attention; the rest apply a
causal sliding-window mask of size sliding_window, limiting each token to
attending only the nearest $W$ positions.
KV sharing — the last num_kv_shared_layers layers reuse the key/value
states computed by the layer just before the shared block starts, saving
memory without a significant quality drop.
attention_k_eq_v — in global layers of large models, $V$ is derived from
the same linear projection as $K$ (i.e. v_proj = None and $V = k\_proj(x)$
before k_norm). This halves the KV projection cost.
336class TextAttention(nn.Module):
369 def __init__(self, cfg: TextConfig, layer_idx: int, is_kv_shared: bool = False):
370 super().__init__() 371 self.layer_idx = layer_idx 372 self.num_heads = cfg.num_attention_heads 373 self.is_kv_shared = is_kv_shared
Use _layer_types if set (authoritative), otherwise fall back to pattern
375 _layer_types = getattr(cfg, "_layer_types", None) 376 is_global = ( 377 _layer_types[layer_idx] == "full_attention" 378 if _layer_types is not None 379 else (layer_idx % cfg.sliding_window_pattern == 0) 380 ) 381 self.sliding_window = None if is_global else cfg.sliding_window
Global attention layers may use a larger head_dim and different KV head count
383 self.head_dim = ( 384 cfg.global_head_dim if (is_global and cfg.global_head_dim is not None) 385 else cfg.head_dim 386 ) 387 self.num_kv_heads = ( 388 cfg.num_global_key_value_heads 389 if (is_global and cfg.num_global_key_value_heads is not None) 390 else cfg.num_key_value_heads 391 ) 392 393 hs = cfg.hidden_size 394 kv_dim = self.num_kv_heads * self.head_dim 395 396 self.q_proj = nn.Linear(hs, self.num_heads * self.head_dim, bias=False)
KV-shared layers still have k/v projections; sharing only applies with a KV cache
398 self.k_proj = nn.Linear(hs, kv_dim, bias=False)
attention_k_eq_v: full-attention layers share K projection for V (v_proj=None) In that case V = v_norm(k_proj(x)) (same input projection, different norm, no RoPE)
401 self.use_alternative_attention = cfg.attention_k_eq_v and is_global 402 self.v_proj = ( 403 None if self.use_alternative_attention 404 else nn.Linear(hs, kv_dim, bias=False) 405 ) 406 self.o_proj = nn.Linear(self.num_heads * self.head_dim, hs, bias=False)
Per-head norms (applied before RoPE)
409 self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps) 410 self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)
v_norm: no learnable scale unless use_v_norm=True (4B+); always applied when k_eq_v
412 self.v_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps, with_scale=cfg.use_v_norm) 413 414 self.attn_logit_softcapping = cfg.attn_logit_softcapping
QK-norm (RMSNorm on q and k) is always applied, so scaling is 1.0, matching HF.
416 self.scaling = 1.0
418 def forward( 419 self, 420 hidden_states: torch.Tensor, # [B, L, D] 421 cos: torch.Tensor, 422 sin: torch.Tensor, 423 attention_mask: Optional[torch.Tensor] = None, 424 kv_cache: Optional[dict] = None, # simple dict cache for demo 425 ) -> torch.Tensor:
426 B, L, _ = hidden_states.shape 427 H, Hkv, Dh = self.num_heads, self.num_kv_heads, self.head_dim
── Queries ─────────────────────────────────────────────────────
430 q = self.q_proj(hidden_states).view(B, L, H, Dh) 431 q = self.q_norm(q) 432 q = apply_rotary_pos_emb(q, cos, sin, unsqueeze_dim=2) 433 q = q.transpose(1, 2) # [B, H, L, Dh]
── Keys & Values ────────────────────────────────────────────────
436 if self.is_kv_shared and kv_cache is not None and "shared_kv" in kv_cache:
Reuse the KV states stored by the designated anchor layer
438 k, v = kv_cache["shared_kv"] 439 else: 440 k_raw = self.k_proj(hidden_states).view(B, L, Hkv, Dh)
attention_k_eq_v: V uses the same raw k projection (before k_norm and RoPE)
442 v_raw = k_raw if self.use_alternative_attention else self.v_proj(hidden_states).view(B, L, Hkv, Dh) 443 k = self.k_norm(k_raw) 444 v = self.v_norm(v_raw) 445 k = apply_rotary_pos_emb(k, cos, sin, unsqueeze_dim=2) 446 k = k.transpose(1, 2) # [B, Hkv, L, Dh] 447 v = v.transpose(1, 2)
── GQA: expand KV heads to match Q heads ─────────────────────── Each KV head is broadcast to H/Hkv query heads by inserting a repeat dim
451 if Hkv != H: 452 expand = H // Hkv 453 k = k.unsqueeze(2).expand(-1, -1, expand, -1, -1).reshape(B, H, -1, Dh) 454 v = v.unsqueeze(2).expand(-1, -1, expand, -1, -1).reshape(B, H, -1, Dh)
── Scaled dot-product attention ──────────────────────────────── scaling=1.0 because QK-norms already control the variance (no 1/√d needed)
458 attn = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
Optional logit soft-capping: tanh(logit/cap) * cap keeps logits in (-cap, cap) This prevents softmax collapse in very long contexts
462 if self.attn_logit_softcapping is not None: 463 attn = attn / self.attn_logit_softcapping 464 attn = torch.tanh(attn) 465 attn = attn * self.attn_logit_softcapping
Causal / sliding-window mask: upper-triangle = -inf, recent window = 0
468 if attention_mask is not None: 469 attn = attn + attention_mask
Upcast to float32 for numerically stable softmax (matches HF eager impl)
472 attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
Weighted sum of values, merge heads, project back to hidden_size
474 out = torch.matmul(attn, v) # [B, H, L, Dh] 475 out = out.transpose(1, 2).reshape(B, L, H * Dh) 476 return self.o_proj(out)
Selects the $K$ best experts for each token and computes their routing weights.
For a token representation $x \in \mathbb{R}^D$:
479class TextRouter(nn.Module):
496 def __init__(self, cfg: TextConfig):
497 super().__init__() 498 self.top_k = cfg.num_experts_per_tok
Bug #2 fix: HF uses RMSNorm with_scale=False (pure normalisation, no learnable weight)
500 self.norm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps, with_scale=False)
Bug #1 fix: HF has a learned per-dim scale vector (NOT a scalar sqrt(D))
502 self.scale = nn.Parameter(torch.ones(cfg.hidden_size)) 503 self._scale_factor = cfg.hidden_size ** -0.5
Projection to expert logits
505 self.proj = nn.Linear(cfg.hidden_size, cfg.num_experts, bias=False)
Per-expert learned output scale
507 self.per_expert_scale = nn.Parameter(torch.ones(cfg.num_experts))
x: [T, D] (T = batch*seq flattened)
509 def forward(self, x: torch.Tensor):
Step 1: normalise + scale → stable router input (zero-mean, unit-ish variance) norm is scale-free RMSNorm; self.scale is learned per-dim; _scale_factor = D^{-0.5}
513 h = self.norm(x) * self.scale * self._scale_factor
Step 2: project to E expert logits — one score per expert per token
515 logits = self.proj(h) # [T, E]
Step 3: softmax over experts → probability distribution for each token
517 probs = F.softmax(logits, dim=-1) # [T, E]
Step 4: top-K selection — choose the K highest-probability experts
519 top_w, top_idx = torch.topk(probs, self.top_k, dim=-1) # [T, K]
Step 5: renormalise the K weights to sum to 1 (sparse softmax)
521 top_w = top_w / top_w.sum(dim=-1, keepdim=True)
Step 6: apply learned per-expert scalar α_e — allows global expert re-weighting
523 top_w = top_w * self.per_expert_scale[top_idx]
Return full prob distribution (for aux loss), top-K weights, and expert indices
525 return probs, top_w, top_idx
Holds $E$ independent SwiGLU FFNs. For each token only the $K$ experts selected by the router are evaluated — the rest are skipped entirely, keeping compute proportional to $K/E$ of the full dense cost.
Weight layout — gate_up_proj[e] $\in \mathbb{R}^{2D_i \times D}$ and
down_proj[e] $\in \mathbb{R}^{D \times D_i}$ are stacked along the first
dimension so that loading a single expert is a single tensor slice.
Dispatch loop — we build an [E, K, T] one-hot mask, iterate only over
active experts (those that received ≥ 1 token), gather the relevant token
vectors, run the SwiGLU FFN, and scatter-add results back via index_add_.
Numerics note — index_add_ accumulates contributions in a deterministic
order (one token at a time), matching HuggingFace's eager implementation
bit-exactly in bfloat16. The alternative grouped_mm kernel accumulates in
a different order, producing ~0.06 diff per layer that grows to ~25 max_diff
over 30 layers (see TEST_REPORT_26B_A4B.md).
528class TextExperts(nn.Module):
550 def __init__(self, cfg: TextConfig):
551 super().__init__() 552 E = cfg.num_experts 553 D = cfg.hidden_size 554 Di = cfg.expert_intermediate_size
Bug #3 fix: match HF weight layout [E, out, in] so checkpoint loads without permute
556 self.gate_up_proj = nn.Parameter(torch.empty(E, 2 * Di, D)) 557 self.down_proj = nn.Parameter(torch.empty(E, D, Di)) 558 nn.init.normal_(self.gate_up_proj, std=0.02) 559 nn.init.normal_(self.down_proj, std=0.02) 560 self.num_experts = E 561 self.act = nn.GELU(approximate="tanh")
563 def forward( 564 self, 565 x: torch.Tensor, # [T, D] 566 top_k_index: torch.Tensor, # [T, K] 567 top_k_weights: torch.Tensor, # [T, K] 568 ) -> torch.Tensor:
569 T, D = x.shape
Accumulator: we scatter expert outputs back into this tensor
571 out = torch.zeros(T, D, dtype=x.dtype, device=x.device)
Build a 3-way membership mask then transpose to expert-major order expert_mask[e, k, t] = 1 iff token t was assigned to expert e in slot k
575 expert_mask = F.one_hot(top_k_index, self.num_experts) # [T, K, E] 576 expert_mask = expert_mask.permute(2, 1, 0) # [E, K, T]
Only iterate over experts that actually received tokens (skip dead experts)
578 active_experts = (expert_mask.sum(dim=(1, 2)) > 0).nonzero(as_tuple=True)[0] 579 580 for expert_idx in active_experts: 581 e = expert_idx.item()
Find which (k-slot, token) pairs are assigned to expert e
583 k_slot, tok_idx = torch.where(expert_mask[e]) # [n], [n] 584 h = x[tok_idx] # [n, D] — gather tokens
SwiGLU FFN for this expert (gate and up projections packed into one matrix)
586 gate, up = F.linear(h, self.gate_up_proj[e]).chunk(2, dim=-1) 587 h = self.act(gate) * up # GELU gate 588 h = F.linear(h, self.down_proj[e]) # [n, D]
Scale by routing weight (how much this expert contributes for this token)
590 h = h * top_k_weights[tok_idx, k_slot, None]
index_add_: deterministic accumulation order → bit-exact with HF eager
592 out.index_add_(0, tok_idx, h.to(out.dtype)) 593 594 return out
Each layer runs the following sequence of operations:
1. Self-Attention block $$h = x + \text{PostAttnNorm}\!\left(\text{Attn}\!\left(\text{PreNorm}(x)\right)\right)$$
2. Dense MLP (always active, even in MoE layers) $$h_{\text{mlp}} = \text{MLP}\!\left(\text{PreFFNNorm}(h)\right)$$
3. Sparse MoE (only in MoE-designated layers; 26B model)
The dense MLP and the sparse expert bank run in parallel on the same pre-FFN residual and their outputs are added: $$h = h + \text{PostFFNNorm}\!\left(h_{\text{mlp}} + h_{\text{moe}}\right)$$
4. Per-layer input gate (E4B and larger models)
A lightweight gating network injects a per-layer side-channel $p_\ell$ (derived from the full token embeddings) at every layer: $$h = h + \text{Norm}\!\left(W_{\text{proj}}\!\left(\text{GELU}(W_{\text{gate}}\,h) \odot p_\ell\right)\right)$$
This allows every layer to directly reference the original token context, acting as a persistent residual information highway.
5. Layer scalar
The entire layer output is multiplied by a learned scalar $\lambda$ (init 1), giving the optimiser a soft mechanism to reduce a layer's contribution early in training.
597class TextDecoderLayer(nn.Module):
630 def __init__(self, cfg: TextConfig, layer_idx: int):
631 super().__init__() 632 is_kv_shared = (cfg.kv_share_from is not None and layer_idx >= cfg.kv_share_from) 633 self.self_attn = TextAttention(cfg, layer_idx, is_kv_shared=is_kv_shared) 634 635 self.input_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 636 self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 637 self.pre_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 638 self.post_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 639 640 self.mlp = TextMLP(cfg.hidden_size, cfg.intermediate_size) 641 642 self.enable_moe = layer_idx in cfg.moe_layers and cfg.num_experts > 0 643 if self.enable_moe: 644 self.router = TextRouter(cfg) 645 self.experts = TextExperts(cfg) 646 self.pre_feedforward_layernorm_2 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 647 self.post_feedforward_layernorm_1 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 648 self.post_feedforward_layernorm_2 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
Per-layer input gate
651 self.use_per_layer_gate = cfg.hidden_size_per_layer_input > 0 652 if self.use_per_layer_gate: 653 D, Dp = cfg.hidden_size, cfg.hidden_size_per_layer_input 654 self.per_layer_input_gate = nn.Linear(D, Dp, bias=False) 655 self.per_layer_projection = nn.Linear(Dp, D, bias=False) 656 self.post_per_layer_input_norm = RMSNorm(D, eps=cfg.rms_norm_eps) 657 self.act_fn = nn.GELU(approximate="tanh")
Learned output scalar (initialised to 1)
660 self.layer_scalar = nn.Parameter(torch.ones(1))
662 def forward( 663 self, 664 hidden_states: torch.Tensor, 665 cos: torch.Tensor, 666 sin: torch.Tensor, 667 attention_mask: Optional[torch.Tensor] = None, 668 per_layer_input: Optional[torch.Tensor] = None, 669 kv_cache: Optional[dict] = None, 670 ) -> torch.Tensor:
── 1. Self-Attention ────────────────────────────────────────────
673 residual = hidden_states 674 hidden_states = self.input_layernorm(hidden_states) 675 hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask, kv_cache) 676 hidden_states = self.post_attention_layernorm(hidden_states) 677 hidden_states = residual + hidden_states
── 2. Dense MLP (always) ────────────────────────────────────────
680 residual = hidden_states 681 hidden_states = self.pre_feedforward_layernorm(hidden_states) 682 hidden_states = self.mlp(hidden_states)
── 3. MoE block (some layers) ───────────────────────────────────
685 if self.enable_moe:
Separate post-norm path for the dense MLP output (26B model only). The dense MLP and sparse MoE run in parallel on the same pre-FFN residual and their contributions are summed before the shared post-FFN norm below.
689 h_mlp = self.post_feedforward_layernorm_1(hidden_states)
Flatten to [B*L, D] so the router and experts see tokens, not (batch, seq)
692 flat = residual.reshape(-1, residual.shape[-1]) # [B*L, D]
Router selects top-K experts for each token: returns routing weights + indices
694 _, top_w, top_idx = self.router(flat)
Pre-norm before expert FFNs (separate norm for MoE path)
696 h_moe = self.pre_feedforward_layernorm_2(flat)
Sparse expert dispatch: only active experts run; scatter back via index_add_
698 h_moe = self.experts(h_moe, top_idx, top_w) 699 h_moe = h_moe.reshape(residual.shape)
Post-norm on MoE output (separate from the dense MLP norm above)
701 h_moe = self.post_feedforward_layernorm_2(h_moe)
Dense MLP + sparse MoE contributions are combined additively
704 hidden_states = h_mlp + h_moe # combine 705 706 hidden_states = self.post_feedforward_layernorm(hidden_states) 707 hidden_states = residual + hidden_states
── 4. Per-layer input gate ──────────────────────────────────────
710 if self.use_per_layer_gate and per_layer_input is not None: 711 residual = hidden_states 712 gate = self.act_fn(self.per_layer_input_gate(hidden_states)) 713 hidden_states = gate * per_layer_input 714 hidden_states = self.per_layer_projection(hidden_states) 715 hidden_states = self.post_per_layer_input_norm(hidden_states) 716 hidden_states = residual + hidden_states
── 5. Layer scalar ──────────────────────────────────────────────
719 hidden_states = hidden_states * self.layer_scalar 720 return hidden_states
Stacks $N$ TextDecoderLayers with a shared TextRotaryEmbedding that
pre-computes $(cos, sin)$ tensors for the full input sequence length.
Per-layer input pipeline (E4B model, hidden_size_per_layer_input > 0):
Before the transformer runs, a compact side-channel tensor $P \in \mathbb{R}^{B \times L \times N \times D_p}$ is computed by blending two projections of the input:
$$P = \frac{1}{\sqrt{2}}\;\text{Norm}\!\left(\frac{W_{\text{proj}}\,x}{\sqrt{D}} + E_{\text{tok}}\right)$$
where $E_{\text{tok}}$ is a second token embedding table (scaled by $\sqrt{D_p}$).
The $\ell$-th slice $P_{:,:,\ell,:}$ is passed as per_layer_input to layer $\ell$.
When inputs_embeds is provided instead of input_ids (multimodal path), the
per-layer projection is computed from the merged embeddings (vision + text),
so image features influence the side-channel at every layer.
723class TextModel(nn.Module):
745 def __init__(self, cfg: TextConfig):
746 super().__init__() 747 self.embed_tokens = ScaledEmbedding( 748 cfg.vocab_size, cfg.hidden_size, cfg.pad_token_id, cfg.embed_scale 749 ) 750 self.rotary_emb = TextRotaryEmbedding(cfg) 751 self.layers = nn.ModuleList([TextDecoderLayer(cfg, i) for i in range(cfg.num_hidden_layers)]) 752 self.norm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 753 self.sliding_window_pattern = cfg.sliding_window_pattern 754 self._layer_types = getattr(cfg, "_layer_types", None)
Per-layer input gate embedding and projection chain (optional; E4B and larger). Matches HF Gemma4TextModel exactly: embed_tokens_per_layer : Embedding(vocab, NDp), scaled by sqrt(Dp) per_layer_model_projection : Linear(D, NDp), scaled by D-0.5 per_layer_projection_norm : RMSNorm(Dp) per_layer_input_scale : 2-0.5 (constant)
762 Dp = cfg.hidden_size_per_layer_input 763 N = cfg.num_hidden_layers 764 if Dp > 0: 765 self.embed_tokens_per_layer = ScaledEmbedding( 766 cfg.vocab_size, N * Dp, cfg.pad_token_id, embed_scale=Dp ** 0.5 767 ) 768 self.per_layer_model_projection = nn.Linear(cfg.hidden_size, N * Dp, bias=False) 769 self.per_layer_model_projection_scale = cfg.hidden_size ** -0.5 770 self.per_layer_projection_norm = RMSNorm(Dp, eps=cfg.rms_norm_eps) 771 self.per_layer_input_scale = 2.0 ** -0.5 772 else: 773 self.embed_tokens_per_layer = None 774 self._num_layers = N 775 self._per_layer_dim = Dp
Token-embedding part only — [B, L, N, Dp]. No projection yet.
777 def _get_per_layer_embed(self, input_ids: torch.Tensor) -> Optional[torch.Tensor]:
779 if self.embed_tokens_per_layer is None: 780 return None 781 return self.embed_tokens_per_layer(input_ids).reshape( 782 *input_ids.shape, self._num_layers, self._per_layer_dim 783 )
Add the inputs_embeds projection to embed_part and scale. → [B, L, N, Dp].
785 def _project_per_layer(self, inputs_embeds: torch.Tensor, 786 embed_part: torch.Tensor) -> torch.Tensor:
788 proj = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale 789 proj = proj.reshape(*inputs_embeds.shape[:-1], self._num_layers, self._per_layer_dim) 790 proj = self.per_layer_projection_norm(proj) 791 return (proj + embed_part) * self.per_layer_input_scale
Compute [B, L, N, Dp] per-layer inputs (text-only path).
793 def _compute_per_layer_inputs(self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor) -> Optional[torch.Tensor]:
795 embed = self._get_per_layer_embed(input_ids) 796 if embed is None: 797 return None 798 return self._project_per_layer(inputs_embeds, embed)
800 def forward( 801 self, 802 input_ids: Optional[torch.Tensor] = None, 803 attention_mask: Optional[torch.Tensor] = None, 804 kv_cache: Optional[dict] = None, 805 per_layer_inputs: Optional[torch.Tensor] = None, # [B, L, num_layers, Dp] if pre-computed 806 inputs_embeds: Optional[torch.Tensor] = None, # pre-merged embeddings (multimodal) 807 ) -> torch.Tensor:
Accept either token ids (text-only) or pre-built embeddings (multimodal path)
809 if inputs_embeds is None: 810 x = self.embed_tokens(input_ids) # [B, L, D] 811 else: 812 x = inputs_embeds # [B, L, D], already embedded 813 814 B, L, _ = x.shape
Simple 0..L-1 position indices — causal order; no offset needed for prefill
816 position_ids = torch.arange(L, device=x.device).unsqueeze(0) # [1, L]
Compute the full per-layer side-channel P ∈ [B, L, N, Dp] once up front (text-only path; multimodal path pre-computes this with merged embeddings)
820 if per_layer_inputs is None and input_ids is not None and self.embed_tokens_per_layer is not None: 821 per_layer_inputs = self._compute_per_layer_inputs(input_ids, x) 822 823 for i, layer in enumerate(self.layers):
Alternate between local (sliding-window) and global RoPE frequencies
825 if self._layer_types is not None: 826 layer_type = "local" if self._layer_types[i] == "sliding_attention" else "global" 827 else: 828 layer_type = "local" if (i % self.sliding_window_pattern != 0) else "global" 829 cos, sin = self.rotary_emb(x, position_ids, layer_type=layer_type)
Slice the i-th layer's per-layer input vector — [B, L, Dp]
831 pli = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None 832 x = layer(x, cos, sin, attention_mask=attention_mask, 833 per_layer_input=pli, kv_cache=kv_cache)
Final RMSNorm before the output projection / logit head
836 return self.norm(x)
2-D RoPE for vision patches. pixel_position_ids: [B, N, 2] (x, y per patch; -1 = padding) Returns cos/sin with shape [B, N, head_dim] where the head_dim is split evenly between x and y rotations.
843class VisionRotaryEmbedding(nn.Module):
850 def __init__(self, head_dim: int, base: float = 10_000.0):
851 super().__init__() 852 dim = head_dim // 2 # half for x, half for y 853 inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 854 self.register_buffer("inv_freq", inv_freq, persistent=False)
856 @torch.no_grad()
x: [B, N, D]; pixel_position_ids: [B, N, 2]
857 def forward(self, x: torch.Tensor, pixel_position_ids: torch.Tensor):
859 all_cos, all_sin = [], []
inv_freq shape [F] → expand to [B, F, 1] for matrix multiply with positions
861 inv = self.inv_freq[None, :, None].expand(pixel_position_ids.shape[0], -1, 1) 862 for dim_i in range(2):
Process x-axis (dim_i=0) then y-axis (dim_i=1) separately
864 pos = pixel_position_ids[:, :, dim_i] # [B, N] integer patch position 865 pos_exp = pos[:, None, :].float() # [B, 1, N] for broadcast matmul
Outer product: frequencies × positions → [B, F, N] → transpose to [B, N, F]
867 freqs = (inv.float() @ pos_exp.float()).transpose(1, 2) # [B, N, dim/2]
Duplicate freqs to fill full half-head_dim (standard RoPE cos/sin embedding)
869 emb = torch.cat([freqs, freqs], dim=-1) # [B, N, dim] 870 all_cos.append(emb.cos()) 871 all_sin.append(emb.sin())
Concatenate x-axis and y-axis embeddings → [B, N, head_dim] where head_dim = 2*dim
873 cos = torch.cat(all_cos, dim=-1).to(x.dtype) # [B, N, head_dim] 874 sin = torch.cat(all_sin, dim=-1).to(x.dtype) 875 return cos, sin
Project pixel patches to hidden_size and add learned 2-D positional embeddings.
Matches HF Gemma4VisionPatchEmbedder exactly:
878class VisionPatchEmbedder(nn.Module):
887 def __init__(self, cfg: VisionConfig):
888 super().__init__() 889 patch_dim = cfg.patch_size * cfg.patch_size * 3 890 self.input_proj = nn.Linear(patch_dim, cfg.hidden_size, bias=False)
Single parameter table covering both spatial axes, matching HF layout
892 self.position_embedding_table = nn.Parameter( 893 torch.ones(2, cfg.position_embedding_size, cfg.hidden_size) 894 ) 895 self.position_embedding_size = cfg.position_embedding_size
897 def _position_embeddings( 898 self, 899 pixel_position_ids: torch.Tensor, # [B, N, 2] 900 padding_mask: torch.Tensor, # [B, N] True = padding 901 ) -> torch.Tensor:
902 pos = pixel_position_ids.clamp(min=0) # [B, N, 2]
one_hot: [B, N, 2, position_embedding_size]
904 one_hot = F.one_hot(pos, num_classes=self.position_embedding_size)
permute to [B, 2, N, pos_emb_size] for matmul with table [2, pos_emb_size, D]
906 one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
[B, 2, N, D] → sum over axes → [B, N, D]
908 pos_embed = (one_hot @ self.position_embedding_table).sum(dim=1)
Zero out padding
910 pos_embed = torch.where(padding_mask.unsqueeze(-1), torch.zeros_like(pos_embed), pos_embed) 911 return pos_embed
913 def forward( 914 self, 915 pixel_values: torch.Tensor, # [B, N, patch_dim] 916 pixel_position_ids: torch.Tensor, # [B, N, 2] (x, y; -1 = padding) 917 padding_mask: torch.Tensor, # [B, N] True = padding 918 ) -> torch.Tensor:
Normalise pixel values from [0, 1] to [-1, 1] (standard ViT preprocessing)
920 pixel_values = 2.0 * (pixel_values - 0.5)
Linear projection: each flattened patch (patch_size² × 3 channels) → hidden_size
922 h = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))
Add 2-D learned position embeddings (row + column axes combined, see _position_embeddings)
924 h = h + self._position_embeddings(pixel_position_ids, padding_mask)
Zero out padding patch positions so they don't contaminate downstream computation
926 h = h.masked_fill(padding_mask.unsqueeze(-1), 0.0) 927 return h
nn.Linear wrapped with optional input/output clamping (matches HF Gemma4ClippableLinear). Clip bounds are stored as buffers and loaded from checkpoint. When use_clipped_linears=False (default), bounds stay at ±inf (no-op clamp).
930class ClippableLinear(nn.Module):
936 def __init__(self, in_features: int, out_features: int, use_clipped_linears: bool = False):
937 super().__init__() 938 self.linear = nn.Linear(in_features, out_features, bias=False) 939 if use_clipped_linears: 940 self.register_buffer("input_min", torch.tensor(-float("inf"))) 941 self.register_buffer("input_max", torch.tensor( float("inf"))) 942 self.register_buffer("output_min", torch.tensor(-float("inf"))) 943 self.register_buffer("output_max", torch.tensor( float("inf"))) 944 self.use_clipped_linears = use_clipped_linears
946 def forward(self, x: torch.Tensor) -> torch.Tensor:
947 if self.use_clipped_linears: 948 x = torch.clamp(x, self.input_min, self.input_max) 949 x = self.linear(x) 950 if self.use_clipped_linears: 951 x = torch.clamp(x, self.output_min, self.output_max) 952 return x
SwiGLU FFN used in the vision encoder (matches HF Gemma4VisionMLP).
955class VisionMLP(nn.Module):
957 def __init__(self, cfg: VisionConfig):
958 super().__init__() 959 clip = cfg.use_clipped_linears 960 self.gate_proj = ClippableLinear(cfg.hidden_size, cfg.intermediate_size, clip) 961 self.up_proj = ClippableLinear(cfg.hidden_size, cfg.intermediate_size, clip) 962 self.down_proj = ClippableLinear(cfg.intermediate_size, cfg.hidden_size, clip)
964 def forward(self, x: torch.Tensor) -> torch.Tensor:
965 return self.down_proj(F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
ViT-style full attention with: • QKV norms (same as text, for training stability) • 2-D RoPE applied to Q and K
968class VisionAttention(nn.Module):
974 def __init__(self, cfg: VisionConfig):
975 super().__init__() 976 H, Dh = cfg.num_attention_heads, cfg.head_dim 977 D = cfg.hidden_size 978 self.num_heads = H 979 self.head_dim = Dh
HF Gemma4VisionAttention uses scaling=1.0, not 1/sqrt(head_dim)
981 self.scaling = 1.0 982 983 clip = cfg.use_clipped_linears 984 self.q_proj = ClippableLinear(D, H * Dh, clip) 985 self.k_proj = ClippableLinear(D, H * Dh, clip) 986 self.v_proj = ClippableLinear(D, H * Dh, clip) 987 self.o_proj = ClippableLinear(H * Dh, D, clip) 988 989 self.q_norm = RMSNorm(Dh, eps=cfg.rms_norm_eps, with_scale=True) 990 self.k_norm = RMSNorm(Dh, eps=cfg.rms_norm_eps, with_scale=True)
v_norm has NO learnable scale (with_scale=False) — just bare normalisation
992 self.v_norm = RMSNorm(Dh, eps=cfg.rms_norm_eps, with_scale=False) 993 994 self.rotary_emb = VisionRotaryEmbedding(Dh, cfg.rope_theta)
996 def forward( 997 self, 998 hidden_states: torch.Tensor, # [B, N, D] 999 pixel_position_ids: torch.Tensor, # [B, N, 2] 1000 attention_mask: Optional[torch.Tensor] = None, 1001 ) -> torch.Tensor:
1002 B, N, _ = hidden_states.shape 1003 H, Dh = self.num_heads, self.head_dim
Project to Q, K, V and reshape to per-head tensors [B, N, H, Dh]
1006 q = self.q_proj(hidden_states).view(B, N, H, Dh) 1007 k = self.k_proj(hidden_states).view(B, N, H, Dh) 1008 v = self.v_proj(hidden_states).view(B, N, H, Dh)
Per-head RMSNorm: stabilises attention logit scale, scaling=1.0 (no 1/√d)
1011 q = self.q_norm(q) 1012 k = self.k_norm(k)
v_norm is scale-free (no learnable weight) — just normalises the magnitude
1014 v = self.v_norm(v)
Compute 2-D RoPE frequencies from (row, col) patch positions
1017 cos, sin = self.rotary_emb(hidden_states, pixel_position_ids)
Apply 2-D RoPE: first head_dim/2 channels encode row, second half encode col
1019 q = apply_2d_rope(q, cos, sin, pixel_position_ids, unsqueeze_dim=2) 1020 k = apply_2d_rope(k, cos, sin, pixel_position_ids, unsqueeze_dim=2)
Transpose to [B, H, N, Dh] for batched matrix multiply
1023 q = q.transpose(1, 2) 1024 k = k.transpose(1, 2) 1025 v = v.transpose(1, 2)
Scaled dot-product: scaling=1.0 because QK norms handle the variance
1028 attn = (q @ k.transpose(-2, -1)) * self.scaling
Optional attention mask (e.g. to block padding patches from attending)
1030 if attention_mask is not None: 1031 attn = attn + attention_mask # additive: 0 = attend, -inf = block
Upcast to float32 for numerically stable softmax
1033 attn = F.softmax(attn.float(), dim=-1).to(q.dtype)
Weighted sum of values, merge heads, project back to D
1036 out = (attn @ v).transpose(1, 2).reshape(B, N, H * Dh) 1037 return self.o_proj(out)
ViT encoder block with sandwich layernorms (pre+post norm around both attention and MLP).
1040class VisionEncoderLayer(nn.Module):
1045 def __init__(self, cfg: VisionConfig):
1046 super().__init__() 1047 self.self_attn = VisionAttention(cfg) 1048 self.mlp = VisionMLP(cfg) 1049 1050 self.input_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 1051 self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 1052 self.pre_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) 1053 self.post_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
1055 def forward( 1056 self, 1057 hidden_states: torch.Tensor, 1058 pixel_position_ids: torch.Tensor, 1059 attention_mask: Optional[torch.Tensor] = None, 1060 ) -> torch.Tensor:
── Attention block with sandwich layernorms ──────────────────────── Pre-norm: normalise before attention (stabilises gradient flow)
1063 residual = hidden_states 1064 hidden_states = self.input_layernorm(hidden_states) 1065 hidden_states = self.self_attn(hidden_states, pixel_position_ids, attention_mask)
Post-norm: second normalisation on the attention output before adding residual (double-norm "sandwich" differs from standard pre-norm transformers)
1068 hidden_states = self.post_attention_layernorm(hidden_states) 1069 hidden_states = residual + hidden_states
── MLP block with sandwich layernorms ──────────────────────────────
1072 residual = hidden_states 1073 hidden_states = self.pre_feedforward_layernorm(hidden_states) 1074 hidden_states = self.mlp(hidden_states)
Post-FFN norm: same sandwich pattern as attention block
1076 hidden_states = self.post_feedforward_layernorm(hidden_states) 1077 hidden_states = residual + hidden_states 1078 1079 return hidden_states
1082class VisionEncoder(nn.Module):
1083 def __init__(self, cfg: VisionConfig):
1084 super().__init__() 1085 self.layers = nn.ModuleList([VisionEncoderLayer(cfg) for _ in range(cfg.num_hidden_layers)])
1087 def forward( 1088 self, 1089 hidden_states: torch.Tensor, 1090 pixel_position_ids: torch.Tensor, 1091 attention_mask: Optional[torch.Tensor] = None, 1092 ) -> torch.Tensor:
Sequential ViT layers: each applies sandwich-norm attention + MLP with residual connections
1094 for layer in self.layers: 1095 hidden_states = layer(hidden_states, pixel_position_ids, attention_mask)
Final hidden states: [B, N, D] — all patch positions, including padding
1097 return hidden_states
Reduces the $N$ patch tokens to $N / k^2$ soft tokens by averaging each non-overlapping $k \times k$ block of patches (default $k = 3$, so 9 patches → 1 soft token):
$$\text{soft\_tok}_{(r',c')} = \frac{1}{k^2} \sum_{dr=0}^{k-1}\sum_{dc=0}^{k-1} h_{(kr'+dr,\, kc'+dc)}$$
The output is scaled by $\sqrt{D_v}$ before returning, matching the HF
implementation. Padding tokens (indicated by padding_positions) are
masked out and stripped from the output so downstream layers see only
valid content.
This aggressive pooling is key to keeping the token count manageable: a $336 \times 336$ image at patch size 16 gives $21 \times 21 = 441$ raw patches, which pool down to $7 \times 7 = 49$ soft tokens.
1100class VisionPooler(nn.Module):
1119 def __init__(self, cfg: VisionConfig):
1120 super().__init__() 1121 self.kernel = cfg.pooling_kernel_size 1122 self.scale = cfg.hidden_size ** 0.5 1123 self.standardize = cfg.standardize 1124 if cfg.standardize: 1125 self.std_bias = nn.Parameter(torch.zeros(cfg.hidden_size)) 1126 self.std_scale = nn.Parameter(torch.ones(cfg.hidden_size))
Position-aware average pooling — groups patches by their (x//k, y//k) bin.
1128 def _avg_pool_by_positions( 1129 self, 1130 hidden_states: torch.Tensor, # [B, N, D] 1131 pixel_position_ids: torch.Tensor, # [B, N, 2] 1132 output_length: int, 1133 ) -> tuple[torch.Tensor, torch.Tensor]:
1135 N = hidden_states.shape[1] 1136 k = int((N // output_length) ** 0.5) 1137 k2 = k * k 1138 if k2 * output_length != N: 1139 raise ValueError(f"Cannot pool {N} patches to {output_length}: {k}^2 × {output_length} ≠ {N}")
Clamp padding (-1) positions to 0 — padding is zeroed out already
1142 pos = pixel_position_ids.clamp(min=0) # [B, N, 2] 1143 max_x = pos[..., 0].max(dim=-1, keepdim=True)[0] + 1 # [B, 1]
Kernel index: which (col_bin, row_bin) does each patch fall into?
1145 kx = torch.div(pos[..., 0], k, rounding_mode="floor") # [B, N] 1146 ky = torch.div(pos[..., 1], k, rounding_mode="floor") # [B, N] 1147 kernel_idxs = kx + (max_x // k) * ky # [B, N] linear index
One-hot weighted average: each kernel bin accumulates 1/k² of each patch
1150 weights = F.one_hot(kernel_idxs.long(), output_length).float() / k2 # [B, N, L] 1151 output = weights.transpose(1, 2) @ hidden_states.float() # [B, L, D] 1152 valid_mask = (weights != 0).any(dim=1) # [B, L] 1153 return output.to(hidden_states.dtype), valid_mask
1155 def forward( 1156 self, 1157 hidden_states: torch.Tensor, # [B, N, D] 1158 pixel_position_ids: torch.Tensor, # [B, N, 2] 1159 padding_mask: torch.Tensor, # [B, N] True = padding 1160 ) -> tuple[torch.Tensor, torch.Tensor]:
1161 B, N, D = hidden_states.shape
Target output length after pooling: N / k² soft tokens per image
1163 output_length = N // (self.kernel ** 2)
Zero-out padding patches before pooling so they don't bias the averages
1166 hidden_states = hidden_states.masked_fill(padding_mask.unsqueeze(-1), 0.0) 1167 1168 if N != output_length:
Spatial pooling: each k×k neighborhood of patches → one soft token
1170 hidden_states, valid_mask = self._avg_pool_by_positions( 1171 hidden_states, pixel_position_ids, output_length 1172 ) 1173 else:
No pooling needed (kernel=1); valid mask is just the non-padding positions
1175 valid_mask = ~padding_mask
Scale pooled vectors by √D_vision (matches HF implementation)
1178 hidden_states = hidden_states * self.scale
Optional channel-wise standardization (bias + scale per feature dim)
1181 if self.standardize: 1182 hidden_states = (hidden_states - self.std_bias) * self.std_scale
Strip padding soft tokens → return flat [M, D] tensor where M = valid pooled patches valid_mask: [B, L] boolean indicates which soft tokens are real vs padding
1186 hidden_states = hidden_states[valid_mask] 1187 return hidden_states, valid_mask
Full vision encoder pipeline: pixels → PatchEmbedder → TransformerEncoder → VisionPooler → soft tokens
1190class VisionModel(nn.Module):
1195 def __init__(self, cfg: VisionConfig):
1196 super().__init__() 1197 self.patch_embedder = VisionPatchEmbedder(cfg) 1198 self.encoder = VisionEncoder(cfg) 1199 self.pooler = VisionPooler(cfg)
1201 def forward( 1202 self, 1203 pixel_values: torch.Tensor, # [B, N, patch_dim] 1204 pixel_position_ids: torch.Tensor, # [B, N, 2] 1205 ) -> torch.Tensor:
Mark padding patches: position_id == (-1, -1) signals a padding slot
1207 padding_mask = (pixel_position_ids == -1).all(dim=-1) # [B, N]
Build additive attention mask: 0 for real patches, -1e9 for padding (pre-softmax block)
1209 attn_mask = padding_mask.float() * -1e9
Expand to [B, 1, 1, N] so it broadcasts over all heads and query positions
1211 attn_mask = attn_mask[:, None, None, :] # broadcast over heads
Stage 1: project each patch to hidden_size + add 2-D positional embeddings
1214 h = self.patch_embedder(pixel_values, pixel_position_ids, padding_mask)
Stage 2: run all ViT encoder layers with 2-D RoPE and sandwich-norm attention
1216 h = self.encoder(h, pixel_position_ids, attention_mask=attn_mask)
Stage 3: k×k spatial average pooling → M soft tokens, stripping padding positions
1218 soft_tokens, _ = self.pooler(h, pixel_position_ids, padding_mask) 1219 return soft_tokens # [M, D_vision] (M = total valid pooled patches across batch)
Projects vision soft-tokens into the language model embedding space. Applies a pre-projection RMSNorm then a linear projection.
1226class MultimodalEmbedder(nn.Module):
1231 def __init__(self, vision_dim: int, text_dim: int, eps: float = 1e-6):
1232 super().__init__()
HF uses with_scale=False (no learnable weight, pure normalization)
1234 self.norm = RMSNorm(vision_dim, eps=eps, with_scale=False) 1235 self.proj = nn.Linear(vision_dim, text_dim, bias=False)
1237 def forward(self, soft_tokens: torch.Tensor) -> torch.Tensor:
1238 return self.proj(self.norm(soft_tokens))
Fuses the vision and text towers into a single forward pass:
Step 1 — Text embedding
Token ids → ScaledEmbedding (with image placeholders replaced by pad_id
so the embedding table doesn't see the sentinel token).
Step 2 — Vision encoding
Raw patches → VisionModel (ViT encoder + spatial pooler) → $M$ soft tokens
of shape $[M, D_v]$. Then MultimodalEmbedder (LayerNorm + Linear) projects
them to text dimension $D$, producing $[M, D]$.
Step 3 — Token stream merge
Image placeholder positions in the embedding sequence are overwritten by the
projected soft tokens via masked_scatter:
$$e_i = \begin{cases} \text{soft\_tok}_j & \text{if } i \text{ is an image placeholder} \\ \text{text\_emb}_i & \text{otherwise} \end{cases}$$
Step 4 — Per-layer side-channel The per-layer projection is computed from the merged embeddings (not the original text-only ids), so vision features appear in $P_\ell$ for every layer.
Step 5 — Language model
The merged embedding sequence is passed through TextModel as normal.
Vision information reaches the transformer through two routes:
(a) directly as soft tokens in the input positions, and
(b) via the per-layer side-channel $P_\ell$ injected at every decoder layer.
1241class Gemma4Model(nn.Module):
1272 def __init__(self, cfg: Gemma4Config):
1273 super().__init__() 1274 self.language_model = TextModel(cfg.text) 1275 self.vision_model = VisionModel(cfg.vision) 1276 self.mm_embedder = MultimodalEmbedder(cfg.vision.hidden_size, cfg.text.hidden_size) 1277 self.image_token_id = cfg.image_token_id
1279 def forward( 1280 self, 1281 input_ids: torch.Tensor, # [B, L] 1282 pixel_values: Optional[torch.Tensor] = None, # [B, N, patch_dim] 1283 pixel_position_ids: Optional[torch.Tensor] = None, # [B, N, 2] 1284 attention_mask: Optional[torch.Tensor] = None, 1285 per_layer_inputs: Optional[torch.Tensor] = None, # [B, L, N, Dp] pre-computed 1286 ) -> torch.Tensor:
Locate image placeholder positions in the token sequence
1288 image_mask = (input_ids == self.image_token_id) # [B, L]
Replace placeholders with pad_id so the embedding table sees valid indices
1290 text_ids = input_ids.clone() 1291 text_ids[image_mask] = self.language_model.embed_tokens.padding_idx 1292 inputs_embeds = self.language_model.embed_tokens(text_ids) # [B, L, D_text]
Compute the token-embedding half of per-layer inputs BEFORE we overwrite the placeholder positions — image slots should use pad embedding, not vision
1296 if per_layer_inputs is None: 1297 pli_embed = self.language_model._get_per_layer_embed(text_ids) 1298 else: 1299 pli_embed = None # caller provided fully-projected per_layer_inputs
── Vision path: encode pixels → soft tokens → inject into embedding sequence ──
1302 if pixel_values is not None and pixel_position_ids is not None:
ViT encoder + k×k spatial pooler → [M, D_vision] soft tokens
1304 soft_tokens = self.vision_model(pixel_values, pixel_position_ids)
Linear projection to text hidden dimension → [M, D_text]
1306 soft_tokens = self.mm_embedder(soft_tokens)
Overwrite image placeholder embeddings with projected vision features
1308 img_mask_exp = image_mask.unsqueeze(-1).expand_as(inputs_embeds) 1309 inputs_embeds = inputs_embeds.masked_scatter( 1310 img_mask_exp, soft_tokens.to(inputs_embeds.dtype) 1311 )
Recompute the projection half using MERGED embeddings so that vision features appear in the per-layer side-channel P_ℓ passed to every decoder layer
1315 if pli_embed is not None: 1316 per_layer_inputs = self.language_model._project_per_layer(inputs_embeds, pli_embed)
Run the full text tower with vision features baked into the embedding sequence
1319 return self.language_model( 1320 inputs_embeds=inputs_embeds, 1321 attention_mask=attention_mask, 1322 per_layer_inputs=per_layer_inputs, 1323 )
Text-only Gemma 4 with final logit soft-capping.
1326class Gemma4ForCausalLM(nn.Module):
1328 def __init__(self, cfg: Gemma4Config):
1329 super().__init__() 1330 self.model = TextModel(cfg.text) 1331 self.lm_head = nn.Linear(cfg.text.hidden_size, cfg.text.vocab_size, bias=False) 1332 self.final_logit_softcapping = cfg.text.final_logit_softcapping
1334 def forward( 1335 self, 1336 input_ids: torch.Tensor, 1337 attention_mask: Optional[torch.Tensor] = None, 1338 labels: Optional[torch.Tensor] = None, 1339 ) -> dict:
Run the full text decoder stack; returns final hidden states [B, L, D]
1341 hidden = self.model(input_ids, attention_mask)
Unembedding: project each position to vocabulary logits [B, L, V]
1343 logits = self.lm_head(hidden)
Logit soft-capping: tanh(z/cap)*cap squashes logits to (-cap, +cap), preventing extreme values that destabilise softmax in long-context generation
1347 if self.final_logit_softcapping is not None: 1348 cap = self.final_logit_softcapping 1349 logits = torch.tanh(logits / cap) * cap 1350 1351 loss = None 1352 if labels is not None:
Shift by 1: predict token i+1 from position i (causal LM objective)
1354 loss = F.cross_entropy( 1355 logits[:, :-1].reshape(-1, logits.shape[-1]), 1356 labels[:, 1:].reshape(-1), 1357 ignore_index=-100, 1358 ) 1359 return {"loss": loss, "logits": logits}
Full multimodal Gemma 4 (text + images).
1362class Gemma4ForConditionalGeneration(nn.Module):
1364 def __init__(self, cfg: Gemma4Config):
1365 super().__init__() 1366 self.model = Gemma4Model(cfg) 1367 self.lm_head = nn.Linear(cfg.text.hidden_size, cfg.text.vocab_size, bias=False) 1368 self.final_logit_softcapping = cfg.text.final_logit_softcapping
1370 def forward( 1371 self, 1372 input_ids: torch.Tensor, 1373 pixel_values: Optional[torch.Tensor] = None, 1374 pixel_position_ids: Optional[torch.Tensor] = None, 1375 attention_mask: Optional[torch.Tensor] = None, 1376 labels: Optional[torch.Tensor] = None, 1377 ) -> dict:
Gemma4Model handles vision encoding + token stream merge + text tower in one call
1379 hidden = self.model(input_ids, pixel_values, pixel_position_ids, attention_mask)
Project hidden states to vocabulary logits [B, L, V]
1381 logits = self.lm_head(hidden)
Soft-cap logits: same as text-only model — squashes extremes to (-cap, +cap)
1384 if self.final_logit_softcapping is not None: 1385 cap = self.final_logit_softcapping 1386 logits = torch.tanh(logits / cap) * cap 1387 1388 loss = None 1389 if labels is not None:
Causal LM cross-entropy: predict token at position i+1 from position i
1391 loss = F.cross_entropy( 1392 logits[:, :-1].reshape(-1, logits.shape[-1]), 1393 labels[:, 1:].reshape(-1), 1394 ignore_index=-100, 1395 ) 1396 return {"loss": loss, "logits": logits}
1391 1392if __name__ == "__main__": 1393 print("Running Gemma 4 smoke test (tiny config, random weights)...\n")
Tiny config to run on CPU quickly
1407 text_cfg = TextConfig( 1408 vocab_size=1000, hidden_size=64, num_hidden_layers=2, 1409 num_attention_heads=4, num_key_value_heads=2, head_dim=16, 1410 intermediate_size=128, num_experts=4, num_experts_per_tok=2, 1411 moe_layers=[1], expert_intermediate_size=64, 1412 sliding_window=8, sliding_window_pattern=2, 1413 final_logit_softcapping=30.0, rms_norm_eps=1e-6, 1414 ) 1415 vision_cfg = VisionConfig( 1416 hidden_size=32, num_hidden_layers=1, num_attention_heads=2, 1417 head_dim=16, intermediate_size=64, patch_size=4, 1418 image_size=16, pooling_kernel_size=2, 1419 ) 1420 cfg = Gemma4Config(text=text_cfg, vision=vision_cfg)
── Text-only ────────────────────────────────────────────────────────
1423 model = Gemma4ForCausalLM(cfg) 1424 model.eval() 1425 B, L = 2, 10 1426 input_ids = torch.randint(0, 1000, (B, L)) 1427 with torch.no_grad(): 1428 out = model(input_ids) 1429 print(f"[CausalLM] logits shape: {out['logits'].shape}") # [2, 10, 1000]
── Multimodal ───────────────────────────────────────────────────────
1432 mm_model = Gemma4ForConditionalGeneration(cfg) 1433 mm_model.eval()
4 image patches per image, 2-D positions, patch_dim = 443 = 48 With pooling_kernel_size=2: output_length = 4 // (2*2) = 1 soft token per image
1436 N, patch_dim = 4, 48 1437 pixel_values = torch.rand(B, N, patch_dim) 1438 pixel_position_ids = torch.tensor([[[0,0],[0,1],[1,0],[1,1]]]).expand(B, -1, -1)
Put exactly 1 placeholder per image (= 1 pooled soft token)
1440 input_ids[:, 2] = 255_999 1441 with torch.no_grad(): 1442 out_mm = mm_model(input_ids, pixel_values, pixel_position_ids) 1443 print(f"[Multimodal] logits shape: {out_mm['logits'].shape}") # [2, 10, 1000] 1444 print("\n✅ Smoke test passed!")