gemma4_simple.py

A clean, minimal PyTorch re-implementation of the Gemma 4 forward pass. Mirrors the HuggingFace source as closely as possible while stripping away framework boilerplate so the architecture is easy to read and experiment with.

Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma4/modeling_gemma4.py

Modules implemented (bottom-up order)

Shared primitives

Text tower

Vision tower

Multimodal fusion

1
2from __future__ import annotations
3import math
4from dataclasses import dataclass, field
5from typing import Optional
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
Config dataclasses (minimal, no validation)
63@dataclass
64class TextConfig:
65    vocab_size: int = 262_144
66    hidden_size: int = 2560
67    num_hidden_layers: int = 34
68    num_attention_heads: int = 16
69    num_key_value_heads: int = 8
70    head_dim: int = 256
71    intermediate_size: int = 8192        # dense MLP width

MoE

73    num_experts: int = 0                 # 0 = no MoE layers
74    num_experts_per_tok: int = 2
75    moe_layers: list[int] = field(default_factory=list)  # which layers have MoE
76    expert_intermediate_size: int = 4096

KV sharing: layers >= kv_share_from reuse KV from layer kv_share_from-1

78    kv_share_from: int | None = None

Attention

80    sliding_window: int | None = 1024   # None → global attention on all layers
81    sliding_window_pattern: int = 6     # every Nth layer uses global attn
82    rope_theta: float = 10_000.0
83    rope_local_base_freq: float = 10_000.0
84    rope_global_base_freq: float = 1_000_000.0
85    attn_logit_softcapping: float | None = None
86    final_logit_softcapping: float | None = 30.0

Global attention layers use a larger head_dim and may have different KV head count

88    global_head_dim: int | None = None           # None → same as head_dim for all layers
89    global_partial_rotary_factor: float = 1.0    # fraction of global_head_dim to rotate
90    num_global_key_value_heads: int | None = None  # None → same as num_key_value_heads

attention_k_eq_v=True: full-attention layers share K/V projection (v_proj=None, V=k_proj(x)) use_v_norm applies v_norm regardless; when attention_k_eq_v=True, v_norm is always applied

93    attention_k_eq_v: bool = False
94    use_v_norm: bool = False

Per-layer input gate (Gemma4-specific)

96    hidden_size_per_layer_input: int = 0

Misc

98    rms_norm_eps: float = 1e-6
99    pad_token_id: int = 0
100    embed_scale: float = 1.0
103@dataclass
104class VisionConfig:
105    hidden_size: int = 768
106    num_hidden_layers: int = 16
107    num_attention_heads: int = 12
108    head_dim: int = 64
109    intermediate_size: int = 3072
110    patch_size: int = 16
111    position_embedding_size: int = 10240  # lookup table size per axis
112    pooling_kernel_size: int = 3     # spatial pooling factor
113    rope_theta: float = 100.0
114    rms_norm_eps: float = 1e-6
115    standardize: bool = False        # optional post-pooling standardization
116    use_clipped_linears: bool = True  # clamp activations in ClippableLinear
119@dataclass
120class Gemma4Config:
121    text: TextConfig = field(default_factory=TextConfig)
122    vision: VisionConfig = field(default_factory=VisionConfig)

Attention: 'bidirectional_vision' → vision tokens use bidirectional mask

124    use_bidirectional_attention: str = "none"
125    image_token_id: int = 258_880   # sentinel id for image placeholder tokens
Shared primitives

Root Mean Square Layer Normalization

Unlike LayerNorm, RMSNorm omits the mean-centering step and normalises only by the root-mean-square of the activations:

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon}} \cdot w$$

where $w \in \mathbb{R}^d$ is a learned per-channel scale initialised to 1.

The computation is done in float32 for numerical stability regardless of the input dtype, then cast back — matching the HuggingFace implementation exactly. Pass with_scale=False to get the scale-free variant used for $v$-normalisation inside attention.

132class RMSNorm(nn.Module):
148    def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
149        super().__init__()
150        self.eps = eps
151        self.weight = nn.Parameter(torch.ones(dim)) if with_scale else None
153    def forward(self, x: torch.Tensor) -> torch.Tensor:

Upcast to float32 for numerical stability, then cast back

155        x_f = x.float()
156        normed = x_f * torch.rsqrt(x_f.pow(2).mean(-1, keepdim=True) + self.eps)
157        if self.weight is not None:
158            normed = normed * self.weight.float()
159        return normed.to(x.dtype)

Rotary helper: rotate the second half into the first

Splits the last dimension into two halves $[x_1, x_2]$ and returns $[-x_2, x_1]$. Combined with the cosine term this implements the complex-number rotation $x \cdot e^{i\theta}$.

162def rotate_half(x: torch.Tensor) -> torch.Tensor:
170    half = x.shape[-1] // 2
171    x1, x2 = x[..., :half], x[..., half:]
172    return torch.cat([-x2, x1], dim=-1)

1-D Rotary Position Embedding (RoPE)

Encodes absolute position into the query/key vectors by rotating pairs of channels by a position-dependent angle $\theta_i \cdot m$ (where $m$ is the token position and $\theta_i = \text{base}^{-2i/d}$):

$$x^\prime = x \cos\theta + \text{rotate\_half}(x) \sin\theta$$

Because rotation is an isometry, the dot product $q \cdot k$ depends only on the relative offset $m - n$, giving the model free relative-position information without any learned parameters.

175def apply_rotary_pos_emb(
176    x: torch.Tensor,
177    cos: torch.Tensor,
178    sin: torch.Tensor,
179    unsqueeze_dim: int = 1,
180) -> torch.Tensor:
194    cos = cos.unsqueeze(unsqueeze_dim)
195    sin = sin.unsqueeze(unsqueeze_dim)
196    return (x * cos) + (rotate_half(x) * sin)

2-D Rotary Position Embedding for Vision Patches

Vision patches have a 2-D grid position $(r, c)$. We encode both axes independently by splitting the head channels into two equal halves and applying 1-D RoPE with the row frequencies to the first half and column frequencies to the second:

$$x^\prime_{\text{row}} = \text{RoPE}(x_{:d/2},\ \theta_r),\quad x^\prime_{\text{col}} = \text{RoPE}(x_{d/2:},\ \theta_c)$$

cos/sin carry both sets of frequencies concatenated along the last dim.

199def apply_2d_rope(
200    x: torch.Tensor,
201    cos: torch.Tensor,
202    sin: torch.Tensor,
203    position_ids: torch.Tensor,
204    unsqueeze_dim: int = 2,
205) -> torch.Tensor:
219    ndim = position_ids.shape[-1]  # should be 2
220    channels_per_dim = 2 * (x.shape[-1] // (2 * ndim))
221    x_parts = torch.split(x, [channels_per_dim] * ndim, dim=-1)
222    cos_parts = torch.split(cos, [channels_per_dim] * ndim, dim=-1)
223    sin_parts = torch.split(sin, [channels_per_dim] * ndim, dim=-1)
224    rotated = [
225        apply_rotary_pos_emb(x_parts[k], cos_parts[k], sin_parts[k], unsqueeze_dim)
226        for k in range(ndim)
227    ]
228    return torch.cat(rotated, dim=-1)
Text Tower

Dual-frequency RoPE for Gemma 4 text. Local layers use rope_local_base_freq / head_dim. Global layers use rope_global_base_freq / global_head_dim, with optional partial rotation (partial_rotary_factor): only rope_angles = int(factor * global_head_dim // 2) dimensions are rotated; the rest carry zero inv_freq (→ cos=1, sin=0 = identity). Matches HF's _compute_proportional_rope_parameters exactly.

235class TextRotaryEmbedding(nn.Module):
244    def __init__(self, cfg: TextConfig):
245        super().__init__()

Local (sliding attention) — full rotation over head_dim

247        local_dim = cfg.head_dim
248        local_inv_freq = 1.0 / (
249            cfg.rope_local_base_freq ** (torch.arange(0, local_dim, 2).float() / local_dim)
250        )
251        self.register_buffer("local_inv_freq", local_inv_freq, persistent=False)
252        self._local_head_dim = local_dim

Global (full attention) — proportional/partial rotation over global_head_dim

255        global_dim = cfg.global_head_dim if cfg.global_head_dim is not None else cfg.head_dim
256        self._global_head_dim = global_dim
257        partial = getattr(cfg, "global_partial_rotary_factor", 1.0)

rope_angles: number of (freq, freq) pairs that are actually rotated

259        rope_angles = int(partial * global_dim // 2)

inv_freq for the rotated part: denominator is global_head_dim (matches HF)

261        rotated_inv = 1.0 / (
262            cfg.rope_global_base_freq
263            ** (torch.arange(0, 2 * rope_angles, 2).float() / global_dim)
264        )

Pad the remaining nope dimensions with zero (identity: cos=1, sin=0)

266        nope_count = global_dim // 2 - rope_angles
267        if nope_count > 0:
268            global_inv_freq = torch.cat(
269                [rotated_inv, torch.zeros(nope_count)], dim=0
270            )
271        else:
272            global_inv_freq = rotated_inv
273        self.register_buffer("global_inv_freq", global_inv_freq, persistent=False)
275    @torch.no_grad()
276    def forward(
277        self,
278        x: torch.Tensor,
279        position_ids: torch.Tensor,
280        layer_type: str = "global",   # "local" or "global"
281    ):

Select inv_freq table for this layer type (local=small base, global=large base)

283        inv_freq = getattr(self, f"{layer_type}_inv_freq")

Outer product: inv_freq [D/2] × position [L] → [B, L, D/2] Each entry (b,l,i) = inv_freq[i] * position_ids[b,l]

286        inv_freq_exp = inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)
287        pos_exp = position_ids[:, None, :].float()
288        freqs = (inv_freq_exp.float() @ pos_exp.float()).transpose(1, 2)

Duplicate to get full head_dim: [cos(θ_0·m), …, cos(θ_{D/2}·m), cos(θ_0·m), …] This layout means the second half mirrors the first, enabling rotate_half

291        emb = torch.cat([freqs, freqs], dim=-1)
292        return emb.cos().to(x.dtype), emb.sin().to(x.dtype)

Token embedding scaled by sqrt(hidden_size) as in Gemma. The scale is stored as a float32 buffer and cast to the weight dtype at forward time, matching HF's Gemma4TextScaledWordEmbedding exactly (important for bf16).

295class ScaledEmbedding(nn.Embedding):
300    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int,
301                 embed_scale: float = 1.0):
302        super().__init__(num_embeddings, embedding_dim, padding_idx)
303        self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
305    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
306        return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)

SwiGLU Feed-Forward Network

A gated variant of the FFN where one linear branch acts as a gate:

$$\text{MLP}(x) = W_{\text{down}}\bigl(\text{GELU}(W_{\text{gate}}\, x) \odot W_{\text{up}}\, x\bigr)$$

The GELU gate (tanh approximation) selectively suppresses or amplifies each dimension of the intermediate representation before the final down-projection, giving the network a multiplicative, content-dependent nonlinearity at essentially no parameter cost beyond the extra gate projection.

Used as the dense shared FFN that runs on every layer. In MoE layers it runs in parallel with the sparse expert bank and both outputs are summed.

309class TextMLP(nn.Module):
325    def __init__(self, hidden_size: int, intermediate_size: int):
326        super().__init__()
327        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
328        self.up_proj   = nn.Linear(hidden_size, intermediate_size, bias=False)
329        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
331    def forward(self, x: torch.Tensor) -> torch.Tensor:

GELU(gate) ⊙ up → down

333        return self.down_proj(F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))

Grouped-Query Attention (GQA)

Standard scaled dot-product attention with several Gemma-4 twists:

$$\text{Attn}(Q,K,V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$

QK-normalisation — RMSNorm is applied to each query and key head before RoPE. Because the norms are learnable scales, the effective temperature is absorbed into the weights, so the scaling constant is $1.0$ rather than $1/\sqrt{d_k}$.

V-normalisation — A scale-free RMSNorm is applied to values, stabilising training in mixed-precision without introducing extra parameters.

Grouped-Query Attention — there are fewer KV heads than Q heads ($H_{kv} < H_q$). Each KV head is shared across $H_q / H_{kv}$ query heads via expand + reshape before the dot product.

Local vs global layers — layers at positions divisible by sliding_window_pattern use full (global) attention; the rest apply a causal sliding-window mask of size sliding_window, limiting each token to attending only the nearest $W$ positions.

KV sharing — the last num_kv_shared_layers layers reuse the key/value states computed by the layer just before the shared block starts, saving memory without a significant quality drop.

attention_k_eq_v — in global layers of large models, $V$ is derived from the same linear projection as $K$ (i.e. v_proj = None and $V = k\_proj(x)$ before k_norm). This halves the KV projection cost.

336class TextAttention(nn.Module):
369    def __init__(self, cfg: TextConfig, layer_idx: int, is_kv_shared: bool = False):
370        super().__init__()
371        self.layer_idx = layer_idx
372        self.num_heads = cfg.num_attention_heads
373        self.is_kv_shared = is_kv_shared

Use _layer_types if set (authoritative), otherwise fall back to pattern

375        _layer_types = getattr(cfg, "_layer_types", None)
376        is_global = (
377            _layer_types[layer_idx] == "full_attention"
378            if _layer_types is not None
379            else (layer_idx % cfg.sliding_window_pattern == 0)
380        )
381        self.sliding_window = None if is_global else cfg.sliding_window

Global attention layers may use a larger head_dim and different KV head count

383        self.head_dim = (
384            cfg.global_head_dim if (is_global and cfg.global_head_dim is not None)
385            else cfg.head_dim
386        )
387        self.num_kv_heads = (
388            cfg.num_global_key_value_heads
389            if (is_global and cfg.num_global_key_value_heads is not None)
390            else cfg.num_key_value_heads
391        )
392
393        hs = cfg.hidden_size
394        kv_dim = self.num_kv_heads * self.head_dim
395
396        self.q_proj = nn.Linear(hs, self.num_heads * self.head_dim, bias=False)

KV-shared layers still have k/v projections; sharing only applies with a KV cache

398        self.k_proj = nn.Linear(hs, kv_dim, bias=False)

attention_k_eq_v: full-attention layers share K projection for V (v_proj=None) In that case V = v_norm(k_proj(x)) (same input projection, different norm, no RoPE)

401        self.use_alternative_attention = cfg.attention_k_eq_v and is_global
402        self.v_proj = (
403            None if self.use_alternative_attention
404            else nn.Linear(hs, kv_dim, bias=False)
405        )
406        self.o_proj = nn.Linear(self.num_heads * self.head_dim, hs, bias=False)

Per-head norms (applied before RoPE)

409        self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)
410        self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)

v_norm: no learnable scale unless use_v_norm=True (4B+); always applied when k_eq_v

412        self.v_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps, with_scale=cfg.use_v_norm)
413
414        self.attn_logit_softcapping = cfg.attn_logit_softcapping

QK-norm (RMSNorm on q and k) is always applied, so scaling is 1.0, matching HF.

416        self.scaling = 1.0
418    def forward(
419        self,
420        hidden_states: torch.Tensor,          # [B, L, D]
421        cos: torch.Tensor,
422        sin: torch.Tensor,
423        attention_mask: Optional[torch.Tensor] = None,
424        kv_cache: Optional[dict] = None,      # simple dict cache for demo
425    ) -> torch.Tensor:
426        B, L, _ = hidden_states.shape
427        H, Hkv, Dh = self.num_heads, self.num_kv_heads, self.head_dim

── Queries ─────────────────────────────────────────────────────

430        q = self.q_proj(hidden_states).view(B, L, H, Dh)
431        q = self.q_norm(q)
432        q = apply_rotary_pos_emb(q, cos, sin, unsqueeze_dim=2)
433        q = q.transpose(1, 2)  # [B, H, L, Dh]

── Keys & Values ────────────────────────────────────────────────

436        if self.is_kv_shared and kv_cache is not None and "shared_kv" in kv_cache:

Reuse the KV states stored by the designated anchor layer

438            k, v = kv_cache["shared_kv"]
439        else:
440            k_raw = self.k_proj(hidden_states).view(B, L, Hkv, Dh)

attention_k_eq_v: V uses the same raw k projection (before k_norm and RoPE)

442            v_raw = k_raw if self.use_alternative_attention else self.v_proj(hidden_states).view(B, L, Hkv, Dh)
443            k = self.k_norm(k_raw)
444            v = self.v_norm(v_raw)
445            k = apply_rotary_pos_emb(k, cos, sin, unsqueeze_dim=2)
446            k = k.transpose(1, 2)  # [B, Hkv, L, Dh]
447            v = v.transpose(1, 2)

── GQA: expand KV heads to match Q heads ─────────────────────── Each KV head is broadcast to H/Hkv query heads by inserting a repeat dim

451        if Hkv != H:
452            expand = H // Hkv
453            k = k.unsqueeze(2).expand(-1, -1, expand, -1, -1).reshape(B, H, -1, Dh)
454            v = v.unsqueeze(2).expand(-1, -1, expand, -1, -1).reshape(B, H, -1, Dh)

── Scaled dot-product attention ──────────────────────────────── scaling=1.0 because QK-norms already control the variance (no 1/√d needed)

458        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scaling

Optional logit soft-capping: tanh(logit/cap) * cap keeps logits in (-cap, cap) This prevents softmax collapse in very long contexts

462        if self.attn_logit_softcapping is not None:
463            attn = attn / self.attn_logit_softcapping
464            attn = torch.tanh(attn)
465            attn = attn * self.attn_logit_softcapping

Causal / sliding-window mask: upper-triangle = -inf, recent window = 0

468        if attention_mask is not None:
469            attn = attn + attention_mask

Upcast to float32 for numerically stable softmax (matches HF eager impl)

472        attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(hidden_states.dtype)

Weighted sum of values, merge heads, project back to hidden_size

474        out = torch.matmul(attn, v)          # [B, H, L, Dh]
475        out = out.transpose(1, 2).reshape(B, L, H * Dh)
476        return self.o_proj(out)

MoE Token Router

Selects the $K$ best experts for each token and computes their routing weights.

For a token representation $x \in \mathbb{R}^D$:

  1. Normalise — scale-free RMSNorm stabilises the router input.
  2. Scale — element-wise multiply by a learned per-dimension scale $s \in \mathbb{R}^D$, then multiply by $D^{-1/2}$.
  3. Logits — linear projection to $E$ expert scores.
  4. Softmax → top-$K$ selection → renormalise the top-$K$ weights to sum to 1.
  5. Per-expert scale — multiply each weight by a learned scalar $\alpha_e$ (one per expert, init 1), giving the model a way to globally up- or down-weight specific experts during training.
479class TextRouter(nn.Module):
496    def __init__(self, cfg: TextConfig):
497        super().__init__()
498        self.top_k = cfg.num_experts_per_tok

Bug #2 fix: HF uses RMSNorm with_scale=False (pure normalisation, no learnable weight)

500        self.norm  = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps, with_scale=False)

Bug #1 fix: HF has a learned per-dim scale vector (NOT a scalar sqrt(D))

502        self.scale = nn.Parameter(torch.ones(cfg.hidden_size))
503        self._scale_factor = cfg.hidden_size ** -0.5

Projection to expert logits

505        self.proj  = nn.Linear(cfg.hidden_size, cfg.num_experts, bias=False)

Per-expert learned output scale

507        self.per_expert_scale = nn.Parameter(torch.ones(cfg.num_experts))

x: [T, D] (T = batch*seq flattened)

509    def forward(self, x: torch.Tensor):

Step 1: normalise + scale → stable router input (zero-mean, unit-ish variance) norm is scale-free RMSNorm; self.scale is learned per-dim; _scale_factor = D^{-0.5}

513        h = self.norm(x) * self.scale * self._scale_factor

Step 2: project to E expert logits — one score per expert per token

515        logits = self.proj(h)                                   # [T, E]

Step 3: softmax over experts → probability distribution for each token

517        probs  = F.softmax(logits, dim=-1)                      # [T, E]

Step 4: top-K selection — choose the K highest-probability experts

519        top_w, top_idx = torch.topk(probs, self.top_k, dim=-1)  # [T, K]

Step 5: renormalise the K weights to sum to 1 (sparse softmax)

521        top_w = top_w / top_w.sum(dim=-1, keepdim=True)

Step 6: apply learned per-expert scalar α_e — allows global expert re-weighting

523        top_w = top_w * self.per_expert_scale[top_idx]

Return full prob distribution (for aux loss), top-K weights, and expert indices

525        return probs, top_w, top_idx

Sparse Expert Bank (MoE)

Holds $E$ independent SwiGLU FFNs. For each token only the $K$ experts selected by the router are evaluated — the rest are skipped entirely, keeping compute proportional to $K/E$ of the full dense cost.

Weight layoutgate_up_proj[e] $\in \mathbb{R}^{2D_i \times D}$ and down_proj[e] $\in \mathbb{R}^{D \times D_i}$ are stacked along the first dimension so that loading a single expert is a single tensor slice.

Dispatch loop — we build an [E, K, T] one-hot mask, iterate only over active experts (those that received ≥ 1 token), gather the relevant token vectors, run the SwiGLU FFN, and scatter-add results back via index_add_.

Numerics noteindex_add_ accumulates contributions in a deterministic order (one token at a time), matching HuggingFace's eager implementation bit-exactly in bfloat16. The alternative grouped_mm kernel accumulates in a different order, producing ~0.06 diff per layer that grows to ~25 max_diff over 30 layers (see TEST_REPORT_26B_A4B.md).

528class TextExperts(nn.Module):
550    def __init__(self, cfg: TextConfig):
551        super().__init__()
552        E  = cfg.num_experts
553        D  = cfg.hidden_size
554        Di = cfg.expert_intermediate_size

Bug #3 fix: match HF weight layout [E, out, in] so checkpoint loads without permute

556        self.gate_up_proj = nn.Parameter(torch.empty(E, 2 * Di, D))
557        self.down_proj    = nn.Parameter(torch.empty(E, D, Di))
558        nn.init.normal_(self.gate_up_proj, std=0.02)
559        nn.init.normal_(self.down_proj,    std=0.02)
560        self.num_experts = E
561        self.act = nn.GELU(approximate="tanh")
563    def forward(
564        self,
565        x: torch.Tensor,           # [T, D]
566        top_k_index: torch.Tensor,  # [T, K]
567        top_k_weights: torch.Tensor, # [T, K]
568    ) -> torch.Tensor:
569        T, D = x.shape

Accumulator: we scatter expert outputs back into this tensor

571        out = torch.zeros(T, D, dtype=x.dtype, device=x.device)

Build a 3-way membership mask then transpose to expert-major order expert_mask[e, k, t] = 1 iff token t was assigned to expert e in slot k

575        expert_mask = F.one_hot(top_k_index, self.num_experts)  # [T, K, E]
576        expert_mask = expert_mask.permute(2, 1, 0)              # [E, K, T]

Only iterate over experts that actually received tokens (skip dead experts)

578        active_experts = (expert_mask.sum(dim=(1, 2)) > 0).nonzero(as_tuple=True)[0]
579
580        for expert_idx in active_experts:
581            e = expert_idx.item()

Find which (k-slot, token) pairs are assigned to expert e

583            k_slot, tok_idx = torch.where(expert_mask[e])   # [n], [n]
584            h = x[tok_idx]                                   # [n, D]  — gather tokens

SwiGLU FFN for this expert (gate and up projections packed into one matrix)

586            gate, up = F.linear(h, self.gate_up_proj[e]).chunk(2, dim=-1)
587            h = self.act(gate) * up                          # GELU gate
588            h = F.linear(h, self.down_proj[e])               # [n, D]

Scale by routing weight (how much this expert contributes for this token)

590            h = h * top_k_weights[tok_idx, k_slot, None]

index_add_: deterministic accumulation order → bit-exact with HF eager

592            out.index_add_(0, tok_idx, h.to(out.dtype))
593
594        return out

Gemma 4 Transformer Decoder Layer

Each layer runs the following sequence of operations:

1. Self-Attention block $$h = x + \text{PostAttnNorm}\!\left(\text{Attn}\!\left(\text{PreNorm}(x)\right)\right)$$

2. Dense MLP (always active, even in MoE layers) $$h_{\text{mlp}} = \text{MLP}\!\left(\text{PreFFNNorm}(h)\right)$$

3. Sparse MoE (only in MoE-designated layers; 26B model)

The dense MLP and the sparse expert bank run in parallel on the same pre-FFN residual and their outputs are added: $$h = h + \text{PostFFNNorm}\!\left(h_{\text{mlp}} + h_{\text{moe}}\right)$$

4. Per-layer input gate (E4B and larger models)

A lightweight gating network injects a per-layer side-channel $p_\ell$ (derived from the full token embeddings) at every layer: $$h = h + \text{Norm}\!\left(W_{\text{proj}}\!\left(\text{GELU}(W_{\text{gate}}\,h) \odot p_\ell\right)\right)$$

This allows every layer to directly reference the original token context, acting as a persistent residual information highway.

5. Layer scalar

The entire layer output is multiplied by a learned scalar $\lambda$ (init 1), giving the optimiser a soft mechanism to reduce a layer's contribution early in training.

597class TextDecoderLayer(nn.Module):
630    def __init__(self, cfg: TextConfig, layer_idx: int):
631        super().__init__()
632        is_kv_shared = (cfg.kv_share_from is not None and layer_idx >= cfg.kv_share_from)
633        self.self_attn = TextAttention(cfg, layer_idx, is_kv_shared=is_kv_shared)
634
635        self.input_layernorm         = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
636        self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
637        self.pre_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
638        self.post_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
639
640        self.mlp = TextMLP(cfg.hidden_size, cfg.intermediate_size)
641
642        self.enable_moe = layer_idx in cfg.moe_layers and cfg.num_experts > 0
643        if self.enable_moe:
644            self.router  = TextRouter(cfg)
645            self.experts = TextExperts(cfg)
646            self.pre_feedforward_layernorm_2  = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
647            self.post_feedforward_layernorm_1 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
648            self.post_feedforward_layernorm_2 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)

Per-layer input gate

651        self.use_per_layer_gate = cfg.hidden_size_per_layer_input > 0
652        if self.use_per_layer_gate:
653            D, Dp = cfg.hidden_size, cfg.hidden_size_per_layer_input
654            self.per_layer_input_gate = nn.Linear(D, Dp, bias=False)
655            self.per_layer_projection = nn.Linear(Dp, D, bias=False)
656            self.post_per_layer_input_norm = RMSNorm(D, eps=cfg.rms_norm_eps)
657            self.act_fn = nn.GELU(approximate="tanh")

Learned output scalar (initialised to 1)

660        self.layer_scalar = nn.Parameter(torch.ones(1))
662    def forward(
663        self,
664        hidden_states: torch.Tensor,
665        cos: torch.Tensor,
666        sin: torch.Tensor,
667        attention_mask: Optional[torch.Tensor] = None,
668        per_layer_input: Optional[torch.Tensor] = None,
669        kv_cache: Optional[dict] = None,
670    ) -> torch.Tensor:

── 1. Self-Attention ────────────────────────────────────────────

673        residual = hidden_states
674        hidden_states = self.input_layernorm(hidden_states)
675        hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask, kv_cache)
676        hidden_states = self.post_attention_layernorm(hidden_states)
677        hidden_states = residual + hidden_states

── 2. Dense MLP (always) ────────────────────────────────────────

680        residual = hidden_states
681        hidden_states = self.pre_feedforward_layernorm(hidden_states)
682        hidden_states = self.mlp(hidden_states)

── 3. MoE block (some layers) ───────────────────────────────────

685        if self.enable_moe:

Separate post-norm path for the dense MLP output (26B model only). The dense MLP and sparse MoE run in parallel on the same pre-FFN residual and their contributions are summed before the shared post-FFN norm below.

689            h_mlp = self.post_feedforward_layernorm_1(hidden_states)

Flatten to [B*L, D] so the router and experts see tokens, not (batch, seq)

692            flat = residual.reshape(-1, residual.shape[-1])       # [B*L, D]

Router selects top-K experts for each token: returns routing weights + indices

694            _, top_w, top_idx = self.router(flat)

Pre-norm before expert FFNs (separate norm for MoE path)

696            h_moe = self.pre_feedforward_layernorm_2(flat)

Sparse expert dispatch: only active experts run; scatter back via index_add_

698            h_moe = self.experts(h_moe, top_idx, top_w)
699            h_moe = h_moe.reshape(residual.shape)

Post-norm on MoE output (separate from the dense MLP norm above)

701            h_moe = self.post_feedforward_layernorm_2(h_moe)

Dense MLP + sparse MoE contributions are combined additively

704            hidden_states = h_mlp + h_moe                         # combine
705
706        hidden_states = self.post_feedforward_layernorm(hidden_states)
707        hidden_states = residual + hidden_states

── 4. Per-layer input gate ──────────────────────────────────────

710        if self.use_per_layer_gate and per_layer_input is not None:
711            residual = hidden_states
712            gate = self.act_fn(self.per_layer_input_gate(hidden_states))
713            hidden_states = gate * per_layer_input
714            hidden_states = self.per_layer_projection(hidden_states)
715            hidden_states = self.post_per_layer_input_norm(hidden_states)
716            hidden_states = residual + hidden_states

── 5. Layer scalar ──────────────────────────────────────────────

719        hidden_states = hidden_states * self.layer_scalar
720        return hidden_states

Full Text Tower

Stacks $N$ TextDecoderLayers with a shared TextRotaryEmbedding that pre-computes $(cos, sin)$ tensors for the full input sequence length.

Per-layer input pipeline (E4B model, hidden_size_per_layer_input &gt; 0):

Before the transformer runs, a compact side-channel tensor $P \in \mathbb{R}^{B \times L \times N \times D_p}$ is computed by blending two projections of the input:

$$P = \frac{1}{\sqrt{2}}\;\text{Norm}\!\left(\frac{W_{\text{proj}}\,x}{\sqrt{D}} + E_{\text{tok}}\right)$$

where $E_{\text{tok}}$ is a second token embedding table (scaled by $\sqrt{D_p}$). The $\ell$-th slice $P_{:,:,\ell,:}$ is passed as per_layer_input to layer $\ell$.

When inputs_embeds is provided instead of input_ids (multimodal path), the per-layer projection is computed from the merged embeddings (vision + text), so image features influence the side-channel at every layer.

723class TextModel(nn.Module):
745    def __init__(self, cfg: TextConfig):
746        super().__init__()
747        self.embed_tokens = ScaledEmbedding(
748            cfg.vocab_size, cfg.hidden_size, cfg.pad_token_id, cfg.embed_scale
749        )
750        self.rotary_emb = TextRotaryEmbedding(cfg)
751        self.layers = nn.ModuleList([TextDecoderLayer(cfg, i) for i in range(cfg.num_hidden_layers)])
752        self.norm   = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
753        self.sliding_window_pattern = cfg.sliding_window_pattern
754        self._layer_types = getattr(cfg, "_layer_types", None)

Per-layer input gate embedding and projection chain (optional; E4B and larger). Matches HF Gemma4TextModel exactly: embed_tokens_per_layer : Embedding(vocab, NDp), scaled by sqrt(Dp) per_layer_model_projection : Linear(D, NDp), scaled by D-0.5 per_layer_projection_norm : RMSNorm(Dp) per_layer_input_scale : 2-0.5 (constant)

762        Dp = cfg.hidden_size_per_layer_input
763        N  = cfg.num_hidden_layers
764        if Dp > 0:
765            self.embed_tokens_per_layer = ScaledEmbedding(
766                cfg.vocab_size, N * Dp, cfg.pad_token_id, embed_scale=Dp ** 0.5
767            )
768            self.per_layer_model_projection = nn.Linear(cfg.hidden_size, N * Dp, bias=False)
769            self.per_layer_model_projection_scale = cfg.hidden_size ** -0.5
770            self.per_layer_projection_norm = RMSNorm(Dp, eps=cfg.rms_norm_eps)
771            self.per_layer_input_scale = 2.0 ** -0.5
772        else:
773            self.embed_tokens_per_layer = None
774        self._num_layers = N
775        self._per_layer_dim = Dp

Token-embedding part only — [B, L, N, Dp]. No projection yet.

777    def _get_per_layer_embed(self, input_ids: torch.Tensor) -> Optional[torch.Tensor]:
779        if self.embed_tokens_per_layer is None:
780            return None
781        return self.embed_tokens_per_layer(input_ids).reshape(
782            *input_ids.shape, self._num_layers, self._per_layer_dim
783        )

Add the inputs_embeds projection to embed_part and scale. → [B, L, N, Dp].

785    def _project_per_layer(self, inputs_embeds: torch.Tensor,
786                            embed_part: torch.Tensor) -> torch.Tensor:
788        proj = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale
789        proj = proj.reshape(*inputs_embeds.shape[:-1], self._num_layers, self._per_layer_dim)
790        proj = self.per_layer_projection_norm(proj)
791        return (proj + embed_part) * self.per_layer_input_scale

Compute [B, L, N, Dp] per-layer inputs (text-only path).

793    def _compute_per_layer_inputs(self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor) -> Optional[torch.Tensor]:
795        embed = self._get_per_layer_embed(input_ids)
796        if embed is None:
797            return None
798        return self._project_per_layer(inputs_embeds, embed)
800    def forward(
801        self,
802        input_ids: Optional[torch.Tensor] = None,
803        attention_mask: Optional[torch.Tensor] = None,
804        kv_cache: Optional[dict] = None,
805        per_layer_inputs: Optional[torch.Tensor] = None,  # [B, L, num_layers, Dp] if pre-computed
806        inputs_embeds: Optional[torch.Tensor] = None,     # pre-merged embeddings (multimodal)
807    ) -> torch.Tensor:

Accept either token ids (text-only) or pre-built embeddings (multimodal path)

809        if inputs_embeds is None:
810            x = self.embed_tokens(input_ids)                # [B, L, D]
811        else:
812            x = inputs_embeds                               # [B, L, D], already embedded
813
814        B, L, _ = x.shape

Simple 0..L-1 position indices — causal order; no offset needed for prefill

816        position_ids = torch.arange(L, device=x.device).unsqueeze(0)  # [1, L]

Compute the full per-layer side-channel P ∈ [B, L, N, Dp] once up front (text-only path; multimodal path pre-computes this with merged embeddings)

820        if per_layer_inputs is None and input_ids is not None and self.embed_tokens_per_layer is not None:
821            per_layer_inputs = self._compute_per_layer_inputs(input_ids, x)
822
823        for i, layer in enumerate(self.layers):

Alternate between local (sliding-window) and global RoPE frequencies

825            if self._layer_types is not None:
826                layer_type = "local" if self._layer_types[i] == "sliding_attention" else "global"
827            else:
828                layer_type = "local" if (i % self.sliding_window_pattern != 0) else "global"
829            cos, sin = self.rotary_emb(x, position_ids, layer_type=layer_type)

Slice the i-th layer's per-layer input vector — [B, L, Dp]

831            pli = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None
832            x = layer(x, cos, sin, attention_mask=attention_mask,
833                      per_layer_input=pli, kv_cache=kv_cache)

Final RMSNorm before the output projection / logit head

836        return self.norm(x)
Vision Tower

2-D RoPE for vision patches. pixel_position_ids: [B, N, 2] (x, y per patch; -1 = padding) Returns cos/sin with shape [B, N, head_dim] where the head_dim is split evenly between x and y rotations.

843class VisionRotaryEmbedding(nn.Module):
850    def __init__(self, head_dim: int, base: float = 10_000.0):
851        super().__init__()
852        dim = head_dim // 2   # half for x, half for y
853        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
854        self.register_buffer("inv_freq", inv_freq, persistent=False)
856    @torch.no_grad()

x: [B, N, D]; pixel_position_ids: [B, N, 2]

857    def forward(self, x: torch.Tensor, pixel_position_ids: torch.Tensor):
859        all_cos, all_sin = [], []

inv_freq shape [F] → expand to [B, F, 1] for matrix multiply with positions

861        inv = self.inv_freq[None, :, None].expand(pixel_position_ids.shape[0], -1, 1)
862        for dim_i in range(2):

Process x-axis (dim_i=0) then y-axis (dim_i=1) separately

864            pos = pixel_position_ids[:, :, dim_i]           # [B, N] integer patch position
865            pos_exp = pos[:, None, :].float()               # [B, 1, N] for broadcast matmul

Outer product: frequencies × positions → [B, F, N] → transpose to [B, N, F]

867            freqs = (inv.float() @ pos_exp.float()).transpose(1, 2)  # [B, N, dim/2]

Duplicate freqs to fill full half-head_dim (standard RoPE cos/sin embedding)

869            emb = torch.cat([freqs, freqs], dim=-1)         # [B, N, dim]
870            all_cos.append(emb.cos())
871            all_sin.append(emb.sin())

Concatenate x-axis and y-axis embeddings → [B, N, head_dim] where head_dim = 2*dim

873        cos = torch.cat(all_cos, dim=-1).to(x.dtype)       # [B, N, head_dim]
874        sin = torch.cat(all_sin, dim=-1).to(x.dtype)
875        return cos, sin

Project pixel patches to hidden_size and add learned 2-D positional embeddings.

Matches HF Gemma4VisionPatchEmbedder exactly:

878class VisionPatchEmbedder(nn.Module):
887    def __init__(self, cfg: VisionConfig):
888        super().__init__()
889        patch_dim = cfg.patch_size * cfg.patch_size * 3
890        self.input_proj = nn.Linear(patch_dim, cfg.hidden_size, bias=False)

Single parameter table covering both spatial axes, matching HF layout

892        self.position_embedding_table = nn.Parameter(
893            torch.ones(2, cfg.position_embedding_size, cfg.hidden_size)
894        )
895        self.position_embedding_size = cfg.position_embedding_size
897    def _position_embeddings(
898        self,
899        pixel_position_ids: torch.Tensor,  # [B, N, 2]
900        padding_mask: torch.Tensor,        # [B, N] True = padding
901    ) -> torch.Tensor:
902        pos = pixel_position_ids.clamp(min=0)  # [B, N, 2]

one_hot: [B, N, 2, position_embedding_size]

904        one_hot = F.one_hot(pos, num_classes=self.position_embedding_size)

permute to [B, 2, N, pos_emb_size] for matmul with table [2, pos_emb_size, D]

906        one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)

[B, 2, N, D] → sum over axes → [B, N, D]

908        pos_embed = (one_hot @ self.position_embedding_table).sum(dim=1)

Zero out padding

910        pos_embed = torch.where(padding_mask.unsqueeze(-1), torch.zeros_like(pos_embed), pos_embed)
911        return pos_embed
913    def forward(
914        self,
915        pixel_values: torch.Tensor,         # [B, N, patch_dim]
916        pixel_position_ids: torch.Tensor,   # [B, N, 2]  (x, y; -1 = padding)
917        padding_mask: torch.Tensor,         # [B, N] True = padding
918    ) -> torch.Tensor:

Normalise pixel values from [0, 1] to [-1, 1] (standard ViT preprocessing)

920        pixel_values = 2.0 * (pixel_values - 0.5)

Linear projection: each flattened patch (patch_size² × 3 channels) → hidden_size

922        h = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))

Add 2-D learned position embeddings (row + column axes combined, see _position_embeddings)

924        h = h + self._position_embeddings(pixel_position_ids, padding_mask)

Zero out padding patch positions so they don't contaminate downstream computation

926        h = h.masked_fill(padding_mask.unsqueeze(-1), 0.0)
927        return h

nn.Linear wrapped with optional input/output clamping (matches HF Gemma4ClippableLinear). Clip bounds are stored as buffers and loaded from checkpoint. When use_clipped_linears=False (default), bounds stay at ±inf (no-op clamp).

930class ClippableLinear(nn.Module):
936    def __init__(self, in_features: int, out_features: int, use_clipped_linears: bool = False):
937        super().__init__()
938        self.linear = nn.Linear(in_features, out_features, bias=False)
939        if use_clipped_linears:
940            self.register_buffer("input_min",  torch.tensor(-float("inf")))
941            self.register_buffer("input_max",  torch.tensor( float("inf")))
942            self.register_buffer("output_min", torch.tensor(-float("inf")))
943            self.register_buffer("output_max", torch.tensor( float("inf")))
944        self.use_clipped_linears = use_clipped_linears
946    def forward(self, x: torch.Tensor) -> torch.Tensor:
947        if self.use_clipped_linears:
948            x = torch.clamp(x, self.input_min, self.input_max)
949        x = self.linear(x)
950        if self.use_clipped_linears:
951            x = torch.clamp(x, self.output_min, self.output_max)
952        return x

SwiGLU FFN used in the vision encoder (matches HF Gemma4VisionMLP).

955class VisionMLP(nn.Module):
957    def __init__(self, cfg: VisionConfig):
958        super().__init__()
959        clip = cfg.use_clipped_linears
960        self.gate_proj = ClippableLinear(cfg.hidden_size, cfg.intermediate_size, clip)
961        self.up_proj   = ClippableLinear(cfg.hidden_size, cfg.intermediate_size, clip)
962        self.down_proj = ClippableLinear(cfg.intermediate_size, cfg.hidden_size, clip)
964    def forward(self, x: torch.Tensor) -> torch.Tensor:
965        return self.down_proj(F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))

ViT-style full attention with: • QKV norms (same as text, for training stability) • 2-D RoPE applied to Q and K

968class VisionAttention(nn.Module):
974    def __init__(self, cfg: VisionConfig):
975        super().__init__()
976        H, Dh = cfg.num_attention_heads, cfg.head_dim
977        D = cfg.hidden_size
978        self.num_heads = H
979        self.head_dim  = Dh

HF Gemma4VisionAttention uses scaling=1.0, not 1/sqrt(head_dim)

981        self.scaling   = 1.0
982
983        clip = cfg.use_clipped_linears
984        self.q_proj = ClippableLinear(D, H * Dh, clip)
985        self.k_proj = ClippableLinear(D, H * Dh, clip)
986        self.v_proj = ClippableLinear(D, H * Dh, clip)
987        self.o_proj = ClippableLinear(H * Dh, D, clip)
988
989        self.q_norm = RMSNorm(Dh, eps=cfg.rms_norm_eps, with_scale=True)
990        self.k_norm = RMSNorm(Dh, eps=cfg.rms_norm_eps, with_scale=True)

v_norm has NO learnable scale (with_scale=False) — just bare normalisation

992        self.v_norm = RMSNorm(Dh, eps=cfg.rms_norm_eps, with_scale=False)
993
994        self.rotary_emb = VisionRotaryEmbedding(Dh, cfg.rope_theta)
996    def forward(
997        self,
998        hidden_states: torch.Tensor,          # [B, N, D]
999        pixel_position_ids: torch.Tensor,     # [B, N, 2]
1000        attention_mask: Optional[torch.Tensor] = None,
1001    ) -> torch.Tensor:
1002        B, N, _ = hidden_states.shape
1003        H, Dh = self.num_heads, self.head_dim

Project to Q, K, V and reshape to per-head tensors [B, N, H, Dh]

1006        q = self.q_proj(hidden_states).view(B, N, H, Dh)
1007        k = self.k_proj(hidden_states).view(B, N, H, Dh)
1008        v = self.v_proj(hidden_states).view(B, N, H, Dh)

Per-head RMSNorm: stabilises attention logit scale, scaling=1.0 (no 1/√d)

1011        q = self.q_norm(q)
1012        k = self.k_norm(k)

v_norm is scale-free (no learnable weight) — just normalises the magnitude

1014        v = self.v_norm(v)

Compute 2-D RoPE frequencies from (row, col) patch positions

1017        cos, sin = self.rotary_emb(hidden_states, pixel_position_ids)

Apply 2-D RoPE: first head_dim/2 channels encode row, second half encode col

1019        q = apply_2d_rope(q, cos, sin, pixel_position_ids, unsqueeze_dim=2)
1020        k = apply_2d_rope(k, cos, sin, pixel_position_ids, unsqueeze_dim=2)

Transpose to [B, H, N, Dh] for batched matrix multiply

1023        q = q.transpose(1, 2)
1024        k = k.transpose(1, 2)
1025        v = v.transpose(1, 2)

Scaled dot-product: scaling=1.0 because QK norms handle the variance

1028        attn = (q @ k.transpose(-2, -1)) * self.scaling

Optional attention mask (e.g. to block padding patches from attending)

1030        if attention_mask is not None:
1031            attn = attn + attention_mask   # additive: 0 = attend, -inf = block

Upcast to float32 for numerically stable softmax

1033        attn = F.softmax(attn.float(), dim=-1).to(q.dtype)

Weighted sum of values, merge heads, project back to D

1036        out = (attn @ v).transpose(1, 2).reshape(B, N, H * Dh)
1037        return self.o_proj(out)

ViT encoder block with sandwich layernorms (pre+post norm around both attention and MLP).

1040class VisionEncoderLayer(nn.Module):
1045    def __init__(self, cfg: VisionConfig):
1046        super().__init__()
1047        self.self_attn = VisionAttention(cfg)
1048        self.mlp       = VisionMLP(cfg)
1049
1050        self.input_layernorm          = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
1051        self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
1052        self.pre_feedforward_layernorm  = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
1053        self.post_feedforward_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
1055    def forward(
1056        self,
1057        hidden_states: torch.Tensor,
1058        pixel_position_ids: torch.Tensor,
1059        attention_mask: Optional[torch.Tensor] = None,
1060    ) -> torch.Tensor:

── Attention block with sandwich layernorms ──────────────────────── Pre-norm: normalise before attention (stabilises gradient flow)

1063        residual = hidden_states
1064        hidden_states = self.input_layernorm(hidden_states)
1065        hidden_states = self.self_attn(hidden_states, pixel_position_ids, attention_mask)

Post-norm: second normalisation on the attention output before adding residual (double-norm "sandwich" differs from standard pre-norm transformers)

1068        hidden_states = self.post_attention_layernorm(hidden_states)
1069        hidden_states = residual + hidden_states

── MLP block with sandwich layernorms ──────────────────────────────

1072        residual = hidden_states
1073        hidden_states = self.pre_feedforward_layernorm(hidden_states)
1074        hidden_states = self.mlp(hidden_states)

Post-FFN norm: same sandwich pattern as attention block

1076        hidden_states = self.post_feedforward_layernorm(hidden_states)
1077        hidden_states = residual + hidden_states
1078
1079        return hidden_states
1082class VisionEncoder(nn.Module):
1083    def __init__(self, cfg: VisionConfig):
1084        super().__init__()
1085        self.layers = nn.ModuleList([VisionEncoderLayer(cfg) for _ in range(cfg.num_hidden_layers)])
1087    def forward(
1088        self,
1089        hidden_states: torch.Tensor,
1090        pixel_position_ids: torch.Tensor,
1091        attention_mask: Optional[torch.Tensor] = None,
1092    ) -> torch.Tensor:

Sequential ViT layers: each applies sandwich-norm attention + MLP with residual connections

1094        for layer in self.layers:
1095            hidden_states = layer(hidden_states, pixel_position_ids, attention_mask)

Final hidden states: [B, N, D] — all patch positions, including padding

1097        return hidden_states

Vision Spatial Pooler

Reduces the $N$ patch tokens to $N / k^2$ soft tokens by averaging each non-overlapping $k \times k$ block of patches (default $k = 3$, so 9 patches → 1 soft token):

$$\text{soft\_tok}_{(r',c')} = \frac{1}{k^2} \sum_{dr=0}^{k-1}\sum_{dc=0}^{k-1} h_{(kr'+dr,\, kc'+dc)}$$

The output is scaled by $\sqrt{D_v}$ before returning, matching the HF implementation. Padding tokens (indicated by padding_positions) are masked out and stripped from the output so downstream layers see only valid content.

This aggressive pooling is key to keeping the token count manageable: a $336 \times 336$ image at patch size 16 gives $21 \times 21 = 441$ raw patches, which pool down to $7 \times 7 = 49$ soft tokens.

1100class VisionPooler(nn.Module):
1119    def __init__(self, cfg: VisionConfig):
1120        super().__init__()
1121        self.kernel = cfg.pooling_kernel_size
1122        self.scale  = cfg.hidden_size ** 0.5
1123        self.standardize = cfg.standardize
1124        if cfg.standardize:
1125            self.std_bias  = nn.Parameter(torch.zeros(cfg.hidden_size))
1126            self.std_scale = nn.Parameter(torch.ones(cfg.hidden_size))

Position-aware average pooling — groups patches by their (x//k, y//k) bin.

1128    def _avg_pool_by_positions(
1129        self,
1130        hidden_states: torch.Tensor,   # [B, N, D]
1131        pixel_position_ids: torch.Tensor,  # [B, N, 2]
1132        output_length: int,
1133    ) -> tuple[torch.Tensor, torch.Tensor]:
1135        N = hidden_states.shape[1]
1136        k = int((N // output_length) ** 0.5)
1137        k2 = k * k
1138        if k2 * output_length != N:
1139            raise ValueError(f"Cannot pool {N} patches to {output_length}: {k}^2 × {output_length}{N}")

Clamp padding (-1) positions to 0 — padding is zeroed out already

1142        pos = pixel_position_ids.clamp(min=0)              # [B, N, 2]
1143        max_x = pos[..., 0].max(dim=-1, keepdim=True)[0] + 1  # [B, 1]

Kernel index: which (col_bin, row_bin) does each patch fall into?

1145        kx = torch.div(pos[..., 0], k, rounding_mode="floor")  # [B, N]
1146        ky = torch.div(pos[..., 1], k, rounding_mode="floor")  # [B, N]
1147        kernel_idxs = kx + (max_x // k) * ky               # [B, N] linear index

One-hot weighted average: each kernel bin accumulates 1/k² of each patch

1150        weights = F.one_hot(kernel_idxs.long(), output_length).float() / k2  # [B, N, L]
1151        output = weights.transpose(1, 2) @ hidden_states.float()             # [B, L, D]
1152        valid_mask = (weights != 0).any(dim=1)                               # [B, L]
1153        return output.to(hidden_states.dtype), valid_mask
1155    def forward(
1156        self,
1157        hidden_states: torch.Tensor,          # [B, N, D]
1158        pixel_position_ids: torch.Tensor,     # [B, N, 2]
1159        padding_mask: torch.Tensor,           # [B, N] True = padding
1160    ) -> tuple[torch.Tensor, torch.Tensor]:
1161        B, N, D = hidden_states.shape

Target output length after pooling: N / k² soft tokens per image

1163        output_length = N // (self.kernel ** 2)

Zero-out padding patches before pooling so they don't bias the averages

1166        hidden_states = hidden_states.masked_fill(padding_mask.unsqueeze(-1), 0.0)
1167
1168        if N != output_length:

Spatial pooling: each k×k neighborhood of patches → one soft token

1170            hidden_states, valid_mask = self._avg_pool_by_positions(
1171                hidden_states, pixel_position_ids, output_length
1172            )
1173        else:

No pooling needed (kernel=1); valid mask is just the non-padding positions

1175            valid_mask = ~padding_mask

Scale pooled vectors by √D_vision (matches HF implementation)

1178        hidden_states = hidden_states * self.scale

Optional channel-wise standardization (bias + scale per feature dim)

1181        if self.standardize:
1182            hidden_states = (hidden_states - self.std_bias) * self.std_scale

Strip padding soft tokens → return flat [M, D] tensor where M = valid pooled patches valid_mask: [B, L] boolean indicates which soft tokens are real vs padding

1186        hidden_states = hidden_states[valid_mask]
1187        return hidden_states, valid_mask

Full vision encoder pipeline: pixels → PatchEmbedder → TransformerEncoder → VisionPooler → soft tokens

1190class VisionModel(nn.Module):
1195    def __init__(self, cfg: VisionConfig):
1196        super().__init__()
1197        self.patch_embedder = VisionPatchEmbedder(cfg)
1198        self.encoder        = VisionEncoder(cfg)
1199        self.pooler         = VisionPooler(cfg)
1201    def forward(
1202        self,
1203        pixel_values: torch.Tensor,          # [B, N, patch_dim]
1204        pixel_position_ids: torch.Tensor,    # [B, N, 2]
1205    ) -> torch.Tensor:

Mark padding patches: position_id == (-1, -1) signals a padding slot

1207        padding_mask = (pixel_position_ids == -1).all(dim=-1)   # [B, N]

Build additive attention mask: 0 for real patches, -1e9 for padding (pre-softmax block)

1209        attn_mask = padding_mask.float() * -1e9

Expand to [B, 1, 1, N] so it broadcasts over all heads and query positions

1211        attn_mask = attn_mask[:, None, None, :]                  # broadcast over heads

Stage 1: project each patch to hidden_size + add 2-D positional embeddings

1214        h = self.patch_embedder(pixel_values, pixel_position_ids, padding_mask)

Stage 2: run all ViT encoder layers with 2-D RoPE and sandwich-norm attention

1216        h = self.encoder(h, pixel_position_ids, attention_mask=attn_mask)

Stage 3: k×k spatial average pooling → M soft tokens, stripping padding positions

1218        soft_tokens, _ = self.pooler(h, pixel_position_ids, padding_mask)
1219        return soft_tokens   # [M, D_vision]  (M = total valid pooled patches across batch)
Multimodal Fusion

Projects vision soft-tokens into the language model embedding space. Applies a pre-projection RMSNorm then a linear projection.

1226class MultimodalEmbedder(nn.Module):
1231    def __init__(self, vision_dim: int, text_dim: int, eps: float = 1e-6):
1232        super().__init__()

HF uses with_scale=False (no learnable weight, pure normalization)

1234        self.norm = RMSNorm(vision_dim, eps=eps, with_scale=False)
1235        self.proj = nn.Linear(vision_dim, text_dim, bias=False)
1237    def forward(self, soft_tokens: torch.Tensor) -> torch.Tensor:
1238        return self.proj(self.norm(soft_tokens))

Multimodal Backbone

Fuses the vision and text towers into a single forward pass:

Step 1 — Text embedding Token ids → ScaledEmbedding (with image placeholders replaced by pad_id so the embedding table doesn't see the sentinel token).

Step 2 — Vision encoding Raw patches → VisionModel (ViT encoder + spatial pooler) → $M$ soft tokens of shape $[M, D_v]$. Then MultimodalEmbedder (LayerNorm + Linear) projects them to text dimension $D$, producing $[M, D]$.

Step 3 — Token stream merge Image placeholder positions in the embedding sequence are overwritten by the projected soft tokens via masked_scatter:

$$e_i = \begin{cases} \text{soft\_tok}_j & \text{if } i \text{ is an image placeholder} \\ \text{text\_emb}_i & \text{otherwise} \end{cases}$$

Step 4 — Per-layer side-channel The per-layer projection is computed from the merged embeddings (not the original text-only ids), so vision features appear in $P_\ell$ for every layer.

Step 5 — Language model The merged embedding sequence is passed through TextModel as normal. Vision information reaches the transformer through two routes: (a) directly as soft tokens in the input positions, and (b) via the per-layer side-channel $P_\ell$ injected at every decoder layer.

1241class Gemma4Model(nn.Module):
1272    def __init__(self, cfg: Gemma4Config):
1273        super().__init__()
1274        self.language_model   = TextModel(cfg.text)
1275        self.vision_model     = VisionModel(cfg.vision)
1276        self.mm_embedder      = MultimodalEmbedder(cfg.vision.hidden_size, cfg.text.hidden_size)
1277        self.image_token_id = cfg.image_token_id
1279    def forward(
1280        self,
1281        input_ids: torch.Tensor,                           # [B, L]
1282        pixel_values: Optional[torch.Tensor] = None,       # [B, N, patch_dim]
1283        pixel_position_ids: Optional[torch.Tensor] = None, # [B, N, 2]
1284        attention_mask: Optional[torch.Tensor] = None,
1285        per_layer_inputs: Optional[torch.Tensor] = None,   # [B, L, N, Dp] pre-computed
1286    ) -> torch.Tensor:

Locate image placeholder positions in the token sequence

1288        image_mask = (input_ids == self.image_token_id)  # [B, L]

Replace placeholders with pad_id so the embedding table sees valid indices

1290        text_ids = input_ids.clone()
1291        text_ids[image_mask] = self.language_model.embed_tokens.padding_idx
1292        inputs_embeds = self.language_model.embed_tokens(text_ids)  # [B, L, D_text]

Compute the token-embedding half of per-layer inputs BEFORE we overwrite the placeholder positions — image slots should use pad embedding, not vision

1296        if per_layer_inputs is None:
1297            pli_embed = self.language_model._get_per_layer_embed(text_ids)
1298        else:
1299            pli_embed = None  # caller provided fully-projected per_layer_inputs

── Vision path: encode pixels → soft tokens → inject into embedding sequence ──

1302        if pixel_values is not None and pixel_position_ids is not None:

ViT encoder + k×k spatial pooler → [M, D_vision] soft tokens

1304            soft_tokens = self.vision_model(pixel_values, pixel_position_ids)

Linear projection to text hidden dimension → [M, D_text]

1306            soft_tokens = self.mm_embedder(soft_tokens)

Overwrite image placeholder embeddings with projected vision features

1308            img_mask_exp = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
1309            inputs_embeds = inputs_embeds.masked_scatter(
1310                img_mask_exp, soft_tokens.to(inputs_embeds.dtype)
1311            )

Recompute the projection half using MERGED embeddings so that vision features appear in the per-layer side-channel P_ℓ passed to every decoder layer

1315        if pli_embed is not None:
1316            per_layer_inputs = self.language_model._project_per_layer(inputs_embeds, pli_embed)

Run the full text tower with vision features baked into the embedding sequence

1319        return self.language_model(
1320            inputs_embeds=inputs_embeds,
1321            attention_mask=attention_mask,
1322            per_layer_inputs=per_layer_inputs,
1323        )

Text-only Gemma 4 with final logit soft-capping.

1326class Gemma4ForCausalLM(nn.Module):
1328    def __init__(self, cfg: Gemma4Config):
1329        super().__init__()
1330        self.model   = TextModel(cfg.text)
1331        self.lm_head = nn.Linear(cfg.text.hidden_size, cfg.text.vocab_size, bias=False)
1332        self.final_logit_softcapping = cfg.text.final_logit_softcapping
1334    def forward(
1335        self,
1336        input_ids: torch.Tensor,
1337        attention_mask: Optional[torch.Tensor] = None,
1338        labels: Optional[torch.Tensor] = None,
1339    ) -> dict:

Run the full text decoder stack; returns final hidden states [B, L, D]

1341        hidden = self.model(input_ids, attention_mask)

Unembedding: project each position to vocabulary logits [B, L, V]

1343        logits = self.lm_head(hidden)

Logit soft-capping: tanh(z/cap)*cap squashes logits to (-cap, +cap), preventing extreme values that destabilise softmax in long-context generation

1347        if self.final_logit_softcapping is not None:
1348            cap = self.final_logit_softcapping
1349            logits = torch.tanh(logits / cap) * cap
1350
1351        loss = None
1352        if labels is not None:

Shift by 1: predict token i+1 from position i (causal LM objective)

1354            loss = F.cross_entropy(
1355                logits[:, :-1].reshape(-1, logits.shape[-1]),
1356                labels[:, 1:].reshape(-1),
1357                ignore_index=-100,
1358            )
1359        return {"loss": loss, "logits": logits}

Full multimodal Gemma 4 (text + images).

1362class Gemma4ForConditionalGeneration(nn.Module):
1364    def __init__(self, cfg: Gemma4Config):
1365        super().__init__()
1366        self.model   = Gemma4Model(cfg)
1367        self.lm_head = nn.Linear(cfg.text.hidden_size, cfg.text.vocab_size, bias=False)
1368        self.final_logit_softcapping = cfg.text.final_logit_softcapping
1370    def forward(
1371        self,
1372        input_ids: torch.Tensor,
1373        pixel_values: Optional[torch.Tensor] = None,
1374        pixel_position_ids: Optional[torch.Tensor] = None,
1375        attention_mask: Optional[torch.Tensor] = None,
1376        labels: Optional[torch.Tensor] = None,
1377    ) -> dict:

Gemma4Model handles vision encoding + token stream merge + text tower in one call

1379        hidden = self.model(input_ids, pixel_values, pixel_position_ids, attention_mask)

Project hidden states to vocabulary logits [B, L, V]

1381        logits = self.lm_head(hidden)

Soft-cap logits: same as text-only model — squashes extremes to (-cap, +cap)

1384        if self.final_logit_softcapping is not None:
1385            cap = self.final_logit_softcapping
1386            logits = torch.tanh(logits / cap) * cap
1387
1388        loss = None
1389        if labels is not None:

Causal LM cross-entropy: predict token at position i+1 from position i

1391            loss = F.cross_entropy(
1392                logits[:, :-1].reshape(-1, logits.shape[-1]),
1393                labels[:, 1:].reshape(-1),
1394                ignore_index=-100,
1395            )
1396        return {"loss": loss, "logits": logits}
Smoke test
1391
1392if __name__ == "__main__":
1393    print("Running Gemma 4 smoke test (tiny config, random weights)...\n")

Tiny config to run on CPU quickly

1407    text_cfg = TextConfig(
1408        vocab_size=1000, hidden_size=64, num_hidden_layers=2,
1409        num_attention_heads=4, num_key_value_heads=2, head_dim=16,
1410        intermediate_size=128, num_experts=4, num_experts_per_tok=2,
1411        moe_layers=[1], expert_intermediate_size=64,
1412        sliding_window=8, sliding_window_pattern=2,
1413        final_logit_softcapping=30.0, rms_norm_eps=1e-6,
1414    )
1415    vision_cfg = VisionConfig(
1416        hidden_size=32, num_hidden_layers=1, num_attention_heads=2,
1417        head_dim=16, intermediate_size=64, patch_size=4,
1418        image_size=16, pooling_kernel_size=2,
1419    )
1420    cfg = Gemma4Config(text=text_cfg, vision=vision_cfg)

── Text-only ────────────────────────────────────────────────────────

1423    model = Gemma4ForCausalLM(cfg)
1424    model.eval()
1425    B, L = 2, 10
1426    input_ids = torch.randint(0, 1000, (B, L))
1427    with torch.no_grad():
1428        out = model(input_ids)
1429    print(f"[CausalLM]  logits shape: {out['logits'].shape}")  # [2, 10, 1000]

── Multimodal ───────────────────────────────────────────────────────

1432    mm_model = Gemma4ForConditionalGeneration(cfg)
1433    mm_model.eval()

4 image patches per image, 2-D positions, patch_dim = 443 = 48 With pooling_kernel_size=2: output_length = 4 // (2*2) = 1 soft token per image

1436    N, patch_dim = 4, 48
1437    pixel_values = torch.rand(B, N, patch_dim)
1438    pixel_position_ids = torch.tensor([[[0,0],[0,1],[1,0],[1,1]]]).expand(B, -1, -1)

Put exactly 1 placeholder per image (= 1 pooled soft token)

1440    input_ids[:, 2] = 255_999
1441    with torch.no_grad():
1442        out_mm = mm_model(input_ids, pixel_values, pixel_position_ids)
1443    print(f"[Multimodal] logits shape: {out_mm['logits'].shape}")  # [2, 10, 1000]
1444    print("\n✅ Smoke test passed!")